Skip to content

[MoE] FP8 support for MoE, specifically GLM 4.7 flash#548

Open
Datta0 wants to merge 12 commits intounslothai:mainfrom
Datta0:moe_fp8
Open

[MoE] FP8 support for MoE, specifically GLM 4.7 flash#548
Datta0 wants to merge 12 commits intounslothai:mainfrom
Datta0:moe_fp8

Conversation

@Datta0
Copy link
Copy Markdown
Collaborator

@Datta0 Datta0 commented Mar 16, 2026

grouped_mm is not a supported operation on FP8, so we try to make use of _scaled_grouped_mm, but that is only supported on newer GPU architectures. If the operation is not supported, we de-quantize the tensors on the fly and pass it to the same grouped_mm function.
There seems to be some issue where transformers are not very keen on loading the weight scales for the FPI tensors, so we needed to patch it up. Currently, this is very specific to GLM; still not sure about Qwen3 MoE or other MOEs.

Note that this is only tested on B200 so far, and the losses seem to pretty closely match up with the BF16 training counterpart.

To Test:

  1. Use Qwen3 MoE notebook but replace the model name with Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8. The memory usage should be <40GB
  2. Use GLM 4.7 flash notebook but replace the model name with unsloth/GLM-4.7-Flash-FP8-Dynamic. Same as above
  3. Sanity: Ensure the BF16 variants (unchanged notebooks) work too
  4. Ensure that Qwen3 8B dense fp8 grpo works

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Mixture-of-Experts (MoE) implementation by adding robust support for FP8 quantization. It introduces a flexible strategy to leverage torch._scaled_grouped_mm for performance gains on newer hardware, while also providing intelligent fallbacks for broader compatibility. A critical patch is included to correctly handle FP8 weight scales for GLM4-MoE Lite models, ensuring proper functionality and training stability. These changes aim to optimize MoE operations for efficiency and reliability across various hardware configurations.

Highlights

  • FP8 Support for MoE: Introduced comprehensive support for FP8 (Float8) quantized Mixture-of-Experts (MoE) models, specifically targeting GLM 4.7 flash, by integrating torch._scaled_grouped_mm for efficient FP8 operations on compatible GPU architectures (e.g., B200).
  • Dynamic Grouped Matrix Multiplication (GEMM) Strategy: Implemented a dynamic fallback mechanism for grouped GEMM operations. If torch._scaled_grouped_mm is not supported or compatible with the FP8 weights, the system intelligently dequantizes only the routed experts to a higher precision (e.g., BF16) before performing the grouped GEMM, ensuring broader compatibility while maintaining performance where possible.
  • GLM4-MoE Weight Scale Patching: Addressed a specific issue in GLM4-MoE Lite models where transformers might not correctly load FP8 weight_scale tensors for stacked expert parameters. A new utility function maybe_patch_stacked_moe_expert_fp8_scales was added to manually load and attach these missing scales from safetensors files.
  • Improved Logging and Error Handling: Added a _log_warn_once utility to provide informative warnings to users regarding the MoE backend selection and FP8 compatibility, helping to diagnose issues related to hardware support or weight metadata.
