@@ -51,13 +51,13 @@ def test_TorchOperator(par, dtype):
5151 # torch operator
5252 yt = Top .apply (xt )
5353 yt .backward (vt , retain_graph = True )
54- yt = yt .detach ().cpu ()
55- xadjt = xt .grad .cpu ()
54+ yt = yt .detach ().cpu (). numpy ()
55+ xadjt = xt .grad .cpu (). numpy ()
5656
57- assert yt .dtype == torch . from_numpy ( x ) .dtype
58- assert xadjt .dtype == torch . from_numpy ( x ) .dtype
59- assert_array_equal (y , yt . numpy () )
60- assert_array_equal (xadj , xadjt . numpy () )
57+ assert yt .dtype == x .dtype
58+ assert xadjt .dtype == x .dtype
59+ assert_array_equal (y , yt )
60+ assert_array_equal (xadj , xadjt )
6161
6262
6363@pytest .mark .skipif (platform .system () == "Darwin" , reason = "Not OSX enabled" )
@@ -78,9 +78,9 @@ def test_TorchOperator_batch(par, dtype):
7878
7979 y = Dop .matmat (x .T ).T
8080 yt = Top .apply (xt )
81-
82- assert yt .dtype == torch . from_numpy ( x ) .dtype
83- assert_array_equal (y , yt . detach (). cpu (). numpy () )
81+ yt = yt . detach (). cpu (). numpy ()
82+ assert yt .dtype == x .dtype
83+ assert_array_equal (y , yt )
8484
8585
8686@pytest .mark .skipif (platform .system () == "Darwin" , reason = "Not OSX enabled" )
@@ -105,6 +105,9 @@ def test_TorchOperator_batch_nd(par, dtype):
105105
106106 y = (Dop @ x .transpose (1 , 2 , 0 )).transpose (2 , 0 , 1 )
107107 yt = Top .apply (xt )
108+ yt = yt .detach ().cpu ().numpy ()
108109
109- assert yt .dtype == torch .from_numpy (x ).dtype
110- assert_array_equal (y , yt .detach ().cpu ().numpy ())
110+ assert yt .dtype == dtype
111+ assert_array_equal (
112+ y ,
113+ )
0 commit comments