Skip to content

Commit b35bd93

Browse files
committed
add KLShampoo
1 parent d972fe6 commit b35bd93

4 files changed

Lines changed: 107 additions & 10 deletions

File tree

heavyball/__init__.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,60 @@ def __init__(
737737
self._build_soap_defaults(locals(), fns=(C.scale_by_kl_soap,))
738738

739739

740+
class KLShampoo(SOAPBase):
741+
"""
742+
KL-Shampoo
743+
744+
Shampoo with KL-corrected Kronecker factor accumulation, applied directly as
745+
⊗_i Q[i] diag(d_i^{-1/2}) Q[i].T to a momentum-EMA gradient. Unlike KL-SOAP,
746+
no Adam runs in the projected space, and the eigenvalues d_i = diag(Q[i].T @ GG[i] @ Q[i])
747+
are the preconditioner. GG is seeded with init_factor * I to keep the first preconditioner
748+
uniform (= 1/sqrt(init_factor) * I) instead of exploding along the rank-1 null space.
749+
750+
Sources:
751+
KL-Shampoo:
752+
Understanding and Improving Shampoo and SOAP via Kullback-Leibler Minimization
753+
Wu Lin, Scott C. Lowe, Felix Dangel, Runa Eschenhagen, Zikun Xu, Roger B. Grosse
754+
https://arxiv.org/abs/2509.03378
755+
"""
756+
757+
def __init__(
758+
self,
759+
params,
760+
lr: float = 3e-3,
761+
betas=(0.9, 0.95),
762+
shampoo_beta: float = 0.95,
763+
eps: float = 1e-8,
764+
weight_decay: float = 0.01,
765+
precondition_frequency: int = 2,
766+
max_precond_dim: int = 2048,
767+
merge_dims: bool = True,
768+
precondition_1d: bool = False,
769+
warmup_steps: int = 0,
770+
split: bool = False,
771+
multi_tensor: bool = True,
772+
mars: bool = False,
773+
caution: bool = False,
774+
mars_gamma: float = 0.0025,
775+
palm: bool = C.use_default,
776+
precond_scheduler=(1 / 3, 9),
777+
beta2_scale: float = 0.8,
778+
use_precond_schedule: bool = C.use_default,
779+
gradient_clipping: C.str_or_fn = C.use_default,
780+
update_clipping: C.str_or_fn = C.use_default,
781+
storage_dtype: str = "float32",
782+
precond_grad_accum: bool = False,
783+
compile_step: bool = C.use_default,
784+
promote: bool = C.use_default,
785+
ecc: str | None = None,
786+
param_ecc: str | None = None,
787+
orig_shapes: ShapeMap | None = None,
788+
init_factor: float = 0.1,
789+
**kwargs,
790+
):
791+
self._build_soap_defaults(locals(), fns=(C.scale_by_kl_shampoo,))
792+
793+
740794
class SOAPNAdam(SOAPBase):
741795
def __init__(
742796
self,

heavyball/chainable.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,9 @@ def scion_auto_norm(group, update, grad, param, scion_state):
10731073

10741074

10751075
def _init_soap(state, group, update, grad, param):
1076-
utils.init_preconditioner(grad, state, group["max_precond_dim"], group["precondition_1d"])
1076+
utils.init_preconditioner(
1077+
grad, state, group["max_precond_dim"], group["precondition_1d"], group.get("init_factor", 0.0)
1078+
)
10771079

10781080

10791081
def _apply_soap_preconditioner(group, update, Q, GG, *references, use_kl: bool = False, eps=1e-8):
@@ -1132,6 +1134,17 @@ def scale_by_kl_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG):
11321134
return precond
11331135

11341136

1137+
@needs_full_param
1138+
@zero_guard("exp_avg")
1139+
@general_guard("Q", "GG", init_fn=_init_soap)
1140+
@no_state
1141+
def scale_by_kl_shampoo(group, update, grad, param, exp_avg, Q, GG):
1142+
utils.stochastic_lerp_(exp_avg, update, 1 - utils.get_beta1(group))
1143+
precond = [utils.kl_shampoo_precondition(e, q, gg, group["eps"]) for e, q, gg in zip(exp_avg, Q, GG)]
1144+
_apply_soap_preconditioner(group, update, Q, GG, use_kl=True, eps=group["eps"])
1145+
return precond
1146+
1147+
11351148
@needs_full_param
11361149
@zero_guard("exp_avg", "exp_avg_sq")
11371150
@general_guard("mu_product", init_fn=_init_mu_product, skip_first=False)

heavyball/utils.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -744,18 +744,16 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], *exp_avg: Tensor
744744
:param GG: List of accumulated gradient outer products.
745745
:param Q: List of current eigenbases (updated in-place to Q_new).
746746
:param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
747+
Pass nothing (or only `None` entries) to refresh Q without rotating any state.
747748
"""
748-
if not exp_avg:
749+
if isinstance(Q, list) and not Q:
749750
return
750751

751-
ref = exp_avg[0]
752-
if ref.dim() == 0: # preconditioning doesn't make sense here
752+
ref = exp_avg[0] if exp_avg else None
753+
if ref is not None and ref.dim() == 0: # preconditioning doesn't make sense here
753754
Q.clear()
754755
return
755756

756-
if isinstance(Q, list) and not Q:
757-
return
758-
759757
if ref is not None and ref.dim() != len(Q):
760758
raise ValueError(f"ref dim {ref.dim()} does not match Q length {len(Q)}")
761759

@@ -778,7 +776,8 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], *exp_avg: Tensor
778776

779777
if ref is None:
780778
for q, q_new in zip(Q, new_qs):
781-
copy_stochastic_(q, q_new)
779+
if q is not None:
780+
copy_stochastic_(q, q_new)
782781
return
783782

784783
assert ref.ndim < 13, "ref.ndim must be less than 13"
@@ -1145,6 +1144,28 @@ def update_ggt_kl(grad, GG, Q, max_precond_dim, precondition_1d, beta, eps):
11451144
stochastic_lerp_(m, outer, 1 - beta)
11461145

11471146

1147+
@decorator_knowngood
1148+
def _kl_shampoo_kron_scale(grad: Tensor, Q: List[Optional[Tensor]], GG: List[Optional[Tensor]], eps: float):
1149+
out = promote(grad)
1150+
for idx, (q, m) in enumerate(zip(Q, GG)):
1151+
if q is None or m is None:
1152+
continue
1153+
q32, m32 = promote(q), promote(m)
1154+
d = ((q32.T @ m32) * q32.T).sum(dim=1).clamp_min(eps).rsqrt()
1155+
shape = [1] * out.ndim
1156+
shape[idx] = -1
1157+
out = out * d.view(shape)
1158+
return out.to(grad.dtype)
1159+
1160+
1161+
def kl_shampoo_precondition(grad, Q, GG, eps):
1162+
"""KL-Shampoo Kronecker preconditioner (arXiv:2509.03378).
1163+
1164+
Applies ⊗_i Q[i] diag(d_i^{-1/2}) Q[i].T to grad, with d_i = diag(Q[i].T @ GG[i] @ Q[i]).
1165+
"""
1166+
return project(_kl_shampoo_kron_scale(project(grad, Q, back=False), Q, GG, eps), Q, back=True)
1167+
1168+
11481169
def tree_apply(fn: Callable[[Any], Any]) -> Callable[[Any], Any]:
11491170
def _fn(*args):
11501171
return tree_map(fn, *args)
@@ -1241,22 +1262,29 @@ def update_preconditioner(grad, Q, GG, exp_avg, max_precond_dim, precondition_1d
12411262
get_orthogonal_matrix_QR(GG, Q, *exp_avg)
12421263

12431264

1244-
def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
1265+
def init_preconditioner(grad, state, max_precond_dim, precondition_1d, init_factor: float = 0.0):
12451266
"""
12461267
Initializes the preconditioner matrices (L and R in the paper).
1268+
1269+
If init_factor > 0, GG starts as init_factor * I per side (uniform-eigval seed used by KL-Shampoo
1270+
to avoid the rank-1 explosion: 1/sqrt(eps) along null directions). Otherwise, seeds with one
1271+
outer product of grad (standard SOAP behavior).
12471272
"""
12481273
state["GG"] = [] # Will hold all the preconditioner matrices (L and R in the paper).
12491274
if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
12501275
for sh in grad.shape:
12511276
if sh > max_precond_dim or sh == 1:
12521277
# via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
12531278
state["GG"].append(None)
1279+
elif init_factor > 0:
1280+
state["GG"].append(torch.eye(sh, device=grad.device, dtype=grad.dtype) * init_factor)
12541281
else:
12551282
state["GG"].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
12561283
else:
12571284
state["GG"].append(None)
12581285

1259-
update_ggt(grad, state["GG"], max_precond_dim, precondition_1d, 0)
1286+
if init_factor <= 0:
1287+
update_ggt(grad, state["GG"], max_precond_dim, precondition_1d, 0)
12601288
state["Q"] = get_orthogonal_matrix(state["GG"])
12611289

12621290

test/test_chainable_cpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def state_fn(_x):
8484
# Optimizers whose chains use shape-dependent or global-reduction ops must need gather
8585
_EXPECT_GATHER = {
8686
"SOAP",
87+
"KLSOAP",
88+
"KLShampoo",
8789
"SOAPNAdam",
8890
"SOAPAdEMAMix",
8991
"SOLP",

0 commit comments

Comments
 (0)