@@ -140,9 +140,12 @@ def create_data(par, nv):
140140@pytest .mark .parametrize (
141141 "par" , [(par1 ), (par2 ), (par3 ), (par4 ), (par5 ), (par6 ), (par7 ), (par8 )]
142142)
143- def test_MDC_1virtualsource (par ):
143+ @pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
144+ def test_MDC_1virtualsource (par , dtype ):
144145 """Dot-test and inversion for MDC operator of 1 virtual source"""
146+ cdtype = (np .empty (0 , dtype = dtype ) + 1j * np .empty (0 , dtype = dtype )).dtype
145147 nt2 , wav , mwav , Gwav , Gwav_fft = create_data (par , 1 )
148+ Gwav_fft = Gwav_fft .astype (cdtype )
146149
147150 MDCop = MDC (
148151 np .asarray (Gwav_fft ).transpose (2 , 0 , 1 ),
@@ -152,10 +155,18 @@ def test_MDC_1virtualsource(par):
152155 dr = parmod ["dx" ],
153156 twosided = par ["twosided" ],
154157 )
155- dottest (MDCop , nt2 * parmod ["ny" ], nt2 * parmod ["nx" ], backend = backend )
156- mwav = np .asarray (mwav ).T
158+ dottest (
159+ MDCop ,
160+ nt2 * parmod ["ny" ],
161+ nt2 * parmod ["nx" ],
162+ rtol = 1e-4 if dtype == np .float32 else 1e-6 ,
163+ backend = backend ,
164+ )
165+
166+ mwav = np .asarray (mwav .astype (dtype )).T
157167 d = MDCop * mwav .ravel ()
158168 d = d .reshape (nt2 , parmod ["ny" ])
169+ assert d .dtype == dtype
159170
160171 for it , amp in zip (it0_G , amp_G , strict = True ):
161172 ittot = it0_m + it
@@ -201,9 +212,12 @@ def test_MDC_1virtualsource(par):
201212@pytest .mark .parametrize (
202213 "par" , [(par1 ), (par2 ), (par3 ), (par4 ), (par5 ), (par6 ), (par7 ), (par8 )]
203214)
204- def test_MDC_Nvirtualsources (par ):
215+ @pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
216+ def test_MDC_Nvirtualsources (par , dtype ):
205217 """Dot-test and inversion for MDC operator of N virtual source"""
218+ cdtype = (np .empty (0 , dtype = dtype ) + 1j * np .empty (0 , dtype = dtype )).dtype
206219 nt2 , _ , mwav , Gwav , Gwav_fft = create_data (par , parmod ["nx" ])
220+ Gwav_fft = Gwav_fft .astype (cdtype )
207221
208222 MDCop = MDC (
209223 np .asarray (Gwav_fft ).transpose (2 , 0 , 1 ),
@@ -217,12 +231,14 @@ def test_MDC_Nvirtualsources(par):
217231 MDCop ,
218232 nt2 * parmod ["ny" ] * parmod ["nx" ],
219233 nt2 * parmod ["nx" ] * parmod ["nx" ],
234+ rtol = 1e-4 if dtype == np .float32 else 1e-6 ,
220235 backend = backend ,
221236 )
222237
223- mwav = np .asarray (mwav ).transpose (2 , 0 , 1 )
238+ mwav = np .asarray (mwav . astype ( dtype ) ).transpose (2 , 0 , 1 )
224239 d = MDCop * mwav .ravel ()
225240 d = d .reshape (nt2 , parmod ["ny" ], parmod ["nx" ])
241+ assert d .dtype == dtype
226242
227243 for it , _ in zip (it0_G , amp_G , strict = True ):
228244 ittot = it0_m + it
@@ -272,9 +288,12 @@ def test_MDC_Nvirtualsources(par):
272288 (par1 ),
273289 ],
274290)
275- def test_MDC_1virtualsource_scipy (par ):
291+ @pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
292+ def test_MDC_1virtualsource_scipy (par , dtype ):
276293 """Dot-test for MDC operator of 1 virtual source with scipy engine and workers"""
277- nt2 , _ , _ , _ , Gwav_fft = create_data (par , 1 )
294+ cdtype = (np .empty (0 , dtype = dtype ) + 1j * np .empty (0 , dtype = dtype )).dtype
295+ nt2 , _ , mwav , _ , Gwav_fft = create_data (par , 1 )
296+ Gwav_fft = Gwav_fft .astype (cdtype )
278297
279298 MDCop = MDC (
280299 np .asarray (Gwav_fft ).transpose (2 , 0 , 1 ),
@@ -286,4 +305,15 @@ def test_MDC_1virtualsource_scipy(par):
286305 engine = "scipy" ,
287306 ** dict (workers = 4 ),
288307 )
289- dottest (MDCop , nt2 * parmod ["ny" ], nt2 * parmod ["nx" ], backend = backend )
308+ dottest (
309+ MDCop ,
310+ nt2 * parmod ["ny" ],
311+ nt2 * parmod ["nx" ],
312+ rtol = 1e-4 if dtype == np .float32 else 1e-6 ,
313+ backend = backend ,
314+ )
315+
316+ mwav = np .asarray (mwav .astype (dtype )).T
317+ d = MDCop * mwav .ravel ()
318+ d = d .reshape (nt2 , parmod ["ny" ])
319+ assert d .dtype == dtype
0 commit comments