@@ -1687,9 +1687,7 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
16871687 if not self .preallocate :
16881688 res : NDArray = self .y - self .Opmatvec (x )
16891689 else :
1690- self .ncp .subtract (
1691- self .y , self .Opmatvec (x ), out = self .res if self .preallocate else res
1692- )
1690+ self .ncp .subtract (self .y , self .Opmatvec (x ), out = self .res )
16931691
16941692 if self .monitorres :
16951693 self .normres = np .linalg .norm (self .res if self .preallocate else res )
@@ -1707,9 +1705,9 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17071705 grad : NDArray = self .alpha * (self .Oprmatvec (res ))
17081706 else :
17091707 self .ncp .multiply (
1710- self .Oprmatvec (self .res if self . preallocate else res ),
1708+ self .Oprmatvec (self .res ),
17111709 self .alpha ,
1712- out = self .grad if self . preallocate else grad ,
1710+ out = self .grad ,
17131711 )
17141712
17151713 # update inverted model
@@ -1718,8 +1716,8 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17181716 else :
17191717 self .ncp .add (
17201718 x ,
1721- self .grad if self . preallocate else grad ,
1722- out = self .x_unthesh if self . preallocate else x_unthesh ,
1719+ self .grad ,
1720+ out = self .x_unthesh ,
17231721 )
17241722
17251723 # apply SOp.H to current x
@@ -2653,6 +2651,7 @@ def setup(
26532651 tol : float = 1e-10 ,
26542652 tau : float = 1.0 ,
26552653 restart : bool = False ,
2654+ preallocate : bool = False ,
26562655 show : bool = False ,
26572656 ) -> NDArray :
26582657 r"""Setup solver
@@ -2698,6 +2697,10 @@ def setup(
26982697 the initial guess (``True``) or with the last estimate (``False``).
26992698 Note that when this is set to ``True``, the ``x0`` provided in the setup will
27002699 be used in all iterations.
2700+ preallocate : :obj:`bool`, optional
2701+ .. versionadded:: 2.5.0
2702+
2703+ Pre-allocate all variables used by the solver
27012704 show : :obj:`bool`, optional
27022705 Display setup log
27032706
@@ -2719,14 +2722,19 @@ def setup(
27192722 self .tol = tol
27202723 self .tau = tau
27212724 self .restart = restart
2725+
27222726 self .ncp = get_array_module (y )
2727+ self .isjax = get_module_name (self .ncp ) == "jax"
2728+ self ._setpreallocate (preallocate )
27232729
27242730 # L1 regularizations
27252731 self .nregsL1 = len (RegsL1 )
27262732 self .b = [
27272733 self .ncp .zeros (RegL1 .shape [0 ], dtype = self .Op .dtype ) for RegL1 in RegsL1
27282734 ]
2729- self .d = self .b .copy ()
2735+ self .d = [
2736+ self .ncp .zeros (RegL1 .shape [0 ], dtype = self .Op .dtype ) for RegL1 in RegsL1
2737+ ]
27302738
27312739 # L2 regularizations
27322740 self .nregsL2 = 0 if RegsL2 is None else len (RegsL2 )
@@ -2797,11 +2805,22 @@ def step(
27972805 Updated model vector
27982806
27992807 """
2808+ # add preallocate to keywords of solver
2809+ if self .preallocate and (engine == "pylops" or self .ncp != np ):
2810+ kwargs_solver ["preallocate" ] = True
2811+
28002812 for _ in range (self .niter_inner ):
28012813 # regularized problem
2802- dataregs = self .dataregsL2 + [
2803- self .d [ireg ] - self .b [ireg ] for ireg in range (self .nregsL1 )
2804- ]
2814+ if not self .preallocate :
2815+ dataregs = self .dataregsL2 + [
2816+ self .d [ireg ] - self .b [ireg ] for ireg in range (self .nregsL1 )
2817+ ]
2818+ else :
2819+ for ireg in range (self .nregsL1 ):
2820+ self .ncp .subtract (self .d [ireg ], self .b [ireg ], out = self .d [ireg ])
2821+ dataregs = self .dataregsL2 + [
2822+ self .d [ireg ] for ireg in range (self .nregsL1 )
2823+ ]
28052824 x = regularized_inversion (
28062825 self .Op ,
28072826 self .y ,
@@ -2813,11 +2832,18 @@ def step(
28132832 engine = engine ,
28142833 ** kwargs_solver ,
28152834 )[0 ]
2816- # Shrinkage
2817- for ireg in range (self .nregsL1 ):
2818- self .d [ireg ] = _softthreshold (
2819- self .RegsL1 [ireg ].matvec (x ) + self .b [ireg ], self .epsRL1s [ireg ]
2820- )
2835+ # shrinkage
2836+ if not self .preallocate :
2837+ for ireg in range (self .nregsL1 ):
2838+ self .d [ireg ] = _softthreshold (
2839+ self .RegsL1 [ireg ].matvec (x ) + self .b [ireg ], self .epsRL1s [ireg ]
2840+ )
2841+ else :
2842+ for ireg in range (self .nregsL1 ):
2843+ self .ncp .add (
2844+ self .RegsL1 [ireg ].matvec (x ), self .b [ireg ], out = self .d [ireg ]
2845+ )
2846+ self .d [ireg ] = _softthreshold (self .d [ireg ], self .epsRL1s [ireg ])
28212847
28222848 # Bregman update
28232849 for ireg in range (self .nregsL1 ):
0 commit comments