Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 127 additions & 43 deletions deepmd/pt/optimizer/hybrid_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@
update = beta * m_t + (1 - beta) * g_t

2. Orthogonalization:
- Standard path:
- Standard path (hybrid two-stage Newton-Schulz, DeepSeek-V4 style):
X_0 = G / ||G||_F
A_k = X_k @ X_k^T
X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k
* Fast stage: ``NS_STEPS_FAST`` iters with ``NS_COEFF_FAST``
(3.4445, -4.7750, 2.0315) — drives singular values near 1.
* Polish stage: ``NS_STEPS_POLISH`` iters with ``NS_COEFF_POLISH``
(2.0, -1.5, 0.5) — exact Newton iteration stabilizes sigma at 1.
- Gram path (when ``enable_gram=True`` and the matrix is rectangular):
X_0 = G / ||G||_F
R_k = X_k @ X_k^T
Expand All @@ -53,7 +57,10 @@
RZ_k = a*R_k + R_k @ Z_k
R_{k+1} = a*RZ_k + Z_k @ RZ_k
X_out = Q_last @ X_restart
Uses float32 normalization followed by float16 iteration.
Uses float32 normalization followed by float16 iteration with
five Polar-Express coefficient tuples (kept unchanged; the
Polar-Express recipe already tight-calibrates sigma→1 on its final
step and does not accept an extra Newton polish cleanly).

3. Scaling: scale = coeff * sqrt(max(m, n)) [match-RMS mode]
scale = sqrt(max(1, m/n)) [rectangular mode]
Expand Down Expand Up @@ -93,6 +100,10 @@
for MLIP force-field training.
.. [6] Dao-AILab, "gram-newton-schulz."
https://github.com/Dao-AILab/gram-newton-schulz
.. [7] DeepSeek-AI, "DeepSeek-V4: Towards Highly Efficient Million-Token
Context Intelligence," 2026. Section 2.4 introduces the two-stage
hybrid Newton-Schulz (8 fast + 2 polish iterations) and uses
gamma=0.18 for update-RMS rescaling with AdamW ε=1e-20.
"""

from __future__ import (
Expand All @@ -106,10 +117,17 @@
)

import torch
import torch._dynamo.config as _dynamo_config
from torch.optim.optimizer import (
Optimizer,
)

DYNAMO_CACHE_SIZE_LIMIT = 64
_dynamo_config.cache_size_limit = max(
int(_dynamo_config.cache_size_limit),
DYNAMO_CACHE_SIZE_LIMIT,
)
Comment thread
OutisLi marked this conversation as resolved.

if TYPE_CHECKING:
from collections.abc import (
Callable,
Expand Down Expand Up @@ -150,14 +168,28 @@
# Constants
# ============================================================================

# Newton-Schulz iteration count
NS_STEPS: int = 5
# Numerical stability epsilon for norm clamping and Adam
EPS: float = 1e-7
# Quintic Newton-Schulz polynomial coefficients
NS_COEFF_A: float = 3.4445
NS_COEFF_B: float = -4.7750
NS_COEFF_C: float = 2.0315
# --- Newton-Schulz two-stage iteration schedule (DeepSeek-V4 §2.4) ---
# Fast stage drives singular values close to 1 rapidly; polish stage uses
# exact Newton iteration (a=2, b=-1.5, c=0.5) to stabilize sigma precisely at 1.
# Only the Standard / Flash NS paths use this schedule. The Gram (Polar-
# Express) path is a different orthogonalization recipe and keeps its own
# pre-calibrated 5-step schedule in POLAR_EXPRESS_COEFFICIENTS below.
NS_STEPS_FAST: int = 8
NS_STEPS_POLISH: int = 2
NS_COEFF_FAST: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
NS_COEFF_POLISH: tuple[float, float, float] = (2.0, -1.5, 0.5)
Comment thread
OutisLi marked this conversation as resolved.

# --- Numerical stability epsilons ---
# NS_EPS: guards Frobenius-norm clamp so X_0 = G / ||G||_F stays finite.
# Normal gradients satisfy ||G||_F >> 1e-7, so this never bites in practice.
# ADAM_EPS: Adam denominator ε in ``sqrt(v_hat) + ε``. DeepSeek-V4 uses
# ε=1e-20 so the Adam update is driven by the second-moment estimate rather
# than the floor ε. For MLIP training (SeZM) this matters for ``adam_``
# parameters whose gradient scale spans many orders of magnitude across
# training (e.g. ``adam_ffn_layer_scales`` starting at 1e-3 vs
# ``adam_type_embedding`` at O(1)).
NS_EPS: float = 1e-7
ADAM_EPS: float = 1e-20
# Polar Express coefficients with the safety scaling used in the reference repo
_GRAM_NS_UNMODIFIED_POLAR_EXPRESS_COEFFICIENTS: tuple[
tuple[float, float, float], ...
Expand Down Expand Up @@ -300,8 +332,10 @@ def _flash_newton_schulz_orth(
buf2: torch.Tensor,
) -> torch.Tensor:
"""
Orthogonalize a 2D matrix via quintic Newton-Schulz with triton-accelerated
symmetric matmul. Mathematically equivalent to ``_newton_schulz_orth``.
Orthogonalize a 2D matrix via two-stage Newton-Schulz with triton-accelerated
symmetric matmul. Mathematically equivalent to ``_newton_schulz_orth``
(same DeepSeek-V4 hybrid schedule: ``NS_STEPS_FAST`` fast iters followed
by ``NS_STEPS_POLISH`` polish iters).

