|
16 | 16 | from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray |
17 | 17 |
|
18 | 18 | pyfftw_message = deps.pyfftw_import("the fft module") |
| 19 | +mkl_fft_message = deps.mkl_fft_import("the mkl fft module") |
19 | 20 |
|
20 | 21 | if pyfftw_message is None: |
21 | 22 | import pyfftw |
22 | 23 |
|
| 24 | +if mkl_fft_message is None: |
| 25 | + import mkl_fft.interfaces.numpy_fft as mkl_backend |
| 26 | + |
| 27 | + try: |
| 28 | + import scipy.fft # noqa: F401 |
| 29 | + import mkl_fft.interfaces.scipy_fft as mkl_backend |
| 30 | + except ImportError: |
| 31 | + pass |
| 32 | + |
23 | 33 | logger = logging.getLogger(__name__) |
24 | 34 |
|
25 | 35 |
|
@@ -394,6 +404,94 @@ def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike: |
394 | 404 | return self._rmatvec(y) / self._scale |
395 | 405 |
|
396 | 406 |
|
| 407 | +class _FFT_mklfft(_BaseFFT): |
| 408 | + """One-dimensional Fast-Fourier Transform using mkl_fft""" |
| 409 | + |
| 410 | + def __init__( |
| 411 | + self, |
| 412 | + dims: Union[int, InputDimsLike], |
| 413 | + axis: int = -1, |
| 414 | + nfft: Optional[int] = None, |
| 415 | + sampling: float = 1.0, |
| 416 | + norm: str = "ortho", |
| 417 | + real: bool = False, |
| 418 | + ifftshift_before: bool = False, |
| 419 | + fftshift_after: bool = False, |
| 420 | + dtype: DTypeLike = "complex128", |
| 421 | + **kwargs_fft, |
| 422 | + ) -> None: |
| 423 | + super().__init__( |
| 424 | + dims=dims, |
| 425 | + axis=axis, |
| 426 | + nfft=nfft, |
| 427 | + sampling=sampling, |
| 428 | + norm=norm, |
| 429 | + real=real, |
| 430 | + ifftshift_before=ifftshift_before, |
| 431 | + fftshift_after=fftshift_after, |
| 432 | + dtype=dtype, |
| 433 | + ) |
| 434 | + self._kwargs_fft = kwargs_fft |
| 435 | + self._norm_kwargs = {"norm": None} |
| 436 | + if self.norm is _FFTNorms.ORTHO: |
| 437 | + self._norm_kwargs["norm"] = "ortho" |
| 438 | + self._scale = np.sqrt(1 / self.nfft) |
| 439 | + elif self.norm is _FFTNorms.NONE: |
| 440 | + self._scale = self.nfft |
| 441 | + elif self.norm is _FFTNorms.ONE_OVER_N: |
| 442 | + self._scale = 1.0 / self.nfft |
| 443 | + |
| 444 | + @reshaped |
| 445 | + def _matvec(self, x: NDArray) -> NDArray: |
| 446 | + if self.ifftshift_before: |
| 447 | + x = mkl_backend.ifftshift(x, axes=self.axis) |
| 448 | + if not self.clinear: |
| 449 | + x = np.real(x) |
| 450 | + if self.real: |
| 451 | + y = mkl_backend.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) |
| 452 | + y = np.swapaxes(y, -1, self.axis) |
| 453 | + y[..., 1 : 1 + (self.nfft - 1) // 2] *= np.sqrt(2) |
| 454 | + y = np.swapaxes(y, self.axis, -1) |
| 455 | + else: |
| 456 | + y = mkl_backend.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) |
| 457 | + if self.norm is _FFTNorms.ONE_OVER_N: |
| 458 | + y *= self._scale |
| 459 | + if self.fftshift_after: |
| 460 | + y = mkl_backend.fftshift(y, axes=self.axis) |
| 461 | + return y |
| 462 | + |
| 463 | + @reshaped |
| 464 | + def _rmatvec(self, x: NDArray) -> NDArray: |
| 465 | + if self.fftshift_after: |
| 466 | + x = mkl_backend.ifftshift(x, axes=self.axis) |
| 467 | + if self.real: |
| 468 | + x = x.copy() |
| 469 | + x = np.swapaxes(x, -1, self.axis) |
| 470 | + x[..., 1 : 1 + (self.nfft - 1) // 2] /= np.sqrt(2) |
| 471 | + x = np.swapaxes(x, self.axis, -1) |
| 472 | + y = mkl_backend.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) |
| 473 | + else: |
| 474 | + y = mkl_backend.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) |
| 475 | + if self.norm is _FFTNorms.NONE: |
| 476 | + y *= self._scale |
| 477 | + |
| 478 | + if self.nfft > self.dims[self.axis]: |
| 479 | + y = np.take(y, range(0, self.dims[self.axis]), axis=self.axis) |
| 480 | + elif self.nfft < self.dims[self.axis]: |
| 481 | + y = np.pad(y, self.ifftpad) |
| 482 | + |
| 483 | + if not self.clinear: |
| 484 | + y = np.real(y) |
| 485 | + if self.ifftshift_before: |
| 486 | + y = mkl_backend.fftshift(y, axes=self.axis) |
| 487 | + return y |
| 488 | + |
| 489 | + def __truediv__(self, y): |
| 490 | + if self.norm is not _FFTNorms.ORTHO: |
| 491 | + return self._rmatvec(y) / self._scale |
| 492 | + return self._rmatvec(y) |
| 493 | + |
| 494 | + |
397 | 495 | def FFT( |
398 | 496 | dims: Union[int, InputDimsLike], |
399 | 497 | axis: int = -1, |
@@ -481,7 +579,7 @@ def FFT( |
481 | 579 | frequencies are arranged from zero to largest positive, and then from negative |
482 | 580 | Nyquist to the frequency bin before zero. |
483 | 581 | engine : :obj:`str`, optional |
484 | | - Engine used for fft computation (``numpy``, ``fftw``, or ``scipy``). Choose |
| 582 | + Engine used for fft computation (``numpy``, ``fftw``, ``scipy`` or ``mkl_fft``). Choose |
485 | 583 | ``numpy`` when working with cupy and jax arrays. |
486 | 584 |
|
487 | 585 | .. note:: Since version 1.17.0, accepts "scipy". |
@@ -534,7 +632,7 @@ def FFT( |
534 | 632 | - If ``dims`` is provided and ``axis`` is bigger than ``len(dims)``. |
535 | 633 | - If ``norm`` is not one of "ortho", "none", or "1/n". |
536 | 634 | NotImplementedError |
537 | | - If ``engine`` is neither ``numpy``, ``fftw``, nor ``scipy``. |
| 635 | + If ``engine`` is neither ``numpy``, ``fftw``, ``scipy`` nor ``mkl_fft``. |
538 | 636 |
|
539 | 637 | See Also |
540 | 638 | -------- |
@@ -579,7 +677,24 @@ def FFT( |
579 | 677 | dtype=dtype, |
580 | 678 | **kwargs_fft, |
581 | 679 | ) |
582 | | - elif engine == "numpy" or (engine == "fftw" and pyfftw_message is not None): |
| 680 | + elif engine == "mkl_fft" and mkl_fft_message is None: |
| 681 | + f = _FFT_mklfft( |
| 682 | + dims, |
| 683 | + axis=axis, |
| 684 | + nfft=nfft, |
| 685 | + sampling=sampling, |
| 686 | + norm=norm, |
| 687 | + real=real, |
| 688 | + ifftshift_before=ifftshift_before, |
| 689 | + fftshift_after=fftshift_after, |
| 690 | + dtype=dtype, |
| 691 | + **kwargs_fft, |
| 692 | + ) |
| 693 | + elif ( |
| 694 | + engine == "numpy" |
| 695 | + or (engine == "fftw" and pyfftw_message is not None) |
| 696 | + or (engine == "mkl_fft" and mkl_fft_message is not None) |
| 697 | + ): |
583 | 698 | if engine == "fftw" and pyfftw_message is not None: |
584 | 699 | logger.warning(pyfftw_message) |
585 | 700 | f = _FFT_numpy( |
@@ -608,6 +723,6 @@ def FFT( |
608 | 723 | **kwargs_fft, |
609 | 724 | ) |
610 | 725 | else: |
611 | | - raise NotImplementedError("engine must be numpy, fftw or scipy") |
| 726 | + raise NotImplementedError("engine must be numpy, fftw, scipy or mkl_fft") |
612 | 727 | f.name = name |
613 | 728 | return f |
0 commit comments