Changelog
  • unsloth_zoo/temporary_patches/glm4_moe.py
    • Adjusted whitespace for improved code readability.
  • unsloth_zoo/temporary_patches/moe_utils.py
    • Added _WARNED_MOE_MESSAGES set and _log_warn_once function for controlled warning messages.
    • Introduced global flags _TORCH_SCALED_GROUPED_MM_AVAILABLE and _TORCH_SCALED_GROUPED_MM_SUPPORTED to track _scaled_grouped_mm availability and support.
    • Implemented _check_torch_scaled_grouped_mm_supported to programmatically verify GPU architecture (e.g., SM90 for FP8) and runtime support for _scaled_grouped_mm.
    • Added utility functions _is_float8_tensor, _get_fp8_dequant_target_dtype, and _build_active_expert_grouping for FP8 tensor inspection and expert grouping.
    • Created helper functions _get_base_weight_and_quant_state, _get_moe_weight_and_quant_state, and _get_moe_weight_and_quant_info to extract expert weights and their associated FP8 quantization metadata.
    • Developed functions _slice_fp8_quant_state, _dequantize_expert_slice, _dequantize_full_expert_weights, and _dequantize_active_expert_weights to handle slicing and dequantization of FP8 expert weights.
    • Added _make_grouped_mm_rhs_column_major for weight tensor reshaping and _extract_scaled_grouped_mm_weight_scale to retrieve scales for _scaled_grouped_mm.
    • Introduced _prepare_scaled_grouped_mm_weight and _quantize_inputs_for_scaled_grouped_mm to prepare weights and inputs for _scaled_grouped_mm.
    • Wrapped torch._scaled_grouped_mm in _scaled_grouped_mm_with_backward_fix.
    • Implemented _manual_fp8_rowwise_quantize for custom row-wise FP8 quantization.
    • Added _moe_uses_fp8_expert_weights to detect if MoE layers are using FP8 weights.
    • Refactored forward_native_grouped_mm to include FP8 handling logic, prioritizing _scaled_grouped_mm and falling back to active expert dequantization or native loop.
    • Introduced _forward_native_grouped_mm_active_dequant and _forward_native_grouped_mm_scaled_fp8 as new forward pass implementations for FP8 MoE.
    • Modified _forward_triton_grouped_gemm_impl to accept optional gate_up_proj and down_proj arguments, enabling pre-dequantized weights.
    • Updated forward_triton_grouped_gemm to incorporate FP8 handling, routing to forward_native_grouped_mm or performing full expert dequantization before calling the Triton implementation.
  • unsloth_zoo/temporary_patches/moe_utils_fp8.py
    • Added new file moe_utils_fp8.py to handle FP8 specific utilities.
    • Implemented _maybe_patch_glm4_stacked_moe_fp8_scales to patch missing FP8 scale tensors for GLM4-MoE Lite models by loading them from safetensors.
    • Provided maybe_patch_stacked_moe_expert_fp8_scales as a generic entry point for model-specific FP8 scale patching.
Activity
  • The author, Datta0, has implemented FP8 support for MoE, specifically for GLM 4.7 flash models.
  • The implementation addresses the limitation of grouped_mm not supporting FP8 by utilizing _scaled_grouped_mm where available (newer GPU architectures).
  • A fallback mechanism is in place to de-quantize tensors on the fly and pass them to grouped_mm if _scaled_grouped_mm is not supported.
  • The author noted and patched an issue where transformers was not loading weight scales for FP8 tensors, particularly for GLM models.
  • Initial testing has been conducted on B200 GPUs, showing that losses closely match those from BF16 training counterparts.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces FP8 support for Mixture-of-Experts (MoE) models, with a specific focus on GLM-4. The changes include new forward paths to handle FP8 tensors, leveraging torch._scaled_grouped_mm on compatible hardware and providing a fallback that dequantizes weights on the fly for other GPUs. A patch is also included to correctly load FP8 scales for GLM-4 models. The implementation is comprehensive, addressing various hardware capabilities. My review feedback centers on enhancing maintainability by reducing code duplication in the newly added forward paths, improving the specificity of exception handling to prevent masking potential issues, and suggesting a refactoring of a lengthy function to improve its readability.