Parameters
----------
Expand All @@ -324,16 +358,25 @@ def _flash_newton_schulz_orth(
X = X.transpose(-2, -1)

# === Step 2. Normalize Frobenius norm to at most 1 ===
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS)
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS)

# === Step 3. Newton-Schulz iterations with triton symmetric matmul ===
for _ in range(NS_STEPS):
# === Step 3. Fast stage: drive sigma close to 1 ===
fast_a, fast_b, fast_c = NS_COEFF_FAST
for _ in range(NS_STEPS_FAST):
_matmul_transpose_assign(X, buf1) # buf1 = X @ X.T = A
_matmul_transpose_assign(buf1, buf2) # buf2 = A @ A.T = A² (A symmetric)
B = NS_COEFF_B * buf1 + NS_COEFF_C * buf2
X = NS_COEFF_A * X + B @ X

# === Step 4. Transpose back if needed ===
B = fast_b * buf1 + fast_c * buf2
X = fast_a * X + B @ X

# === Step 4. Polish stage: exact Newton iteration around sigma=1 ===
polish_a, polish_b, polish_c = NS_COEFF_POLISH
for _ in range(NS_STEPS_POLISH):
_matmul_transpose_assign(X, buf1)
_matmul_transpose_assign(buf1, buf2)
B = polish_b * buf1 + polish_c * buf2
X = polish_a * X + B @ X

# === Step 5. Transpose back if needed ===
if transposed:
X = X.transpose(-2, -1)

Expand All @@ -344,12 +387,17 @@ def _newton_schulz_orth(
G: torch.Tensor,
) -> torch.Tensor:
"""
Orthogonalize a 2D matrix via quintic Newton-Schulz iteration.
Orthogonalize a 2D matrix via two-stage Newton-Schulz iteration.

Mathematical formulation:
X_0 = G / ||G||_F
X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T
Coefficients: a=3.4445, b=-4.7750, c=2.0315

Two-stage schedule (DeepSeek-V4 §2.4):
* Fast stage (``NS_STEPS_FAST`` iters, ``NS_COEFF_FAST``):
a=3.4445, b=-4.7750, c=2.0315 — rapid convergence to sigma ≈ 1.
* Polish stage (``NS_STEPS_POLISH`` iters, ``NS_COEFF_POLISH``):
a=2, b=-1.5, c=0.5 — exact Newton iteration pinning sigma to 1.
"""
# === Step 1. Cast to bf16 and transpose tall matrices ===
X = G.to(dtype=torch.bfloat16)
Expand All @@ -358,15 +406,23 @@ def _newton_schulz_orth(
X = X.transpose(-2, -1)

# === Step 2. Normalize Frobenius norm to at most 1 ===
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS)
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS)

# === Step 3. Fast stage: drive sigma close to 1 ===
fast_a, fast_b, fast_c = NS_COEFF_FAST
for _ in range(NS_STEPS_FAST):
A = torch.mm(X, X.transpose(-2, -1))
gram_update = torch.addmm(A, A, A, beta=fast_b, alpha=fast_c)
X = torch.addmm(X, gram_update, X, beta=fast_a, alpha=1.0)

