Skip to content

Commit cfc2eb9

Browse files
committed
test: added dtype tests for twoway and waveeqprocessing
1 parent c65a3a7 commit cfc2eb9

2 files changed

Lines changed: 44 additions & 8 deletions

File tree

pytests/test_twoway.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,9 @@ def test_acwave2d():
6161
assert dottest(
6262
Dop, par["ns"] * par["nr"] * Dop.geometry.nt, par["nz"] * par["nx"], atol=1e-1
6363
)
64+
65+
x = np.ones(par["nz"] * par["nx"], dtype="float32")
66+
y = Dop * x
67+
xadj = Dop.H * y
68+
assert y.dtype == "float32"
69+
assert xadj.dtype == "float32"

pytests/test_waveeqprocessing.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,12 @@ def create_data(par, nv):
140140
@pytest.mark.parametrize(
141141
"par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)]
142142
)
143-
def test_MDC_1virtualsource(par):
143+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
144+
def test_MDC_1virtualsource(par, dtype):
144145
"""Dot-test and inversion for MDC operator of 1 virtual source"""
146+
cdtype = (np.empty(0, dtype=dtype) + 1j * np.empty(0, dtype=dtype)).dtype
145147
nt2, wav, mwav, Gwav, Gwav_fft = create_data(par, 1)
148+
Gwav_fft = Gwav_fft.astype(cdtype)
146149

147150
MDCop = MDC(
148151
np.asarray(Gwav_fft).transpose(2, 0, 1),
@@ -152,10 +155,18 @@ def test_MDC_1virtualsource(par):
152155
dr=parmod["dx"],
153156
twosided=par["twosided"],
154157
)
155-
dottest(MDCop, nt2 * parmod["ny"], nt2 * parmod["nx"], backend=backend)
156-
mwav = np.asarray(mwav).T
158+
dottest(
159+
MDCop,
160+
nt2 * parmod["ny"],
161+
nt2 * parmod["nx"],
162+
rtol=1e-4 if dtype == np.float32 else 1e-6,
163+
backend=backend,
164+
)
165+
166+
mwav = np.asarray(mwav.astype(dtype)).T
157167
d = MDCop * mwav.ravel()
158168
d = d.reshape(nt2, parmod["ny"])
169+
assert d.dtype == dtype
159170

160171
for it, amp in zip(it0_G, amp_G, strict=True):
161172
ittot = it0_m + it
@@ -201,9 +212,12 @@ def test_MDC_1virtualsource(par):
201212
@pytest.mark.parametrize(
202213
"par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)]
203214
)
204-
def test_MDC_Nvirtualsources(par):
215+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
216+
def test_MDC_Nvirtualsources(par, dtype):
205217
"""Dot-test and inversion for MDC operator of N virtual source"""
218+
cdtype = (np.empty(0, dtype=dtype) + 1j * np.empty(0, dtype=dtype)).dtype
206219
nt2, _, mwav, Gwav, Gwav_fft = create_data(par, parmod["nx"])
220+
Gwav_fft = Gwav_fft.astype(cdtype)
207221

208222
MDCop = MDC(
209223
np.asarray(Gwav_fft).transpose(2, 0, 1),
@@ -217,12 +231,14 @@ def test_MDC_Nvirtualsources(par):
217231
MDCop,
218232
nt2 * parmod["ny"] * parmod["nx"],
219233
nt2 * parmod["nx"] * parmod["nx"],
234+
rtol=1e-4 if dtype == np.float32 else 1e-6,
220235
backend=backend,
221236
)
222237

223-
mwav = np.asarray(mwav).transpose(2, 0, 1)
238+
mwav = np.asarray(mwav.astype(dtype)).transpose(2, 0, 1)
224239
d = MDCop * mwav.ravel()
225240
d = d.reshape(nt2, parmod["ny"], parmod["nx"])
241+
assert d.dtype == dtype
226242

227243
for it, _ in zip(it0_G, amp_G, strict=True):
228244
ittot = it0_m + it
@@ -272,9 +288,12 @@ def test_MDC_Nvirtualsources(par):
272288
(par1),
273289
],
274290
)
275-
def test_MDC_1virtualsource_scipy(par):
291+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
292+
def test_MDC_1virtualsource_scipy(par, dtype):
276293
"""Dot-test for MDC operator of 1 virtual source with scipy engine and workers"""
277-
nt2, _, _, _, Gwav_fft = create_data(par, 1)
294+
cdtype = (np.empty(0, dtype=dtype) + 1j * np.empty(0, dtype=dtype)).dtype
295+
nt2, _, mwav, _, Gwav_fft = create_data(par, 1)
296+
Gwav_fft = Gwav_fft.astype(cdtype)
278297

279298
MDCop = MDC(
280299
np.asarray(Gwav_fft).transpose(2, 0, 1),
@@ -286,4 +305,15 @@ def test_MDC_1virtualsource_scipy(par):
286305
engine="scipy",
287306
**dict(workers=4),
288307
)
289-
dottest(MDCop, nt2 * parmod["ny"], nt2 * parmod["nx"], backend=backend)
308+
dottest(
309+
MDCop,
310+
nt2 * parmod["ny"],
311+
nt2 * parmod["nx"],
312+
rtol=1e-4 if dtype == np.float32 else 1e-6,
313+
backend=backend,
314+
)
315+
316+
mwav = np.asarray(mwav.astype(dtype)).T
317+
d = MDCop * mwav.ravel()
318+
d = d.reshape(nt2, parmod["ny"])
319+
assert d.dtype == dtype

0 commit comments

Comments
 (0)