Comment on lines +1209 to +1379
def _forward_native_grouped_mm_active_dequant(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> Optional[torch.Tensor]:
"""
FP8 compatibility path: dequantize only routed experts, then run grouped_mm.
Falls back to None when the expert quant metadata cannot be interpreted safely.
"""
# This Unsloth Zoo code section is licensed under AGPL3

if not hasattr(self, "gate_up_proj") or not hasattr(self, "down_proj"):
return None

is_2d_input = hidden_states.dim() == 2
if is_2d_input:
sequence_length, hidden_dim = hidden_states.shape
batch_size = 1
else:
batch_size, sequence_length, hidden_dim = hidden_states.shape

input_dtype = hidden_states.dtype
hidden_states = hidden_states.view(-1, hidden_dim)

flat_top_k = top_k_index.view(-1)
num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int()
sorted_indices = torch.argsort(flat_top_k, stable=True)
token_indices = sorted_indices // top_k_index.shape[-1]
permuted_input = hidden_states[token_indices]

active_expert_ids, active_counts, offsets = _build_active_expert_grouping(num_tokens_per_expert)
if active_expert_ids.numel() == 0:
return torch.zeros_like(hidden_states) if is_2d_input else hidden_states.new_zeros(batch_size, sequence_length, hidden_dim)

target_dtype = _get_fp8_dequant_target_dtype(permuted_input)
model_type = getattr(self, "_unsloth_model_type", None)
use_separated_lora = _should_use_separated_lora()

gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj")
down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj")
gate_up_weight = _dequantize_active_expert_weights(
gate_up_base,
gate_up_quant,
active_expert_ids,
target_dtype,
"gate_up",
hidden_dim,
model_type,
)
down_weight = _dequantize_active_expert_weights(
down_base,
down_quant,
active_expert_ids,
target_dtype,
"down",
hidden_dim,
model_type,
)
if gate_up_weight is None or down_weight is None:
return None

permuted_input = permuted_input.to(target_dtype)
mm1_out = _grouped_mm_with_backward_fix(permuted_input, gate_up_weight.contiguous(), offsets)

gate_up_lora = None
if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None:
gate_up_lora = self._unsloth_lora_gate_up_proj[:3]
elif use_separated_lora and _has_lora_adapters(self.gate_up_proj):
gate_up_lora = _extract_lora_weights(
self.gate_up_proj, num_experts=self.num_experts, experts_module=self
)

if gate_up_lora is not None:
first_weight, second_weight, scaling = gate_up_lora
active_expert_ids_device = active_expert_ids.to(first_weight.device)
first_weight = first_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous()
second_weight = second_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous()
lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets).contiguous()
try:
lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets)
except RuntimeError:
lora_delta = torch.empty(
(lora_out.shape[0], second_weight.shape[-1]),
dtype=lora_out.dtype,
device=lora_out.device,
)
cpu_offsets = offsets.cpu().tolist()
prev_offset = 0
for i, end in enumerate(cpu_offsets):
if prev_offset < end:
lora_delta[prev_offset:end] = torch.matmul(
lora_out[prev_offset:end], second_weight[i]
)
prev_offset = end
mm1_out = mm1_out + lora_delta * scaling

if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None:
bias_indices = active_expert_ids.to(self.gate_up_proj_bias.device)
bias_expanded = self.gate_up_proj_bias.index_select(0, bias_indices).repeat_interleave(
active_counts.to(self.gate_up_proj_bias.device), dim=0
)
mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype)

if "GptOssExperts" in self.__class__.__name__:
gate = mm1_out[..., ::2]
up = mm1_out[..., 1::2]
limit = getattr(self, "limit", 7.0)
alpha = getattr(self, "alpha", 1.702)
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
inter = (up + 1.0) * (gate * torch.sigmoid(gate * alpha))
else:
gate, up = mm1_out.chunk(2, dim=-1)
inter = F.silu(gate) * up

mm2_out = _grouped_mm_with_backward_fix(inter, down_weight.contiguous(), offsets)

down_lora = None
if getattr(self, "_unsloth_lora_down_proj", None) is not None:
down_lora = self._unsloth_lora_down_proj[:3]
elif use_separated_lora and _has_lora_adapters(self.down_proj):
down_lora = _extract_lora_weights(
self.down_proj, num_experts=self.num_experts, experts_module=self
)

if down_lora is not None:
first_weight, second_weight, scaling = down_lora
active_expert_ids_device = active_expert_ids.to(first_weight.device)
first_weight = first_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous()
second_weight = second_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous()
lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets).contiguous()
try:
lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets)
except RuntimeError:
lora_delta = torch.empty(
(lora_out.shape[0], second_weight.shape[-1]),
dtype=lora_out.dtype,
device=lora_out.device,
)
cpu_offsets = offsets.cpu().tolist()
prev_offset = 0
for i, end in enumerate(cpu_offsets):
if prev_offset < end:
lora_delta[prev_offset:end] = torch.matmul(
lora_out[prev_offset:end], second_weight[i]
)
prev_offset = end
mm2_out = mm2_out + lora_delta * scaling

if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None:
bias_indices = active_expert_ids.to(self.down_proj_bias.device)
bias_expanded = self.down_proj_bias.index_select(0, bias_indices).repeat_interleave(
active_counts.to(self.down_proj_bias.device), dim=0
)
mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype)

flat_weights = top_k_weights.view(-1)
permuted_weights = flat_weights[sorted_indices]
mm2_out = mm2_out * permuted_weights.unsqueeze(-1)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=input_dtype,
device=hidden_states.device,
)
final_hidden_states.index_add_(0, token_indices, mm2_out.to(input_dtype))