# === Step 3. Newton-Schulz iterations with fused GEMM ===
for _ in range(NS_STEPS):
# === Step 4. Polish stage: exact Newton iteration around sigma=1 ===
polish_a, polish_b, polish_c = NS_COEFF_POLISH
for _ in range(NS_STEPS_POLISH):
A = torch.mm(X, X.transpose(-2, -1))
gram_update = torch.addmm(A, A, A, beta=NS_COEFF_B, alpha=NS_COEFF_C)
X = torch.addmm(X, gram_update, X, beta=NS_COEFF_A, alpha=1.0)
gram_update = torch.addmm(A, A, A, beta=polish_b, alpha=polish_c)
X = torch.addmm(X, gram_update, X, beta=polish_a, alpha=1.0)

# === Step 4. Transpose back if needed ===
# === Step 5. Transpose back if needed ===
if transposed:
X = X.transpose(-2, -1)

Expand All @@ -377,7 +433,11 @@ def _batched_newton_schulz_orth(
G: torch.Tensor,
) -> torch.Tensor:
"""
Orthogonalize a batch of matrices via quintic Newton-Schulz iteration.
Orthogonalize a batch of matrices via two-stage Newton-Schulz iteration.

Uses the same DeepSeek-V4 hybrid schedule as ``_newton_schulz_orth``:
``NS_STEPS_FAST`` fast iters with ``NS_COEFF_FAST`` followed by
``NS_STEPS_POLISH`` polish iters with ``NS_COEFF_POLISH``.

Parameters
----------
Expand All @@ -399,15 +459,23 @@ def _batched_newton_schulz_orth(
X = X.transpose(-2, -1)

# === Step 2. Normalize each slice by Frobenius norm ===
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS)
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS)

# === Step 3. Fast stage: drive sigma close to 1 ===
fast_a, fast_b, fast_c = NS_COEFF_FAST
for _ in range(NS_STEPS_FAST):
A = torch.bmm(X, X.transpose(-2, -1))
gram_update = torch.baddbmm(A, A, A, beta=fast_b, alpha=fast_c)
X = torch.baddbmm(X, gram_update, X, beta=fast_a, alpha=1.0)

# === Step 3. Batched Newton-Schulz iterations ===
for _ in range(NS_STEPS):
# === Step 4. Polish stage: exact Newton iteration around sigma=1 ===
polish_a, polish_b, polish_c = NS_COEFF_POLISH
for _ in range(NS_STEPS_POLISH):
A = torch.bmm(X, X.transpose(-2, -1))
gram_update = torch.baddbmm(A, A, A, beta=NS_COEFF_B, alpha=NS_COEFF_C)
X = torch.baddbmm(X, gram_update, X, beta=NS_COEFF_A, alpha=1.0)
gram_update = torch.baddbmm(A, A, A, beta=polish_b, alpha=polish_c)
X = torch.baddbmm(X, gram_update, X, beta=polish_a, alpha=1.0)

# === Step 4. Restore original orientation ===
# === Step 5. Restore original orientation ===
if transposed:
X = X.transpose(-2, -1)

Expand All @@ -421,7 +489,11 @@ class _GramNewtonSchulzOrthogonalizer:
"""

def __init__(self) -> None:
self.ns_epsilon = float(EPS)
# Gram path uses NS_EPS (same numerical role as Standard NS norm clamp).
# It intentionally does NOT share the smaller ADAM_EPS, because the
# Polar-Express recipe normalizes before its first iteration and a
# looser floor is safer under fp32→fp16 cast.
self.ns_epsilon = float(NS_EPS)
self.ns_coefficients = tuple(
(float(a), float(b), float(c)) for a, b, c in POLAR_EXPRESS_COEFFICIENTS
)
Expand Down Expand Up @@ -801,9 +873,13 @@ class HybridMuonOptimizer(Optimizer):
scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust.
Default is 0.0 (match-RMS scaling).
lr_adjust_coeff : float
Coefficient with default 0.2 for match-RMS scaling when
Coefficient with default 0.18 for match-RMS scaling when
``lr_adjust <= 0``:
``scale = lr_adjust_coeff * sqrt(max(m, n))``.
0.18 is the value calibrated by DeepSeek-V4 so that Muon's
per-element update RMS matches AdamW's typical RMS, enabling
reuse of AdamW learning rates across both paths. The Moonlight
reference uses 0.2; both are empirically viable.
Comment thread
OutisLi marked this conversation as resolved.
muon_mode : str
Muon routing mode with default ``"slice"``.
- ``"2d"``: only 2D parameters are Muon candidates.
Expand All @@ -826,10 +902,12 @@ class HybridMuonOptimizer(Optimizer):
``enable_gram=True``.
Default is True.
magma_muon : bool
Enable Magma-lite damping on Muon updates with default False.
Enable Magma-lite damping on Muon updates with default True.
This computes momentum-gradient cosine alignment per Muon block,
applies EMA smoothing, and rescales Muon updates in [0.1, 1.0].
Adam/AdamW paths are unchanged.
Adam/AdamW paths are unchanged. Empirically beneficial for
MLIP / SeZM training under heavy-tailed gradient noise from
conservative-force (second-order) autograd.

Examples
--------
Expand All @@ -848,12 +926,12 @@ def __init__(
weight_decay: float = 0.001,
adam_betas: tuple[float, float] = (0.9, 0.95),
lr_adjust: float = 0.0,
lr_adjust_coeff: float = 0.2,
lr_adjust_coeff: float = 0.18,
muon_mode: str = "slice",
named_parameters: Iterable[tuple[str, torch.Tensor]] | None = None,
enable_gram: bool = True,
flash_muon: bool = True,
magma_muon: bool = False,
magma_muon: bool = True,
use_foreach: bool | None = None,
) -> None:
# === Step 1. Validate routing mode ===
Expand Down Expand Up @@ -1551,7 +1629,7 @@ def step(
lr_adjust = group["lr_adjust"]
lr_adjust_coeff = group["lr_adjust_coeff"]
enable_gram = bool(group.get("enable_gram", True))
magma_muon = bool(group.get("magma_muon", False))
magma_muon = bool(group.get("magma_muon", True))

# === Step 1. Adam update for non-decay Adam path ===
# === Step 1.1. Collect gradients and initialize state ===
Expand Down Expand Up @@ -1602,7 +1680,11 @@ def step(
bias_corr1 = 1 - state["beta1_pow"]
bias_corr2 = 1 - state["beta2_pow"]
step_size = adam_lr / bias_corr1
denom = (adam_no_decay_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS)
denom = (
(adam_no_decay_exp_avg_sqs[i] / bias_corr2)
.sqrt()
.add_(ADAM_EPS)
)
Comment thread
OutisLi marked this conversation as resolved.
delta_fp32 = -step_size * (adam_no_decay_exp_avgs[i] / denom)
p.add_(delta_fp32.to(p.dtype))

Expand Down Expand Up @@ -1660,7 +1742,9 @@ def step(
bias_corr1 = 1 - state["beta1_pow"]
bias_corr2 = 1 - state["beta2_pow"]
step_size = adam_lr / bias_corr1
denom = (adam_decay_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS)
denom = (
(adam_decay_exp_avg_sqs[i] / bias_corr2).sqrt().add_(ADAM_EPS)
)
delta_fp32 = -step_size * (adam_decay_exp_avgs[i] / denom)
p.add_(delta_fp32.to(p.dtype))

Expand Down
4 changes: 1 addition & 3 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,9 +920,7 @@ def single_model_finetune(
"lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]),
"muon_mode": str(self.opt_param.get("muon_mode", "slice")),
"named_parameters": tuple(self.wrapper.named_parameters()),
"enable_gram": False
if self.is_distributed
else bool(self.opt_param.get("enable_gram")),
"enable_gram": bool(self.opt_param.get("enable_gram")),
"flash_muon": bool(self.opt_param.get("flash_muon")),
"magma_muon": bool(self.opt_param.get("magma_muon")),
# FSDP2 shards parameters as DTensor; several torch._foreach_*
Expand Down
9 changes: 5 additions & 4 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3107,9 +3107,11 @@ def optimizer_hybrid_muon() -> list[Argument]:
"lr_adjust_coeff",
float,
optional=True,
default=0.2,
default=0.18,
doc=doc_only_pt_supported
+ "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.",
+ "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0. "
+ "Default 0.18 follows DeepSeek-V4's calibration so Muon update RMS "
+ "matches AdamW's typical RMS; Moonlight's original recipe uses 0.2.",
),
Argument(
"muon_mode",
Expand All @@ -3130,8 +3132,7 @@ def optimizer_hybrid_muon() -> list[Argument]:
default=True,
doc=doc_only_pt_supported
+ "Enable the compiled Gram Newton-Schulz path for rectangular Muon matrices. "
+ "Square matrices keep using the current standard Newton-Schulz path. "
+ "Automatically disabled in distributed (multi-GPU) training.",
+ "Square matrices keep using the current standard Newton-Schulz path.",
),
Argument(
"flash_muon",
Expand Down
Loading