Skip to content

Fix FP8 MoE scale patching for compressed-tensors models#551

Open
danielhanchen wants to merge 6 commits intomainfrom
fix/moe-fp8-scale-patching
Open

Fix FP8 MoE scale patching for compressed-tensors models#551
danielhanchen wants to merge 6 commits intomainfrom
fix/moe-fp8-scale-patching

Conversation

@danielhanchen
Copy link
Copy Markdown
Contributor

Summary

Fixes the FP8 MoE scale patching in moe_utils_fp8.py which was silently failing for all compressed-tensors FP8 MoE models (e.g., GLM-4.7-Flash-FP8-Dynamic). Based on the cleanup branch from #548.

The scale patcher (_maybe_patch_glm4_stacked_moe_fp8_scales) was not firing due to two bugs:

  1. Enum vs string comparison: quantization_config.quant_method returns a QuantizationMethod enum in newer transformers, not the string "compressed-tensors". The comparison quant_method != "compressed-tensors" always returned True, causing an early return.

  2. Wrong layer probe: The probe for scale keys in safetensors checked only the first routed layer. Some layers (e.g., layers 1, 39, 46 in GLM-4.7-Flash) are in the quantization ignore list and have no scales. If the first routed layer happened to be ignored, the probe returned False and the entire patcher bailed out.

Additionally, the patcher only supported single-file model.safetensors checkpoints. Sharded checkpoints (model-00001-of-NNNNN.safetensors) would crash with EntryNotFoundError.

Changes

  • Normalize quant_method to string before comparison (handles both enum and string)
  • Scan all routed layers to find those with scale keys, skip layers without
  • Read the full set of available keys from safetensors once instead of probing per-layer
  • Add _resolve_safetensors_files() supporting both single-file and sharded layouts via model.safetensors.index.json
  • Add _open_safetensors_for_keys() that opens only the relevant shard(s)

Test results

Tested on NVIDIA RTX PRO 6000 Blackwell (SM 12.0, 98 GB):

  • GLM-4.7-Flash-FP8-Dynamic: 43 FP8 expert layers get scales attached (previously 0)
  • Forward pass produces valid output, no NaN/Inf
  • 21-step LoRA training completes successfully
  • Normal (non-FP8, non-MoE) models are unaffected -- Llama-3.2-1B 4-bit training passes

Backwards compatibility

  • Works with transformers 4.57.6 (string quant_method) and 5.3.0 (enum quant_method)
  • Falls back gracefully if safetensors files are not found
  • No changes to any other code paths

Datta0 and others added 6 commits March 15, 2026 13:49
The scale patcher was not firing for two reasons:

1. quant_method comparison failed because transformers returns a
   QuantizationMethod enum, not the string "compressed-tensors".
   Normalize to string before comparing.

2. The probe checked only the first routed layer for scale keys,
   but some layers are in the quantization ignore list and have
   no scales. Now scans all layers and filters to those that
   actually have scale keys in the checkpoint.

Also adds support for sharded safetensors checkpoints via
model.safetensors.index.json, so this works with both single-file
and multi-shard models.

Tested on GLM-4.7-Flash-FP8-Dynamic: 43 FP8 expert layers now
correctly get scales attached, forward pass and 21-step LoRA
training both pass with no NaN. Normal (non-FP8) models unaffected.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the functionality and reliability of FP8 Mixture-of-Experts (MoE) models, particularly those utilizing compressed-tensors quantization. By rectifying several underlying issues related to scale application and checkpoint loading, it ensures that these models can now properly initialize and operate, enabling successful training and inference for advanced architectures like GLM-4.7-Flash-FP8-Dynamic. The changes improve compatibility across different transformers versions and handle diverse safetensors storage formats.

Highlights

  • FP8 MoE Scale Patching Fix: Resolved critical issues preventing FP8 MoE scales from being correctly applied to compressed-tensors models, which previously led to silent failures.
  • Improved Quantization Method Handling: Addressed a bug where the quant_method comparison failed due to Enum vs. string type mismatch in newer transformers versions.
  • Robust Layer Probing: Fixed an issue where the scale patcher would fail if the first routed layer was in the quantization ignore list, by scanning all routed layers for valid scale keys.
  • Sharded Checkpoint Support: Implemented support for loading FP8 scales from sharded safetensors checkpoints (e.g., model-00001-of-NNNNN.safetensors), preventing EntryNotFoundError.
