Skip to content

Commit c75da67

Browse files
committed
feat: added preallocate to SplitBregman
1 parent f98cc31 commit c75da67

2 files changed

Lines changed: 48 additions & 16 deletions

File tree

pylops/optimization/cls_sparsity.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,9 +1687,7 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
16871687
if not self.preallocate:
16881688
res: NDArray = self.y - self.Opmatvec(x)
16891689
else:
1690-
self.ncp.subtract(
1691-
self.y, self.Opmatvec(x), out=self.res if self.preallocate else res
1692-
)
1690+
self.ncp.subtract(self.y, self.Opmatvec(x), out=self.res)
16931691

16941692
if self.monitorres:
16951693
self.normres = np.linalg.norm(self.res if self.preallocate else res)
@@ -1707,9 +1705,9 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17071705
grad: NDArray = self.alpha * (self.Oprmatvec(res))
17081706
else:
17091707
self.ncp.multiply(
1710-
self.Oprmatvec(self.res if self.preallocate else res),
1708+
self.Oprmatvec(self.res),
17111709
self.alpha,
1712-
out=self.grad if self.preallocate else grad,
1710+
out=self.grad,
17131711
)
17141712

17151713
# update inverted model
@@ -1718,8 +1716,8 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17181716
else:
17191717
self.ncp.add(
17201718
x,
1721-
self.grad if self.preallocate else grad,
1722-
out=self.x_unthesh if self.preallocate else x_unthesh,
1719+
self.grad,
1720+
out=self.x_unthesh,
17231721
)
17241722

17251723
# apply SOp.H to current x
@@ -2653,6 +2651,7 @@ def setup(
26532651
tol: float = 1e-10,
26542652
tau: float = 1.0,
26552653
restart: bool = False,
2654+
preallocate: bool = False,
26562655
show: bool = False,
26572656
) -> NDArray:
26582657
r"""Setup solver
@@ -2698,6 +2697,10 @@ def setup(
26982697
the initial guess (``True``) or with the last estimate (``False``).
26992698
Note that when this is set to ``True``, the ``x0`` provided in the setup will
27002699
be used in all iterations.
2700+
preallocate : :obj:`bool`, optional
2701+
.. versionadded:: 2.5.0
2702+
2703+
Pre-allocate all variables used by the solver
27012704
show : :obj:`bool`, optional
27022705
Display setup log
27032706
@@ -2719,14 +2722,19 @@ def setup(
27192722
self.tol = tol
27202723
self.tau = tau
27212724
self.restart = restart
2725+
27222726
self.ncp = get_array_module(y)
2727+
self.isjax = get_module_name(self.ncp) == "jax"
2728+
self._setpreallocate(preallocate)
27232729

27242730
# L1 regularizations
27252731
self.nregsL1 = len(RegsL1)
27262732
self.b = [
27272733
self.ncp.zeros(RegL1.shape[0], dtype=self.Op.dtype) for RegL1 in RegsL1
27282734
]
2729-
self.d = self.b.copy()
2735+
self.d = [
2736+
self.ncp.zeros(RegL1.shape[0], dtype=self.Op.dtype) for RegL1 in RegsL1
2737+
]
27302738

27312739
# L2 regularizations
27322740
self.nregsL2 = 0 if RegsL2 is None else len(RegsL2)
@@ -2797,11 +2805,22 @@ def step(
27972805
Updated model vector
27982806
27992807
"""
2808+
# add preallocate to keywords of solver
2809+
if self.preallocate and (engine == "pylops" or self.ncp != np):
2810+
kwargs_solver["preallocate"] = True
2811+
28002812
for _ in range(self.niter_inner):
28012813
# regularized problem
2802-
dataregs = self.dataregsL2 + [
2803-
self.d[ireg] - self.b[ireg] for ireg in range(self.nregsL1)
2804-
]
2814+
if not self.preallocate:
2815+
dataregs = self.dataregsL2 + [
2816+
self.d[ireg] - self.b[ireg] for ireg in range(self.nregsL1)
2817+
]
2818+
else:
2819+
for ireg in range(self.nregsL1):
2820+
self.ncp.subtract(self.d[ireg], self.b[ireg], out=self.d[ireg])
2821+
dataregs = self.dataregsL2 + [
2822+
self.d[ireg] for ireg in range(self.nregsL1)
2823+
]
28052824
x = regularized_inversion(
28062825
self.Op,
28072826
self.y,
@@ -2813,11 +2832,18 @@ def step(
28132832
engine=engine,
28142833
**kwargs_solver,
28152834
)[0]
2816-
# Shrinkage
2817-
for ireg in range(self.nregsL1):
2818-
self.d[ireg] = _softthreshold(
2819-
self.RegsL1[ireg].matvec(x) + self.b[ireg], self.epsRL1s[ireg]
2820-
)
2835+
# shrinkage
2836+
if not self.preallocate:
2837+
for ireg in range(self.nregsL1):
2838+
self.d[ireg] = _softthreshold(
2839+
self.RegsL1[ireg].matvec(x) + self.b[ireg], self.epsRL1s[ireg]
2840+
)
2841+
else:
2842+
for ireg in range(self.nregsL1):
2843+
self.ncp.add(
2844+
self.RegsL1[ireg].matvec(x), self.b[ireg], out=self.d[ireg]
2845+
)
2846+
self.d[ireg] = _softthreshold(self.d[ireg], self.epsRL1s[ireg])
28212847

28222848
# Bregman update
28232849
for ireg in range(self.nregsL1):

pylops/optimization/sparsity.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ def splitbregman(
636636
itershow: Tuple[int, int, int] = (10, 10, 10),
637637
show_inner: bool = False,
638638
callback: Optional[Callable] = None,
639+
preallocate: bool = False,
639640
**kwargs_lsqr,
640641
) -> Tuple[NDArray, int, NDArray]:
641642
r"""Split Bregman for mixed L2-L1 norms.
@@ -698,6 +699,10 @@ def splitbregman(
698699
callback : :obj:`callable`, optional
699700
Function with signature (``callback(x)``) to call after each iteration
700701
where ``x`` is the current model vector
702+
preallocate : :obj:`bool`, optional
703+
.. versionadded:: 2.5.0
704+
705+
Pre-allocate all variables used by the solver
701706
**kwargs_lsqr
702707
Arbitrary keyword arguments for
703708
:py:func:`scipy.sparse.linalg.lsqr` solver used to solve the first
@@ -735,6 +740,7 @@ def splitbregman(
735740
tau=tau,
736741
restart=restart,
737742
engine=engine,
743+
preallocate=preallocate,
738744
show=show,
739745
itershow=itershow,
740746
show_inner=show_inner,

0 commit comments

Comments
 (0)