Skip to content

Commit cbd7d1b

Browse files
authored
Merge pull request #695 from mrava87/patch-kwarg_fft
Feat: added fft engine option to operators using FFT
2 parents 817720d + 3af9374 commit cbd7d1b

8 files changed

Lines changed: 315 additions & 163 deletions

File tree

pylops/signalprocessing/shift.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
__all__ = ["Shift"]
22

3-
from typing import Tuple, Union
3+
from typing import TYPE_CHECKING, Tuple, Union
44

55
import numpy as np
6-
import numpy.typing as npt
76

87
from pylops.basicoperators import Diagonal
98
from pylops.signalprocessing import FFT
109
from pylops.utils._internal import _value_or_sized_to_array
1110
from pylops.utils.backend import get_normalize_axis_index
12-
from pylops.utils.typing import DTypeLike
11+
from pylops.utils.typing import DTypeLike, NDArray
12+
13+
if TYPE_CHECKING:
14+
from pylops.linearoperator import LinearOperator
1315

1416

1517
def Shift(
1618
dims: Tuple,
17-
shift: Union[float, npt.ArrayLike],
19+
shift: Union[float, NDArray],
1820
axis: int = -1,
1921
nfft: int = None,
2022
sampling: float = 1.0,
2123
real: bool = False,
2224
engine: str = "numpy",
2325
dtype: DTypeLike = "complex128",
2426
name: str = "S",
25-
**kwargs_fftw
26-
):
27+
**kwargs_fft,
28+
) -> "LinearOperator":
2729
r"""Shift operator
2830
2931
Apply fractional shift in the frequency domain along an ``axis``
@@ -58,9 +60,8 @@ def Shift(
5860
.. versionadded:: 2.0.0
5961
6062
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
61-
**kwargs_fftw
62-
Arbitrary keyword arguments
63-
for :py:class:`pyfftw.FTTW`
63+
**kwargs_fft
64+
Arbitrary keyword arguments to be passed to the selected fft method
6465
6566
Attributes
6667
----------
@@ -98,7 +99,7 @@ def Shift(
9899
real=real,
99100
engine=engine,
100101
dtype=dtype,
101-
**kwargs_fftw
102+
**kwargs_fft,
102103
)
103104
if isinstance(dims, int):
104105
dimsdiag = None

pylops/waveeqprocessing/marchenko.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
__all__ = ["Marchenko"]
22

33
import logging
4-
from typing import Optional, Tuple, Union
4+
from typing import Any, Dict, Optional, Tuple, Union
55

66
import numpy as np
77
from scipy.signal import filtfilt
@@ -246,6 +246,7 @@ def __init__(
246246
prescaled: bool = False,
247247
fftengine: str = "numpy",
248248
dtype: DTypeLike = "float64",
249+
kwargs_fft: Optional[Dict[str, Any]] = None,
249250
) -> None:
250251
# Save inputs into class
251252
self.dt = dt
@@ -257,6 +258,7 @@ def __init__(
257258
self.prescaled = prescaled
258259
self.fftengine = fftengine
259260
self.dtype = dtype
261+
self.kwargs_fft = {} if kwargs_fft is None else kwargs_fft
260262
self.explicit = False
261263
self.ncp = get_array_module(R)
262264

@@ -384,6 +386,7 @@ def apply_onepoint(
384386
saveGt=self.saveRt,
385387
prescaled=self.prescaled,
386388
usematmul=usematmul,
389+
**self.kwargs_fft,
387390
)
388391
R1op = MDC(
389392
self.Rtwosided_fft,
@@ -397,6 +400,7 @@ def apply_onepoint(
397400
saveGt=self.saveRt,
398401
prescaled=self.prescaled,
399402
usematmul=usematmul,
403+
**self.kwargs_fft,
400404
)
401405
Rollop = Roll(
402406
(self.nt2, self.ns),
@@ -592,6 +596,7 @@ def apply_multiplepoints(
592596
fftengine=self.fftengine,
593597
prescaled=self.prescaled,
594598
usematmul=usematmul,
599+
**self.kwargs_fft,
595600
)
596601
R1op = MDC(
597602
self.Rtwosided_fft,
@@ -604,6 +609,7 @@ def apply_multiplepoints(
604609
fftengine=self.fftengine,
605610
prescaled=self.prescaled,
606611
usematmul=usematmul,
612+
**self.kwargs_fft,
607613
)
608614
Rollop = Roll(
609615
(self.nt2, self.ns, nvs),

pylops/waveeqprocessing/mdd.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,15 @@ def _MDC(
9797
real=True,
9898
ifftshift_before=twosided,
9999
dtype=rdtype,
100-
**args_FFT
100+
**args_FFT,
101101
)
102102
F1op = _FFT(
103103
dims=(nt, ns, nv),
104104
axis=0,
105105
real=True,
106106
ifftshift_before=False,
107107
dtype=rdtype,
108-
**args_FFT1
108+
**args_FFT1,
109109
)
110110

111111
# create Identity operator to extract only relevant frequencies
@@ -140,6 +140,7 @@ def MDC(
140140
usematmul: bool = False,
141141
prescaled: bool = False,
142142
name: str = "M",
143+
**kwargs_fft,
143144
) -> LinearOperator:
144145
r"""Multi-dimensional convolution.
145146
@@ -188,6 +189,10 @@ def MDC(
188189
.. versionadded:: 2.0.0
189190
190191
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
192+
**kwargs_fft
193+
.. versionadded:: 2.6.0
194+
195+
Arbitrary keyword arguments to be passed to the selected fft method
191196
192197
Raises
193198
------
@@ -243,8 +248,8 @@ def MDC(
243248
saveGt=saveGt,
244249
conj=conj,
245250
prescaled=prescaled,
246-
args_FFT={"engine": fftengine},
247-
args_FFT1={"engine": fftengine},
251+
args_FFT={**{"engine": fftengine}, **kwargs_fft},
252+
args_FFT1={**{"engine": fftengine}, **kwargs_fft},
248253
args_Fredholm1={"usematmul": usematmul},
249254
)
250255
MOp.name = name
@@ -267,7 +272,7 @@ def MDD(
267272
add_negative: bool = True,
268273
smooth_precond: int = 0,
269274
fftengine: str = "numpy",
270-
**kwargs_solver
275+
**kwargs_solver,
271276
) -> Union[
272277
Tuple[NDArray, NDArray],
273278
Tuple[NDArray, NDArray, NDArray],
@@ -483,7 +488,7 @@ def MDD(
483488
MDCop,
484489
d.ravel(),
485490
ncp.zeros(int(MDCop.shape[1]), dtype=MDCop.dtype),
486-
**kwargs_solver
491+
**kwargs_solver,
487492
)[0]
488493
minv = ncp.squeeze(minv.reshape(nt2, nr, nv))
489494
minv = ncp.moveaxis(minv, 0, -1)
@@ -502,7 +507,7 @@ def MDD(
502507
PSFop,
503508
G.ravel(),
504509
ncp.zeros(int(PSFop.shape[1]), dtype=PSFop.dtype),
505-
**kwargs_solver
510+
**kwargs_solver,
506511
)[0]
507512
psfinv = ncp.squeeze(psfinv.reshape(nt2, nr, nr))
508513
psfinv = ncp.moveaxis(psfinv, 0, -1)

pylops/waveeqprocessing/oneway.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def PhaseShift(
8080
ky: Optional[NDArray] = None,
8181
dtype: DTypeLike = "float64",
8282
name: str = "P",
83+
fftengine: str = "numpy",
84+
**kwargs_fft,
8385
) -> LinearOperator:
8486
r"""Phase shift operator
8587
@@ -110,6 +112,15 @@ def PhaseShift(
110112
.. versionadded:: 2.0.0
111113
112114
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
115+
fftengine : :obj:`str`, optional
116+
.. versionadded:: 2.6.0
117+
118+
Engine used for fft computation (``numpy``, ``scipy``, or ``fftw``). Choose
119+
``numpy`` when working with CuPy arrays.
120+
**kwargs_fft
121+
.. versionadded:: 2.6.0
122+
123+
Arbitrary keyword arguments to be passed to the selected fft method
113124
114125
Returns
115126
-------
@@ -170,7 +181,9 @@ def PhaseShift(
170181
nfft=ky.size,
171182
real=False,
172183
ifftshift_before=True,
184+
engine=fftengine,
173185
dtype=dtypefft,
186+
**kwargs_fft,
174187
)
175188
Pop = _PhaseShift(vel, dz, freq, kx, ky, dtypefft)
176189
if ky is None:
@@ -204,7 +217,7 @@ def Deghosting(
204217
solver: Callable = lsqr,
205218
dottest: bool = False,
206219
dtype: DTypeLike = "complex128",
207-
**kwargs_solver
220+
**kwargs_solver,
208221
) -> Tuple[NDArray, NDArray]:
209222
r"""Wavefield deghosting.
210223

pytests/test_marchenko.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
if int(os.environ.get("TEST_CUPY_PYLOPS", 0)):
44
import cupy as np
5-
from cupy.testing import assert_array_almost_equal, assert_array_equal
65

76
backend = "cupy"
87
else:
98
import numpy as np
10-
from numpy.testing import assert_array_almost_equal, assert_array_equal
119

1210
backend = "numpy"
1311
import numpy as npp
@@ -86,11 +84,16 @@
8684
R1twosided_fft = npp.fft.rfft(R1twosided, 2 * nt - 1, axis=-1) / npp.sqrt(2 * nt - 1)
8785
R1twosided_fft = R1twosided_fft[..., :nfmax]
8886

89-
90-
par1 = {"niter": 10, "prescaled": False, "fftengine": "numpy"}
91-
par2 = {"niter": 10, "prescaled": True, "fftengine": "numpy"}
92-
par3 = {"niter": 10, "prescaled": False, "fftengine": "scipy"}
93-
par4 = {"niter": 10, "prescaled": False, "fftengine": "fftw"}
87+
# Test parameters
88+
par1 = {"niter": 10, "prescaled": False, "fftengine": "numpy", "kwargs_fft": None}
89+
par2 = {"niter": 10, "prescaled": True, "fftengine": "numpy", "kwargs_fft": None}
90+
par3 = {
91+
"niter": 10,
92+
"prescaled": False,
93+
"fftengine": "scipy",
94+
"kwargs_fft": dict(workers=4),
95+
}
96+
par4 = {"niter": 10, "prescaled": False, "fftengine": "fftw", "kwargs_fft": None}
9497

9598

9699
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)])
@@ -111,6 +114,7 @@ def test_Marchenko_freq(par):
111114
nsmooth=nsmooth,
112115
prescaled=par["prescaled"],
113116
fftengine=par["fftengine"] if backend == "numpy" else "numpy",
117+
kwargs_fft=par["kwargs_fft"] if backend == "numpy" else None,
114118
)
115119

116120
solver_dict = (

0 commit comments

Comments
 (0)