Changelog
  • unsloth_zoo/patching_utils.py
    • Imported maybe_patch_stacked_moe_expert_fp8_scales to enable FP8 MoE scale patching.
    • Invoked maybe_patch_stacked_moe_expert_fp8_scales after patch_to_dict to apply FP8 scales during model initialization.
  • unsloth_zoo/temporary_patches/glm4_moe.py
    • Modified patch_function calls for Glm4MoeLiteNaiveMoe and Glm4MoeLiteMoE to use force = True, ensuring the Unsloth MoE backend bypasses transformer's expert forward wrapper.
  • unsloth_zoo/temporary_patches/moe_utils.py
    • Added logic to install moe_utils_fp8.py to the cache if it exists.
    • Introduced _CACHED_MOE_UTILS_FP8_MODULE for caching the FP8 utility module.
    • Updated the module name for _load_cached_moe_utils_module to unsloth_zoo.temporary_patches._cached_moe_utils.
    • Set the __package__ attribute for dynamically loaded modules to ensure correct relative imports.
    • Implemented _load_cached_moe_utils_fp8_module to load the cached FP8 utility module.
    • Modified get_forward_moe_backend to prioritize the FP8-specific backend (forward_moe_backend_fp8) if FP8 weights are detected.
    • Removed a redundant definition of _TORCH_GROUPED_MM_AVAILABLE.
    • Added _try_attach_block_size to safely attach block size attributes to tensors.
    • Introduced _get_base_weight_and_quant_state to retrieve base weights and their quantization states.
    • Added _get_moe_weight_and_quant_state for MoE-specific weight and quantization state retrieval.
    • Implemented _get_grouped_lora and _apply_grouped_lora for handling grouped LoRA weights.
    • Added _expand_grouped_bias for expanding grouped biases.
    • Refactored _patched_param_wrapper_forward to leverage new helper functions and simplify LoRA data handling for MoE experts.
  • unsloth_zoo/temporary_patches/moe_utils_fp8.py
    • Added new file moe_utils_fp8.py to handle FP8 MoE specific functionalities.
    • Implemented _resolve_safetensors_files to support loading safetensors from single files or sharded checkpoints.
    • Developed _open_safetensors_for_keys to efficiently open relevant safetensors shards for specific keys.
    • Created _maybe_patch_glm4_stacked_moe_fp8_scales to correctly patch FP8 scales for GLM4 MoE models, addressing quant_method type and layer probing issues.
    • Provided maybe_patch_stacked_moe_expert_fp8_scales as a public entry point for FP8 scale patching.
    • Included utility functions like _is_float8_tensor, _get_fp8_dequant_target_dtype, and _log_moe_fp8_backend_once.
    • Added _check_torch_scaled_grouped_mm_supported to verify hardware support for torch._scaled_grouped_mm.
    • Implemented FP8 dequantization logic with _slice_fp8_quant_state, _dequantize_expert_slice, and _dequantize_full_expert_weights.
    • Introduced _make_grouped_mm_rhs_column_major for preparing weights for grouped matrix multiplication.
    • Developed helper functions _get_moe_weight_and_quant_info, _extract_scaled_grouped_mm_weight_scale, and _prepare_scaled_grouped_mm_weight for FP8 weight processing.
    • Added _manual_fp8_rowwise_quantize for explicit row-wise FP8 quantization.
    • Implemented _forward_scaled_grouped_mm_fp8 for the FP8 grouped matrix multiplication forward pass.
    • Defined _moe_uses_fp8_expert_weights to determine if MoE modules are using FP8 weights.
    • Introduced _call_with_temporary_moe_weights to facilitate temporary weight replacement during forward passes.
    • Implemented _slice_fp8_linear_quant_state for slicing FP8 linear quantization states.
    • Developed _forward_native_fp8_expert_loop for direct FP8 expert processing and _forward_native_moe_loop_fp8 as a fallback.
    • Integrated various FP8 MoE backends including forward_grouped_mm_fp8, forward_moe_backend_fp8, and forward_triton_grouped_gemm_fp8.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the FP8 MoE scale patching for compressed-tensors models, which was silently failing. The changes include normalizing the quant_method, scanning all routed layers, supporting sharded checkpoints, and adding helper functions for safetensors file resolution. The code has been reviewed, and suggestions have been made to improve error handling and code clarity.