if is_2d_input:
return final_hidden_states
return final_hidden_states.view(batch_size, sequence_length, hidden_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function, along with _forward_native_grouped_mm_scaled_fp8 and the original forward_native_grouped_mm, contains a significant amount of duplicated code. For example, the logic for handling LoRA projections, applying biases, and the final scatter-add operation is nearly identical across these functions. This extensive duplication increases the maintenance burden and the risk of introducing inconsistencies. I recommend refactoring the common logic into helper functions. For instance, the LoRA application logic could be extracted into a single helper that is called from all three forward passes, which would make the codebase more modular and easier to manage.

References
  1. Refactor duplicated logic into shared helper functions to avoid code duplication and improve maintainability.

Comment on lines +85 to +108
for expert_idx in range(num_experts):
gate = file.get_tensor(
f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight"
)
gate_scale = file.get_tensor(
f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight_scale"
)
up = file.get_tensor(
f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight"
)
up_scale = file.get_tensor(
f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight_scale"
)
down = file.get_tensor(
f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight"
)
down_scale = file.get_tensor(
f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight_scale"
)

gate_up_rows.append(torch.cat([gate, up], dim = 0))
down_rows.append(down)
gate_up_scales.append(torch.cat([gate_scale, up_scale], dim = 0))
down_scales.append(down_scale)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function _maybe_patch_glm4_stacked_moe_fp8_scales is quite long, which can make it difficult to read and maintain. The loop over num_experts (lines 85-108), which is responsible for loading multiple tensors for each expert from the safetensors file, could be extracted into a separate helper function. Creating a helper such as _load_and_prepare_expert_tensors(file, layer_idx, expert_idx) would encapsulate this logic, thereby improving the overall readability and modularity of the code.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b4bcdef7a4

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +711 to +712
if expert_quant_state is None:
return expert_weight.to(target_dtype)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fall back when FP8 quant scales are missing

In _dequantize_expert_slice, FP8 weights with missing expert_quant_state are directly cast via expert_weight.to(target_dtype) instead of treated as an unsupported dequantization case. For compressed-tensors checkpoints where weight_scale tensors were not attached, this interprets quantized FP8 codes as real weights and silently produces incorrect MoE activations, while bypassing the intended fallback logic that should trigger when metadata is insufficient.

Useful? React with 👍 / 👎.

Comment on lines +205 to +208
major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device())
if major != 9:
_TORCH_SCALED_GROUPED_MM_SUPPORTED = False
return False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Remove SM90-only gate before scaled_grouped_mm probe

_check_torch_scaled_grouped_mm_supported returns False for every GPU whose compute capability major is not 9 before running the runtime probe. That means the FP8 scaled grouped-mm path is forcibly disabled on other architectures even when torch._scaled_grouped_mm exists and could work, so these runs always fall back to the slower dequantization path.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 64491aa59d

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

# Cache model dimensions and kernel configs on first call
if self._unsloth_moe_configs is None:
intermediate_dim = self.gate_up_proj.shape[1] // 2
intermediate_dim = gate_up_proj.shape[1] // 2
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Compute Triton intermediate dim from actual GEMM weight axis

In the FP8 Triton fallback, gate_up_proj is passed from _dequantize_full_expert_weights, which already applies preprocess_weight and commonly yields a (E, hidden_dim, 2*intermediate) layout. Using gate_up_proj.shape[1] // 2 here therefore derives hidden_dim/2 instead of the real intermediate size, and that incorrect value is sent into get_or_autotune_moe_kernels. On FP8 models running the Triton path (e.g., when torch._grouped_mm is unavailable), this can cache/autotune kernels for the wrong matrix shapes and trigger grouped-GEMM runtime failures.

Useful? React with 👍 / 👎.

@danielhanchen
Copy link
Copy Markdown
Contributor

Review: FP8 MoE Support (zoo PR #548)

Tested on NVIDIA RTX PRO 6000 Blackwell (SM 12.0, 98 GB VRAM), PyTorch 2.10.0, Transformers 5.3.0. Also reviewed the cleanup branch (datta0/moe_fp8_cleanup).

Cleanup Branch Recommendation

The cleanup branch is significantly better than the original PR and should be preferred. It fixes several bugs and has better code organization. However, some issues remain in both versions.

Bugs Found

Critical (P1):

  1. moe_utils.py:206 -- SM gate blocks SM 10.0: if major != 9 only permits Hopper. PyTorch reports _scaled_grouped_mm support on [9.0, 10.0]. Fixed in cleanup (< 9), which correctly allows the runtime probe to execute for SM >= 9 and lets the try/except catch unsupported devices (like our SM 12.0).

  2. moe_utils.py:720-724 -- Numerically incorrect dequantization: _dequantize_expert_slice when expert_quant_state is None does .to(target_dtype) on raw FP8 values without applying the weight scale. This produces numerically wrong results (off by ~100-400x) with no crash or error. Fixed in cleanup (returns None, correctly falling back).

Medium (P2):

  1. moe_utils_fp8.py:63-72 -- Sharded checkpoints not supported: _maybe_patch_glm4_stacked_moe_fp8_scales hardcodes model.safetensors filename. Sharded checkpoints (model-00001-of-NNNNN.safetensors) fail with EntryNotFoundError (remote) or FileNotFoundError (local). Not fixed in cleanup. Fix: parse model.safetensors.index.json to find relevant shards.

  2. moe_utils_fp8.py:57 -- Scale patching inoperative: The dtype check experts.gate_up_proj.dtype == torch.float8_e4m3fn always returns False because compressed_tensors decompresses weights to BF16 during from_pretrained(). FP8 scales from the checkpoint appear as UNEXPECTED keys and are never attached. The entire scale-patching feature is dead code. Not fixed in cleanup.

  3. moe_utils.py -- Missing torch.cuda.synchronize() in SM probe: The _check_torch_scaled_grouped_mm_supported probe calls _scaled_grouped_mm inside try/except but CUDA kernels launch asynchronously. Without sync, the error may not be caught and _TORCH_SCALED_GROUPED_MM_SUPPORTED is cached as True incorrectly. Fixed in cleanup.

  4. moe_utils.py:2185-2225 -- forward_native_moe_loop assumes 2D input: Doesn't flatten 3D hidden_states/top_k_index like forward_native_grouped_mm does. 3D input (batch, seq, hidden) causes wrong expert routing via incorrect F.one_hot + permute(2, 1, 0). Not fixed in cleanup.

Minor (P3):

  1. moe_utils.py:1285-1310 -- Memory pressure in active-dequant: No del gate_up_weight after computing mm1_out. Both bf16 temporaries (gate_up and down) held alive simultaneously, doubling memory pressure for active experts. Adding del gate_up_weight after the first grouped_mm would release the buffer before allocating the next.

  2. moe_utils_fp8.py:85-125 -- Memory spike from re-loading weights: Scale patcher re-reads all per-expert weights from safetensors and constructs new stacked nn.Parameter objects while originals are still in GPU memory, temporarily doubling expert weight memory.

SM 12.0 Code Path Trace

On Blackwell SM 12.0:

  • select_moe_backend() returns "grouped_mm" (correct)
  • torch._grouped_mm works for BF16 (correct)
  • _forward_native_grouped_mm_scaled_fp8 returns None immediately due to SM gate (correct outcome on SM 12.0)
  • Falls to _forward_native_grouped_mm_active_dequant: dequantizes routed experts to bf16, runs grouped_mm (correct and functional)
  • FP8 MoE effectively runs as BF16 since compressed_tensors decompresses weights during loading

Cleanup Branch Fixes Summary

Bug Fixed?
SM gate != 9 Yes (< 9)
Numerically incorrect dequant Yes (returns None)
Missing cuda sync Yes
Scale attribute aliases Yes
Module caching for relative imports Yes
force=True for newer transformers Yes
Sharded safetensors No
Inoperative scale patching No
3D input in native loop No
Memory pressure in dequant No

@danielhanchen
Copy link
Copy Markdown
Contributor

Investigation: Why dequant doesn't work on cleanup branch

Traced the full code path on the cleanup branch. Here is the root cause chain:

The Problem

The dequant path works fine on our Blackwell GPU, but only because compressed_tensors decompresses FP8 weights to BF16 during from_pretrained(). The weights arrive as BF16 with no scales, so _moe_uses_fp8_expert_weights() returns False and the code takes the normal non-FP8 forward_native_grouped_mm path. It "works" but is effectively not doing FP8 at all.

When weights actually ARE FP8 (e.g., on Hopper where compressed_tensors might keep them in float8), the dequant fails because the scales are never attached.

Root Cause Chain

  1. compressed_tensors decompresses FP8 -> BF16 during AutoModelForCausalLM.from_pretrained() -- by the time the model object exists, expert weights have dtype=torch.bfloat16, not torch.float8_e4m3fn

  2. Scale keys are dropped as UNEXPECTED during loading -- the checkpoint has model.layers.{N}.mlp.experts.gate_up_proj_scale and down_proj_scale, but transformers' state dict loading doesn't know how to attach these to the stacked nn.Parameter (they're not registered parameters/buffers on the experts module). They appear in the load report as UNEXPECTED and are silently discarded.

  3. maybe_patch_stacked_moe_expert_fp8_scales never fires -- it checks experts.gate_up_proj.dtype == torch.float8_e4m3fn at moe_utils_fp8.py:57 (cleanup line ~37 in the original). Since weights are already BF16, this returns False immediately. The scale re-attachment code never runs.

  4. Without scales, all dequant paths return None:

    • _get_moe_weight_and_quant_info finds quant_state=None (no *_weight_scale, *_scale_inv, *_scale attributes on experts module)
    • _dequantize_full_expert_weights -> _slice_fp8_quant_state returns None -> _dequantize_expert_slice returns None -> full function returns None
    • forward_grouped_mm_fp8 sees None -> falls to _forward_native_moe_loop_fp8
    • _forward_native_fp8_expert_loop calls fp8_linear which calls fbgemm_fp8_linear -> crashes if fbgemm genai ops are unavailable

The Fix

The scale patching must happen BEFORE or independently of the dtype check. Two approaches:

Option A: Intercept loading to prevent decompression of MoE stacked experts

  • Tell compressed_tensors to skip decompression for modules matching *.mlp.experts.* patterns
  • This keeps weights as float8 and preserves scales during loading
  • Requires understanding compressed_tensors' ignore config

Option B: Patch scales regardless of weight dtype

  • In _maybe_patch_glm4_stacked_moe_fp8_scales, remove the dtype == float8_e4m3fn check
  • Instead check if the checkpoint has scale keys by reading model.safetensors metadata
  • Attach scales even to BF16 weights -- the dequant path will still work (it just becomes a no-op multiply)
  • This is simpler but wasteful

Option C: Keep FP8 weights by using quantization_status="frozen" instead of decompressing

  • Configure the compressed_tensors quantizer to keep weights compressed but mark them as frozen
  • This is the cleanest fix but requires changes to how the model is loaded

For any option, the _maybe_patch_glm4_stacked_moe_fp8_scales function also needs to handle sharded safetensors (support model.safetensors.index.json).

Debug Evidence

Layer 1:
  gate_up_proj: dtype=torch.bfloat16, shape=torch.Size([64, 3072, 2048])
  Scale attrs: []           <-- no scales attached
  gate_up_proj.quant_state: None

  _is_float8_tensor(gate_up_proj): False
  _moe_uses_fp8_expert_weights(experts): False
  _get_moe_weight_and_quant_info('gate_up_proj'):
    weight dtype: torch.bfloat16, quant_state: None, quant_kind: None

The UNEXPECTED keys from the load report confirm scales exist in checkpoint but are dropped:

model.layers.{2...45}.mlp.experts.gate_up_proj_scale | UNEXPECTED
model.layers.{2...45}.mlp.experts.down_proj_scale    | UNEXPECTED

@Datta0
Copy link
Copy Markdown
Collaborator Author

Datta0 commented Mar 16, 2026

The investigation done by the model is wrong lol.
Layer 1 is BF16 while the others are FP8. I have fixed it in the cleanup branch and will merge to here in a while.
image

image

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 20f56f8fea

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +141 to +144
fp8_module = _load_cached_moe_utils_fp8_module()
if fp8_module is not None and hasattr(fp8_module, "forward_moe_backend_fp8"):
_CACHED_FORWARD_MOE_BACKEND = fp8_module.forward_moe_backend_fp8
return _CACHED_FORWARD_MOE_BACKEND
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Route non-FP8 MoE through standard backend selector

get_forward_moe_backend now prioritizes forward_moe_backend_fp8 whenever the cached FP8 module exists, and install_to_cache creates that cache by default, so non-FP8 MoE models are also forced through the FP8 path. In that path, the code dequantizes/stacks full expert weights on every call and runs under @torch.compiler.disable, which introduces a major per-step performance regression for regular bf16/fp16 MoE training and inference.

Useful? React with 👍 / 👎.

Comment on lines +454 to +456
def create_block_mask_wrapper(*args, **kwargs):
kwargs["_compile"] = False
return torch_create_block_mask(*args, **kwargs)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard _compile kwarg injection for older torch versions

The new create_block_mask_wrapper always sets kwargs["_compile"] = False before delegating to torch_create_block_mask; on torch builds where create_block_mask does not accept _compile, this raises an unexpected-keyword TypeError during mask creation. The previous signature-based compatibility guard was removed, so this is now a runtime break for those environments.

Useful? React with 👍 / 👎.

Comment on lines +66 to +68
safetensors_path = hf_hub_download(
repo_id = model_name,
filename = "model.safetensors",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Support sharded safetensors in FP8 scale recovery

The GLM4 FP8 scale patch hardcodes filename="model.safetensors" when downloading checkpoint weights, but many large checkpoints are distributed only as sharded safetensors (model-0000x-of-... plus an index). In those repos this call fails and aborts patching instead of recovering scales, so affected models cannot use the intended remediation path.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 48c2e01790

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +620 to +621
gate_up_param = nn.Parameter(gate_up_proj, requires_grad=old_gate_up.requires_grad)
down_param = nn.Parameter(down_proj, requires_grad=old_down.requires_grad)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve gradients when installing temporary dequantized weights

Wrapping gate_up_proj/down_proj in new nn.Parameter objects here detaches them from the dequantization graph, so gradients (if experts are trainable) land on temporary parameters that are discarded in finally instead of updating the original MoE weights. This makes the FP8 dequant fallback silently non-trainable for full-finetune or expert-unfrozen runs.

Useful? React with 👍 / 👎.

routing_weights = routing_weights.to(router_input.dtype)

final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
final_hidden_states = self.experts(router_input, selected_experts, routing_weights)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep expert inputs in expert-weight dtype

router_input is cast to the gate weight dtype and then reused for self.experts(...). On checkpoints where router weights are fp32 but expert weights are bf16/fp16, this passes fp32 activations into the MoE backend and can trigger dtype-mismatch failures in grouped GEMM paths (or unnecessary fp32 execution). The dtype cast should stay scoped to router computation, while experts should consume the original hidden-state tensor.

Useful? React with 👍 / 👎.

@danielhanchen
Copy link
Copy Markdown
Contributor

Updated Review: FP8 MoE Support (zoo PR #548)

Tested on NVIDIA RTX PRO 6000 Blackwell (SM 12.0, 98 GB VRAM), PyTorch 2.10.0, Transformers 5.3.0.

Correction: FP8 IS Active

The previous review's conclusion that FP8 was inoperative was incorrect -- it only checked layer 1 (BF16 by design, in the ignore list). All MoE layers 2-38 and 40-45 are float8_e4m3fn with scales attached. Memory after load is 30.30 GB (BF16 would be ~60 GB). @Datta0 was right.

What Works

  • FP8 weights load correctly with per-channel scales (gate_up_proj_weight_scale, down_proj_scale)
  • _check_torch_scaled_grouped_mm_supported() correctly returns False on SM 12.0 (gate restricts to SM 9.x)
  • Dequant-plus-grouped_mm fallback path works: training completes 61 steps, losses decrease normally [1.463 -> 1.065], zero NaN/Inf
  • Peak training memory 35.92 GB confirms FP8 storage advantage

Backend Selection on SM 12.0

_check_torch_scaled_grouped_mm_supported() = False
torch._scaled_grouped_mm available: True (but SM gate blocks it)
Backend: dequant + torch._grouped_mm (BF16)

Bugs Found

Critical:

ID File:Line Description
B7 moe_utils_fp8.py:62-71 _maybe_patch_glm4_stacked_moe_fp8_scales hardcodes model.safetensors. Sharded models (model-00001-of-NNNNN.safetensors) raise FileNotFoundError. No model.safetensors.index.json parsing.
B8 moe_utils_fp8.py:249-337 _dequantize_expert_slice applies w * s directly for weight_scale_inv -- inverse scales are not inverted. Contrast with _slice_fp8_linear_quant_state (line 572) which correctly calls .reciprocal(). Produces numerically wrong outputs for any model storing weight_scale_inv.
B12 moe_utils.py:140-153 get_forward_moe_backend now unconditionally returns forward_moe_backend_fp8. Non-FP8 MoE models (plain BF16 Mixtral, DeepSeek) are routed through FP8 code paths, causing unnecessary overhead and potential crashes.

High:

ID File:Line Description
B13 moe_utils_fp8.py Temporary nn.Parameter swap in _call_with_temporary_moe_weights severs gradients to original expert weights during training.
B18 moe_utils_fp8.py:580 Unguarded from unsloth.kernels.fp8 import fp8_linear in _forward_native_fp8_expert_loop -- raises ImportError in fallback path if unsloth not installed.
B19 moe_utils_fp8.py:596-598 _forward_native_fp8_expert_loop crashes with 3D top_k_index -- hidden_states is reshaped to 2D but top_k_index is not, causing permute dimension mismatch.
B23 moe_utils_fp8.py:341 _make_grouped_mm_rhs_column_major is a no-op (double transpose). Does not produce the column-major layout that _scaled_grouped_mm expects on Hopper. Should be weight.contiguous().transpose(-2, -1).
B24 moe_utils_fp8.py:117-126 GLM4 scale attributes written as _weight_scale/_scale, but FP8Expert.forward hardcodes _scale_inv. Patched scales are invisible to the stock expert forward path.

Medium:

ID File:Line Description
B9 moe_utils_fp8.py:394-397 1D scale compared against wrong weight dimension (shape[-1] = in_dim, not out_dim)
B25 moe_utils_fp8.py:630, moe_utils.py:1467 GLM4 stacked experts module has no act_fn attribute -- _forward_native_fp8_expert_loop and forward_native_moe_loop will AttributeError.

Low/Performance:

ID File:Line Description
B10 moe_utils.py:1428-1487 forward_native_moe_loop doesn't handle 3D input (callers flatten, safe in practice but fragile)
B11 moe_utils_fp8.py:710-735 Full dequant on every forward call with no caching. For 64-expert GLM4-MoE, this is ~1-2 GB of fresh allocations per batch.
B26 moe_utils_fp8.py:653-679 _forward_native_moe_loop_fp8 is defined but never called (dead code)

Summary

The FP8 MoE loading, scale patching, and dequant-plus-grouped_mm fallback work correctly on SM 12.0. Training converges with normal loss curves and no numerical issues.

Key issues to address: (1) inverse scale handling in dequant path (B8), (2) unconditional FP8 routing for non-FP8 models (B12), (3) sharded checkpoint support (B7), (4) the no-op _make_grouped_mm_rhs_column_major (B23) which will produce wrong results on Hopper when _scaled_grouped_mm is eventually enabled. The existing fix PR #551 addresses some of these.

- B7: Support sharded safetensors (multi-shard) in GLM4 FP8 scale patching
- B8: Add quant_kind param to _dequantize_expert_slice for weight_scale_inv
  reciprocal handling
- B18: Guard fp8_linear import with try/except and dequant fallback
- B19: Flatten 3D top_k_index/top_k_weights alongside hidden_states
- B23: Fix _make_grouped_mm_rhs_column_major (was no-op double transpose,
  now weight.mT.contiguous())
- B25: Add act_fn fallback to F.silu when attribute missing
- B26: Remove dead _forward_native_moe_loop_fp8 function
- B10: Add 3D input reshape in forward_native_moe_loop
- B12: Prefer generic backend over unconditional FP8 in
  get_forward_moe_backend; use forward_moe_backend as final fallback
- Fix use_separated_lora to respect _should_use_separated_lora() instead
  of hardcoded True
- Remove no-op try/except RuntimeError and duplicate comment
@danielhanchen danielhanchen mentioned this pull request Mar 17, 2026
5 tasks
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1252be1abb

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +164 to +166
gate = file.get_tensor(
f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Catch missing-key failures in GLM4 FP8 scale patching

This helper performs hard get_tensor(...) lookups for per-expert gate_proj/up_proj/down_proj keys, but there is no exception handling around these reads; if a GLM4 compressed checkpoint uses a different safetensors layout (for example, stacked or renamed expert tensors), this raises and bubbles out of maybe_patch_stacked_moe_expert_fp8_scales, aborting model patching instead of cleanly skipping the recovery path.

Useful? React with 👍 / 👎.

Comment on lines +702 to +705
gate_up_param = nn.Parameter(gate_up_proj, requires_grad=old_gate_up.requires_grad)
down_param = nn.Parameter(down_proj, requires_grad=old_down.requires_grad)
setattr(experts_module, "gate_up_proj", gate_up_param)
setattr(experts_module, "down_proj", down_param)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep MoE LoRA adapters active in FP8 dequant fallback

The dequant fallback swaps gate_up_proj and down_proj to raw nn.Parameters, which drops wrapper metadata (lora_A/lora_B) used by grouped-mm/triton paths to extract separated LoRA deltas; as a result, FP8 runs that enter this fallback (e.g., non-SM90 or missing scaled-grouped-mm scales) silently execute base MoE weights and ignore active LoRA adapters.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants