feat(pt): optimze HybridMuon by borrowing some ideas from deepseek v4 paper#5424
feat(pt): optimze HybridMuon by borrowing some ideas from deepseek v4 paper#5424OutisLi wants to merge 1 commit intodeepmodeling:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the PyTorch HybridMuon optimizer configuration and implementation to align with reported DeepSeek‑V4 calibration choices (two-stage Newton–Schulz schedule and match-RMS coefficient), and adjusts related training/config plumbing.
Changes:
- Change HybridMuon match-RMS scaling default (
lr_adjust_coeff) from 0.2 to 0.18 and update documentation accordingly. - Remove the training-time auto-disable of
enable_gramin distributed mode. - Update HybridMuon’s orthogonalization and Adam epsilon behavior (two-stage Newton–Schulz;
ADAM_EPS=1e-20) and introduce a Dynamo cache size tweak.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
deepmd/utils/argcheck.py |
Updates HybridMuon optimizer defaults/docs (notably lr_adjust_coeff) and removes distributed-disable wording for enable_gram. |
deepmd/pt/train/training.py |
Changes HybridMuon optimizer construction to pass enable_gram directly (no longer forced off under distributed training). |
deepmd/pt/optimizer/hybrid_muon.py |
Implements the DeepSeek‑style two-stage Newton–Schulz schedule, updates defaults/docs, changes Adam epsilon handling, and adjusts Dynamo cache behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
📝 WalkthroughWalkthroughThe pull request updates the HybridMuon optimizer with a two-stage Newton-Schulz orthogonalization schedule (fast and polish phases), introduces distinct epsilon values for Adam operations, adjusts default parameters, and integrates PyTorch Dynamo cache sizing. It also simplifies distributed training logic for the Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
120-129: Avoid mutatingtorch._dynamo.config.cache_size_limitat module import time.Importing this module unconditionally bumps a process-global PyTorch Dynamo setting, even for callers that never instantiate
HybridMuonOptimizer(e.g., serving / inference entry points that merely import the optimizer registry). The cache-size bump is only needed for_GramNewtonSchulzOrthogonalizer._compiled_call(the soletorch.compilesite in this file), so the side effect should be scoped to where it's actually required.♻️ Move the bump into `_GramNewtonSchulzOrthogonalizer.__init__`
import torch -import torch._dynamo.config as _dynamo_config from torch.optim.optimizer import ( Optimizer, ) DYNAMO_CACHE_SIZE_LIMIT = 64 -_dynamo_config.cache_size_limit = max( - int(_dynamo_config.cache_size_limit), - DYNAMO_CACHE_SIZE_LIMIT, -)def __init__(self) -> None: + # Gram NS compiles per-shape; bump Dynamo's cache budget locally so the + # repeated recompilation across MLIP parameter shapes doesn't spill, + # without polluting global state for unrelated torch.compile users. + import torch._dynamo.config as _dynamo_config + + _dynamo_config.cache_size_limit = max( + int(_dynamo_config.cache_size_limit), + DYNAMO_CACHE_SIZE_LIMIT, + ) # Gram path uses NS_EPS (same numerical role as Standard NS norm clamp).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 7a925eec-d673-4562-93b0-abf4226eb0cf
📒 Files selected for processing (3)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5424 +/- ##
==========================================
+ Coverage 82.39% 82.42% +0.03%
==========================================
Files 824 824
Lines 87395 87418 +23
Branches 4197 4197
==========================================
+ Hits 72009 72055 +46
+ Misses 14111 14088 -23
Partials 1275 1275 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
Release Notes