Comment on lines +206 to +207
if any(t is None for t in (gate, gate_scale, up, up_scale, down, down_scale)):
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

high: The function returns False if any of the tensors are None. This could lead to the patch not being applied even if some of the tensors are available. It might be better to log a warning and continue with the available tensors, or raise an exception to halt execution.


try:
module_name = "unsloth_cached_moe_utils"
module_name = "unsloth_zoo.temporary_patches._cached_moe_utils"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

medium: Consider adding a more descriptive name to the module to avoid potential naming conflicts. Using a more specific name like _cached_moe_utils_module would enhance clarity.

module_name = "unsloth_zoo.temporary_patches._cached_moe_utils_module"

Comment on lines +131 to +132
except Exception:
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

medium: The except Exception block is too broad. It would be better to catch specific exceptions like FileNotFoundError to avoid masking unexpected errors.

References
  1. Avoid using try...except ImportError for libraries that are required dependencies of the project, as the check is redundant.

Comment on lines +559 to +560
except RuntimeError:
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

medium: Consider adding a check to ensure that offsets is not None before using it. This can prevent potential AttributeError exceptions if offsets is unexpectedly None.

Comment on lines +739 to +740
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

medium: The except Exception block is too broad. It would be better to catch specific exceptions like RuntimeError to avoid masking unexpected errors.

References
  1. Avoid using try...except ImportError for libraries that are required dependencies of the project, as the check is redundant.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 001479bec6

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

)


def forward_moe_backend_fp8(self, hidden_states, top_k_index, top_k_weights):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve forward annotations for strict MoE patching

get_forward_moe_backend now routes to forward_moe_backend_fp8, but this function’s signature has no torch.Tensor annotations, and patch_function uses strict fingerprint matching (including annotations) by default. In practice, the unannotated replacement is rejected for annotated expert forwards like DeepseekV3NaiveMoe and Qwen3MoeExperts (patched without force=True), so those models silently skip the backend patch and lose the intended MoE/LoRA path.

Useful? React with 👍 / 👎.


_CACHED_FORWARD_MOE_BACKEND = None
_CACHED_MOE_UTILS_MODULE = None
_CACHED_MOE_UTILS_FP8_MODULE = None
module_name = "unsloth_zoo.temporary_patches._cached_moe_utils_fp8"
module = sys.modules.get(module_name, None)
if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file:
_CACHED_MOE_UTILS_FP8_MODULE = module
module.__package__ = "unsloth_zoo.temporary_patches"
sys.modules[module_name] = module
spec.loader.exec_module(module)
_CACHED_MOE_UTILS_FP8_MODULE = module
return _TORCH_SCALED_GROUPED_MM_SUPPORTED

if not _TORCH_SCALED_GROUPED_MM_AVAILABLE:
_TORCH_SCALED_GROUPED_MM_SUPPORTED = False
_TORCH_SCALED_GROUPED_MM_SUPPORTED = False
return False
if not torch.cuda.is_available():
_TORCH_SCALED_GROUPED_MM_SUPPORTED = False
# context. Keep the FP8 scaled_grouped_mm path off on pre-Hopper parts.
major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device())
if major < 9:
_TORCH_SCALED_GROUPED_MM_SUPPORTED = False
is_2d_input = hidden_states.dim() == 2
if is_2d_input:
sequence_length, hidden_dim = hidden_states.shape
batch_size = 1
return
try:
tensor_like.block_size = block_size
except (AttributeError, RuntimeError):
return forward_moe_backend_fp8(
self, hidden_states, top_k_index, top_k_weights
)
except Exception:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants