@@ -744,18 +744,16 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], *exp_avg: Tensor
744744 :param GG: List of accumulated gradient outer products.
745745 :param Q: List of current eigenbases (updated in-place to Q_new).
746746 :param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
747+ Pass nothing (or only `None` entries) to refresh Q without rotating any state.
747748 """
748- if not exp_avg :
749+ if isinstance ( Q , list ) and not Q :
749750 return
750751
751- ref = exp_avg [0 ]
752- if ref .dim () == 0 : # preconditioning doesn't make sense here
752+ ref = exp_avg [0 ] if exp_avg else None
753+ if ref is not None and ref .dim () == 0 : # preconditioning doesn't make sense here
753754 Q .clear ()
754755 return
755756
756- if isinstance (Q , list ) and not Q :
757- return
758-
759757 if ref is not None and ref .dim () != len (Q ):
760758 raise ValueError (f"ref dim { ref .dim ()} does not match Q length { len (Q )} " )
761759
@@ -778,7 +776,8 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], *exp_avg: Tensor
778776
779777 if ref is None :
780778 for q , q_new in zip (Q , new_qs ):
781- copy_stochastic_ (q , q_new )
779+ if q is not None :
780+ copy_stochastic_ (q , q_new )
782781 return
783782
784783 assert ref .ndim < 13 , "ref.ndim must be less than 13"
@@ -1145,6 +1144,28 @@ def update_ggt_kl(grad, GG, Q, max_precond_dim, precondition_1d, beta, eps):
11451144 stochastic_lerp_ (m , outer , 1 - beta )
11461145
11471146
1147+ @decorator_knowngood
1148+ def _kl_shampoo_kron_scale (grad : Tensor , Q : List [Optional [Tensor ]], GG : List [Optional [Tensor ]], eps : float ):
1149+ out = promote (grad )
1150+ for idx , (q , m ) in enumerate (zip (Q , GG )):
1151+ if q is None or m is None :
1152+ continue
1153+ q32 , m32 = promote (q ), promote (m )
1154+ d = ((q32 .T @ m32 ) * q32 .T ).sum (dim = 1 ).clamp_min (eps ).rsqrt ()
1155+ shape = [1 ] * out .ndim
1156+ shape [idx ] = - 1
1157+ out = out * d .view (shape )
1158+ return out .to (grad .dtype )
1159+
1160+
1161+ def kl_shampoo_precondition (grad , Q , GG , eps ):
1162+ """KL-Shampoo Kronecker preconditioner (arXiv:2509.03378).
1163+
1164+ Applies ⊗_i Q[i] diag(d_i^{-1/2}) Q[i].T to grad, with d_i = diag(Q[i].T @ GG[i] @ Q[i]).
1165+ """
1166+ return project (_kl_shampoo_kron_scale (project (grad , Q , back = False ), Q , GG , eps ), Q , back = True )
1167+
1168+
11481169def tree_apply (fn : Callable [[Any ], Any ]) -> Callable [[Any ], Any ]:
11491170 def _fn (* args ):
11501171 return tree_map (fn , * args )
@@ -1241,22 +1262,29 @@ def update_preconditioner(grad, Q, GG, exp_avg, max_precond_dim, precondition_1d
12411262 get_orthogonal_matrix_QR (GG , Q , * exp_avg )
12421263
12431264
1244- def init_preconditioner (grad , state , max_precond_dim , precondition_1d ):
1265+ def init_preconditioner (grad , state , max_precond_dim , precondition_1d , init_factor : float = 0.0 ):
12451266 """
12461267 Initializes the preconditioner matrices (L and R in the paper).
1268+
1269+ If init_factor > 0, GG starts as init_factor * I per side (uniform-eigval seed used by KL-Shampoo
1270+ to avoid the rank-1 explosion: 1/sqrt(eps) along null directions). Otherwise, seeds with one
1271+ outer product of grad (standard SOAP behavior).
12471272 """
12481273 state ["GG" ] = [] # Will hold all the preconditioner matrices (L and R in the paper).
12491274 if grad .numel () > 1 and (grad .ndim > 1 or precondition_1d ):
12501275 for sh in grad .shape :
12511276 if sh > max_precond_dim or sh == 1 :
12521277 # via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
12531278 state ["GG" ].append (None )
1279+ elif init_factor > 0 :
1280+ state ["GG" ].append (torch .eye (sh , device = grad .device , dtype = grad .dtype ) * init_factor )
12541281 else :
12551282 state ["GG" ].append (torch .zeros (sh , sh , device = grad .device , dtype = grad .dtype ))
12561283 else :
12571284 state ["GG" ].append (None )
12581285
1259- update_ggt (grad , state ["GG" ], max_precond_dim , precondition_1d , 0 )
1286+ if init_factor <= 0 :
1287+ update_ggt (grad , state ["GG" ], max_precond_dim , precondition_1d , 0 )
12601288 state ["Q" ] = get_orthogonal_matrix (state ["GG" ])
12611289
12621290
0 commit comments