Skip to content

Commit eafdf9d

Browse files
committed
fix: fixed dtypes in torchoperator tests
1 parent d223813 commit eafdf9d

1 file changed

Lines changed: 14 additions & 11 deletions

File tree

pytests/test_torchoperator.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ def test_TorchOperator(par, dtype):
5151
# torch operator
5252
yt = Top.apply(xt)
5353
yt.backward(vt, retain_graph=True)
54-
yt = yt.detach().cpu()
55-
xadjt = xt.grad.cpu()
54+
yt = yt.detach().cpu().numpy()
55+
xadjt = xt.grad.cpu().numpy()
5656

57-
assert yt.dtype == torch.from_numpy(x).dtype
58-
assert xadjt.dtype == torch.from_numpy(x).dtype
59-
assert_array_equal(y, yt.numpy())
60-
assert_array_equal(xadj, xadjt.numpy())
57+
assert yt.dtype == x.dtype
58+
assert xadjt.dtype == x.dtype
59+
assert_array_equal(y, yt)
60+
assert_array_equal(xadj, xadjt)
6161

6262

6363
@pytest.mark.skipif(platform.system() == "Darwin", reason="Not OSX enabled")
@@ -78,9 +78,9 @@ def test_TorchOperator_batch(par, dtype):
7878

7979
y = Dop.matmat(x.T).T
8080
yt = Top.apply(xt)
81-
82-
assert yt.dtype == torch.from_numpy(x).dtype
83-
assert_array_equal(y, yt.detach().cpu().numpy())
81+
yt = yt.detach().cpu().numpy()
82+
assert yt.dtype == x.dtype
83+
assert_array_equal(y, yt)
8484

8585

8686
@pytest.mark.skipif(platform.system() == "Darwin", reason="Not OSX enabled")
@@ -105,6 +105,9 @@ def test_TorchOperator_batch_nd(par, dtype):
105105

106106
y = (Dop @ x.transpose(1, 2, 0)).transpose(2, 0, 1)
107107
yt = Top.apply(xt)
108+
yt = yt.detach().cpu().numpy()
108109

109-
assert yt.dtype == torch.from_numpy(x).dtype
110-
assert_array_equal(y, yt.detach().cpu().numpy())
110+
assert yt.dtype == dtype
111+
assert_array_equal(
112+
y,
113+
)

0 commit comments

Comments
 (0)