Skip to content

Commit 17d77bf

Browse files
authored
Merge pull request #409 from mrava87/v1.18.x
bug: fix ista/fista for cupy arrays
2 parents 1cf20d0 + 1f35d9d commit 17d77bf

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

pylops/optimization/sparsity.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ def ISTA(
906906
Op1, niter=eigsiter, tol=eigstol, dtype=Op1.dtype, backend="cupy"
907907
)[0]
908908
)
909-
alpha = 1.0 / maxeig
909+
alpha = 1.0 / float(maxeig)
910910

911911
# define threshold
912912
thresh = eps * alpha * 0.5
@@ -959,8 +959,8 @@ def ISTA(
959959
normresold = normres
960960

961961
# compute gradient
962-
grad = alpha * Op.H @ res
963-
962+
grad = alpha * (Op.H @ res)
963+
964964
# update inverted model
965965
xinv_unthesh = xinv + grad
966966
if SOp is not None:
@@ -1211,7 +1211,7 @@ def FISTA(
12111211
Op1, niter=eigsiter, tol=eigstol, dtype=Op1.dtype, backend="cupy"
12121212
)[0]
12131213
)
1214-
alpha = 1.0 / maxeig
1214+
alpha = 1.0 / float(maxeig)
12151215

12161216
# define threshold
12171217
thresh = eps * alpha * 0.5
@@ -1254,7 +1254,7 @@ def FISTA(
12541254
resz = data - Op @ zinv
12551255

12561256
# compute gradient
1257-
grad = alpha * Op.H @ resz
1257+
grad = alpha * (Op.H @ resz)
12581258

12591259
# update inverted model
12601260
xinv_unthesh = zinv + grad

0 commit comments

Comments
 (0)