Skip to content

Commit 607b253

Browse files
committed
Add mkl_fft in fft
1 parent 817720d commit 607b253

12 files changed

Lines changed: 533 additions & 20 deletions

File tree

environment-dev-arm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ dependencies:
2323
- autopep8
2424
- isort
2525
- black
26+
- mkl_fft
2627
- pip:
2728
- torch
2829
- devito

environment-dev-gpu.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies:
2121
- autopep8
2222
- isort
2323
- black
24+
- mkl_fft
2425
- pip:
2526
- torch
2627
- pytest-runner

environment-dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ dependencies:
2424
- autopep8
2525
- isort
2626
- black
27+
- mkl_fft
2728
- pip:
2829
- torch
2930
- devito

examples/plot_fft.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,30 @@
6666
axs[1].set_xlim([0, 3 * f0])
6767
plt.tight_layout()
6868

69+
###############################################################################
70+
# PyLops also has a third FFT engine (engine='mkl_fft') that uses the well-known
71+
# `Intel MKL FFT <https://github.com/IntelPython/mkl_fft>`_. This is a Python wrapper around
72+
# the `Intel® oneAPI Math Kernel Library (oneMKL) <https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2025-2/fourier-transform-functions.html>`_
73+
# Fourier Transform functions. It lets PyLops run discrete Fourier transforms faster
74+
# by using Intel’s highly optimized math routines.
75+
76+
FFTop = pylops.signalprocessing.FFT(dims=nt, nfft=nfft, sampling=dt, engine="mkl_fft")
77+
D = FFTop * d
78+
79+
# Inverse for FFT
80+
dinv = FFTop.H * D
81+
dinv = FFTop / D
82+
83+
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
84+
axs[0].plot(t, d, "k", lw=2, label="True")
85+
axs[0].plot(t, dinv.real, "--r", lw=2, label="Inverted")
86+
axs[0].legend()
87+
axs[0].set_title("Signal")
88+
axs[1].plot(FFTop.f[: int(FFTop.nfft / 2)], np.abs(D[: int(FFTop.nfft / 2)]), "k", lw=2)
89+
axs[1].set_title("Fourier Transform with MKL FFT")
90+
axs[1].set_xlim([0, 3 * f0])
91+
plt.tight_layout()
92+
6993
###############################################################################
7094
# We can also apply the one dimensional FFT to to a two-dimensional
7195
# signal (along one of the first axis)

pylops/signalprocessing/fft.py

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,20 @@
1616
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
1717

1818
pyfftw_message = deps.pyfftw_import("the fft module")
19+
mkl_fft_message = deps.mkl_fft_import("the mkl fft module")
1920

2021
if pyfftw_message is None:
2122
import pyfftw
2223

24+
if mkl_fft_message is None:
25+
import mkl_fft.interfaces.numpy_fft as mkl_backend
26+
27+
try:
28+
import scipy.fft # noqa: F401
29+
import mkl_fft.interfaces.scipy_fft as mkl_backend
30+
except ImportError:
31+
pass
32+
2333
logger = logging.getLogger(__name__)
2434

2535

@@ -394,6 +404,94 @@ def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike:
394404
return self._rmatvec(y) / self._scale
395405

396406

