A comprehensive, all-in-one collection of state-of-the-art optimization algorithms for deep learning. Designed for maximum efficiency, minimal memory footprint, and superior performance across diverse model architectures and training scenarios.
pip install adv_optmRequires PyTorch 2.3+ for torch.compile support.
This major update introduces a complete architectural refactor of the library:
🆕 New Optimizers & Scaling
SinkSGD_adv: Added a powerful new optimizer to the lineup.- Spectral Scaling: Now available across all optimizers, achieving width/rank invariant updates for highly stable training.
💾 Memory & State Precision Control
- Granular State Precision (
state_precision): Drastically reduce memory overhead with new optimizer state modes:factored(Rank-2 factored mode)fp32(Full precision)bf16_sr&int8_sr(BF16/Int8 with Stochastic Rounding)
- Factored Second Moment (
factored_2nd): Available for all Adam variants. Works seamlessly alongside anystate_precisionsetting to further slash memory usage.
⚙️ Advanced Dynamics & Momentum
-
Variance Normalized Momentum (
normed_momentum): Applies optimizer normalization before momentum (Normalization then Momentum/NtM). Available forAdamW_adv,SignSGD_adv, andSinkSGD_adv. -
Universal Nesterov Momentum: Replaced the hard-to-tune Simplified_AdEMAMix with Nesterov momentum (
nesterov) and a dedicated coefficient (nesterov_coef) across all optimizers. - Preconditioning & Signs:
-
Improved CANS (
accelerated_ns): Enhanced for Muon variants by integrating a dynamic lower bound. -
New OrthoGrad modes (
orthogonal_gradient): Standard OrthoGradflattenedand a new matrix-wise modeiterative.
⚓ Weight Decay Innovations
- Centered Weight Decay (
centered_wd): Pulls weights toward their pre-train state (anchor). To save memory, anchor precision (centered_wd_mode) can be set to full, float8, int8, or int4. - Fisher Weight Decay (
fisher_wd): Now available for Adam variants based on the FAdam paper. - Geometric Weight Decay: Added specifically for
SinkSGD_advandSignSGD_adv.
(Note: Lion_Prodigy_adv, Simplified_AdEMAMix, and heuristic cautious/grams modes have been deprecated in favor of these superior, theoretically-grounded features).
Click to see older release notes (v1.2.x - v2.1.x)
- New Optimizer: Added Signum (SignSGD with momentum) to the
SignSGD_advfamily.
- ⚡
torch.compileSupport: Fully implemented for all advanced optimizers. Enable viacompiled_optimizer=Trueto heavily fuse and optimize the optimizer step path. - 📉 1-Bit Factored Mode: Vastly improved implementation via
nnmf_factor=True. - 🛠️ Broad performance and stability improvements across all optimizers.
- Advanced Muon Variants: Brought the groundbreaking Muon optimizer into the fold, enriched with features from recent literature.
| Optimizer | Description |
|---|---|
Muon_adv |
Advanced Muon implementation featuring CANS, NorMuon, Low-Rank Orthogonalization, and more. |
AdaMuon_adv |
Combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
- Prodigy Speedup: Prodigy variants are now 50% faster by eliminating unnecessary CUDA syncs (Shoutout to @dxqb!).
- Stochastic Rounding for BF16: Parameter updates and weight decay now accumulate in float32 and round once at the end.
- Cautious Weight Decay: Implemented for all advanced optimizers (Paper).
- Fused Operations: Transitioned to fused and in-place operations wherever possible.
(Documentation expanding on the theory and usage of these features is coming soon!)