407+
class _FFT_mklfft(_BaseFFT):
408+
"""One-dimensional Fast-Fourier Transform using mkl_fft"""
409+
410+
def __init__(
411+
self,
412+
dims: Union[int, InputDimsLike],
413+
axis: int = -1,
414+
nfft: Optional[int] = None,
415+
sampling: float = 1.0,
416+
norm: str = "ortho",
417+
real: bool = False,
418+
ifftshift_before: bool = False,
419+
fftshift_after: bool = False,
420+
dtype: DTypeLike = "complex128",
421+
**kwargs_fft,
422+
) -> None:
423+
super().__init__(
424+
dims=dims,
425+
axis=axis,
426+
nfft=nfft,
427+
sampling=sampling,
428+
norm=norm,
429+
real=real,
430+
ifftshift_before=ifftshift_before,
431+
fftshift_after=fftshift_after,
432+
dtype=dtype,
433+
)
434+
self._kwargs_fft = kwargs_fft
435+
self._norm_kwargs = {"norm": None}
436+
if self.norm is _FFTNorms.ORTHO:
437+
self._norm_kwargs["norm"] = "ortho"
438+
self._scale = np.sqrt(1 / self.nfft)
439+
elif self.norm is _FFTNorms.NONE:
440+
self._scale = self.nfft
441+
elif self.norm is _FFTNorms.ONE_OVER_N:
442+
self._scale = 1.0 / self.nfft
443+
444+
@reshaped
445+
def _matvec(self, x: NDArray) -> NDArray:
446+
if self.ifftshift_before:
447+
x = mkl_backend.ifftshift(x, axes=self.axis)
448+
if not self.clinear:
449+
x = np.real(x)
450+
if self.real:
451+
y = mkl_backend.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
452+
y = np.swapaxes(y, -1, self.axis)
453+
y[..., 1 : 1 + (self.nfft - 1) // 2] *= np.sqrt(2)
454+
y = np.swapaxes(y, self.axis, -1)
455+
else:
456+
y = mkl_backend.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
457+
if self.norm is _FFTNorms.ONE_OVER_N:
458+
y *= self._scale
459+
if self.fftshift_after:
460+
y = mkl_backend.fftshift(y, axes=self.axis)
461+
return y
462+
463+
@reshaped
464+
def _rmatvec(self, x: NDArray) -> NDArray:
465+
if self.fftshift_after:
466+
x = mkl_backend.ifftshift(x, axes=self.axis)
467+
if self.real:
468+
x = x.copy()
469+
x = np.swapaxes(x, -1, self.axis)
470+
x[..., 1 : 1 + (self.nfft - 1) // 2] /= np.sqrt(2)
471+
x = np.swapaxes(x, self.axis, -1)
472+
y = mkl_backend.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
473+
else:
474+
y = mkl_backend.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
475+
if self.norm is _FFTNorms.NONE:
476+
y *= self._scale
477+
478+
if self.nfft > self.dims[self.axis]:
479+
y = np.take(y, range(0, self.dims[self.axis]), axis=self.axis)
480+
elif self.nfft < self.dims[self.axis]:
481+
y = np.pad(y, self.ifftpad)
482+
483+
if not self.clinear:
484+
y = np.real(y)
485+
if self.ifftshift_before:
486+
y = mkl_backend.fftshift(y, axes=self.axis)
487+
return y
488+
489+
def __truediv__(self, y):
490+
if self.norm is not _FFTNorms.ORTHO:
491+
return self._rmatvec(y) / self._scale
492+
return self._rmatvec(y)
493+
494+
397495
def FFT(
398496
dims: Union[int, InputDimsLike],
399497
axis: int = -1,
@@ -481,7 +579,7 @@ def FFT(
481579
frequencies are arranged from zero to largest positive, and then from negative
482580
Nyquist to the frequency bin before zero.
483581
engine : :obj:`str`, optional
484-
Engine used for fft computation (``numpy``, ``fftw``, or ``scipy``). Choose
582+
Engine used for fft computation (``numpy``, ``fftw``, ``scipy`` or ``mkl_fft``). Choose
485583
``numpy`` when working with cupy and jax arrays.
486584
487585
.. note:: Since version 1.17.0, accepts "scipy".
@@ -534,7 +632,7 @@ def FFT(
534632
- If ``dims`` is provided and ``axis`` is bigger than ``len(dims)``.
535633
- If ``norm`` is not one of "ortho", "none", or "1/n".
536634
NotImplementedError
537-
If ``engine`` is neither ``numpy``, ``fftw``, nor ``scipy``.
635+
If ``engine`` is neither ``numpy``, ``fftw``, ``scipy`` nor ``mkl_fft``.
538636
539637
See Also
540638
--------
@@ -579,7 +677,24 @@ def FFT(
579677
dtype=dtype,
580678
**kwargs_fft,
581679
)
582-
elif engine == "numpy" or (engine == "fftw" and pyfftw_message is not None):
680+
elif engine == "mkl_fft" and mkl_fft_message is None:
681+
f = _FFT_mklfft(
682+
dims,
683+
axis=axis,
684+
nfft=nfft,
685+
sampling=sampling,
686+
norm=norm,
687+
real=real,
688+
ifftshift_before=ifftshift_before,
689+
fftshift_after=fftshift_after,
690+
dtype=dtype,
691+
**kwargs_fft,
692+
)
693+
elif (
694+
engine == "numpy"
695+
or (engine == "fftw" and pyfftw_message is not None)
696+
or (engine == "mkl_fft" and mkl_fft_message is not None)
697+
):
583698
if engine == "fftw" and pyfftw_message is not None:
584699
logger.warning(pyfftw_message)
585700
f = _FFT_numpy(
@@ -608,6 +723,6 @@ def FFT(
608723
**kwargs_fft,
609724
)
610725
else:
611-
raise NotImplementedError("engine must be numpy, fftw or scipy")
726+
raise NotImplementedError("engine must be numpy, fftw, scipy or mkl_fft")
612727
f.name = name
613728
return f

0 commit comments

Comments
 (0)