[MoE] FP8 support for MoE, specifically GLM 4.7 flash#548
[MoE] FP8 support for MoE, specifically GLM 4.7 flash#548Datta0 wants to merge 12 commits intounslothai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Mixture-of-Experts (MoE) implementation by adding robust support for FP8 quantization. It introduces a flexible strategy to leverage Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces FP8 support for Mixture-of-Experts (MoE) models, with a specific focus on GLM-4. The changes include new forward paths to handle FP8 tensors, leveraging torch._scaled_grouped_mm on compatible hardware and providing a fallback that dequantizes weights on the fly for other GPUs. A patch is also included to correctly load FP8 scales for GLM-4 models. The implementation is comprehensive, addressing various hardware capabilities. My review feedback centers on enhancing maintainability by reducing code duplication in the newly added forward paths, improving the specificity of exception handling to prevent masking potential issues, and suggesting a refactoring of a lengthy function to improve its readability.
| def _forward_native_grouped_mm_active_dequant( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| top_k_index: torch.Tensor, | ||
| top_k_weights: torch.Tensor, | ||
| ) -> Optional[torch.Tensor]: | ||
| """ | ||
| FP8 compatibility path: dequantize only routed experts, then run grouped_mm. | ||
| Falls back to None when the expert quant metadata cannot be interpreted safely. | ||
| """ | ||
| # This Unsloth Zoo code section is licensed under AGPL3 | ||
|
|
||
| if not hasattr(self, "gate_up_proj") or not hasattr(self, "down_proj"): | ||
| return None | ||
|
|
||
| is_2d_input = hidden_states.dim() == 2 | ||
| if is_2d_input: | ||
| sequence_length, hidden_dim = hidden_states.shape | ||
| batch_size = 1 | ||
| else: | ||
| batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
|
|
||
| input_dtype = hidden_states.dtype | ||
| hidden_states = hidden_states.view(-1, hidden_dim) | ||
|
|
||
| flat_top_k = top_k_index.view(-1) | ||
| num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int() | ||
| sorted_indices = torch.argsort(flat_top_k, stable=True) | ||
| token_indices = sorted_indices // top_k_index.shape[-1] | ||
| permuted_input = hidden_states[token_indices] | ||
|
|
||
| active_expert_ids, active_counts, offsets = _build_active_expert_grouping(num_tokens_per_expert) | ||
| if active_expert_ids.numel() == 0: | ||
| return torch.zeros_like(hidden_states) if is_2d_input else hidden_states.new_zeros(batch_size, sequence_length, hidden_dim) | ||
|
|
||
| target_dtype = _get_fp8_dequant_target_dtype(permuted_input) | ||
| model_type = getattr(self, "_unsloth_model_type", None) | ||
| use_separated_lora = _should_use_separated_lora() | ||
|
|
||
| gate_up_base, gate_up_quant = _get_moe_weight_and_quant_state(self, "gate_up_proj") | ||
| down_base, down_quant = _get_moe_weight_and_quant_state(self, "down_proj") | ||
| gate_up_weight = _dequantize_active_expert_weights( | ||
| gate_up_base, | ||
| gate_up_quant, | ||
| active_expert_ids, | ||
| target_dtype, | ||
| "gate_up", | ||
| hidden_dim, | ||
| model_type, | ||
| ) | ||
| down_weight = _dequantize_active_expert_weights( | ||
| down_base, | ||
| down_quant, | ||
| active_expert_ids, | ||
| target_dtype, | ||
| "down", | ||
| hidden_dim, | ||
| model_type, | ||
| ) | ||
| if gate_up_weight is None or down_weight is None: | ||
| return None | ||
|
|
||
| permuted_input = permuted_input.to(target_dtype) | ||
| mm1_out = _grouped_mm_with_backward_fix(permuted_input, gate_up_weight.contiguous(), offsets) | ||
|
|
||
| gate_up_lora = None | ||
| if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None: | ||
| gate_up_lora = self._unsloth_lora_gate_up_proj[:3] | ||
| elif use_separated_lora and _has_lora_adapters(self.gate_up_proj): | ||
| gate_up_lora = _extract_lora_weights( | ||
| self.gate_up_proj, num_experts=self.num_experts, experts_module=self | ||
| ) | ||
|
|
||
| if gate_up_lora is not None: | ||
| first_weight, second_weight, scaling = gate_up_lora | ||
| active_expert_ids_device = active_expert_ids.to(first_weight.device) | ||
| first_weight = first_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() | ||
| second_weight = second_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() | ||
| lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets).contiguous() | ||
| try: | ||
| lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) | ||
| except RuntimeError: | ||
| lora_delta = torch.empty( | ||
| (lora_out.shape[0], second_weight.shape[-1]), | ||
| dtype=lora_out.dtype, | ||
| device=lora_out.device, | ||
| ) | ||
| cpu_offsets = offsets.cpu().tolist() | ||
| prev_offset = 0 | ||
| for i, end in enumerate(cpu_offsets): | ||
| if prev_offset < end: | ||
| lora_delta[prev_offset:end] = torch.matmul( | ||
| lora_out[prev_offset:end], second_weight[i] | ||
| ) | ||
| prev_offset = end | ||
| mm1_out = mm1_out + lora_delta * scaling | ||
|
|
||
| if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: | ||
| bias_indices = active_expert_ids.to(self.gate_up_proj_bias.device) | ||
| bias_expanded = self.gate_up_proj_bias.index_select(0, bias_indices).repeat_interleave( | ||
| active_counts.to(self.gate_up_proj_bias.device), dim=0 | ||
| ) | ||
| mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype) | ||
|
|
||
| if "GptOssExperts" in self.__class__.__name__: | ||
| gate = mm1_out[..., ::2] | ||
| up = mm1_out[..., 1::2] | ||
| limit = getattr(self, "limit", 7.0) | ||
| alpha = getattr(self, "alpha", 1.702) | ||
| gate = gate.clamp(min=None, max=limit) | ||
| up = up.clamp(min=-limit, max=limit) | ||
| inter = (up + 1.0) * (gate * torch.sigmoid(gate * alpha)) | ||
| else: | ||
| gate, up = mm1_out.chunk(2, dim=-1) | ||
| inter = F.silu(gate) * up | ||
|
|
||
| mm2_out = _grouped_mm_with_backward_fix(inter, down_weight.contiguous(), offsets) | ||
|
|
||
| down_lora = None | ||
| if getattr(self, "_unsloth_lora_down_proj", None) is not None: | ||
| down_lora = self._unsloth_lora_down_proj[:3] | ||
| elif use_separated_lora and _has_lora_adapters(self.down_proj): | ||
| down_lora = _extract_lora_weights( | ||
| self.down_proj, num_experts=self.num_experts, experts_module=self | ||
| ) | ||
|
|
||
| if down_lora is not None: | ||
| first_weight, second_weight, scaling = down_lora | ||
| active_expert_ids_device = active_expert_ids.to(first_weight.device) | ||
| first_weight = first_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() | ||
| second_weight = second_weight.index_select(0, active_expert_ids_device).to(target_dtype).contiguous() | ||
| lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets).contiguous() | ||
| try: | ||
| lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) | ||
| except RuntimeError: | ||
| lora_delta = torch.empty( | ||
| (lora_out.shape[0], second_weight.shape[-1]), | ||
| dtype=lora_out.dtype, | ||
| device=lora_out.device, | ||
| ) | ||
| cpu_offsets = offsets.cpu().tolist() | ||
| prev_offset = 0 | ||
| for i, end in enumerate(cpu_offsets): | ||
| if prev_offset < end: | ||
| lora_delta[prev_offset:end] = torch.matmul( | ||
| lora_out[prev_offset:end], second_weight[i] | ||
| ) | ||
| prev_offset = end | ||
| mm2_out = mm2_out + lora_delta * scaling | ||
|
|
||
| if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: | ||
| bias_indices = active_expert_ids.to(self.down_proj_bias.device) | ||
| bias_expanded = self.down_proj_bias.index_select(0, bias_indices).repeat_interleave( | ||
| active_counts.to(self.down_proj_bias.device), dim=0 | ||
| ) | ||
| mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype) | ||
|
|
||
| flat_weights = top_k_weights.view(-1) | ||
| permuted_weights = flat_weights[sorted_indices] | ||
| mm2_out = mm2_out * permuted_weights.unsqueeze(-1) | ||
|
|
||
| final_hidden_states = torch.zeros( | ||
| (batch_size * sequence_length, hidden_dim), | ||
| dtype=input_dtype, | ||
| device=hidden_states.device, | ||
| ) | ||
| final_hidden_states.index_add_(0, token_indices, mm2_out.to(input_dtype)) | ||
|
|
||
| if is_2d_input: | ||
| return final_hidden_states | ||
| return final_hidden_states.view(batch_size, sequence_length, hidden_dim) |
There was a problem hiding this comment.
This function, along with _forward_native_grouped_mm_scaled_fp8 and the original forward_native_grouped_mm, contains a significant amount of duplicated code. For example, the logic for handling LoRA projections, applying biases, and the final scatter-add operation is nearly identical across these functions. This extensive duplication increases the maintenance burden and the risk of introducing inconsistencies. I recommend refactoring the common logic into helper functions. For instance, the LoRA application logic could be extracted into a single helper that is called from all three forward passes, which would make the codebase more modular and easier to manage.
References
- Refactor duplicated logic into shared helper functions to avoid code duplication and improve maintainability.
| for expert_idx in range(num_experts): | ||
| gate = file.get_tensor( | ||
| f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight" | ||
| ) | ||
| gate_scale = file.get_tensor( | ||
| f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight_scale" | ||
| ) | ||
| up = file.get_tensor( | ||
| f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight" | ||
| ) | ||
| up_scale = file.get_tensor( | ||
| f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight_scale" | ||
| ) | ||
| down = file.get_tensor( | ||
| f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight" | ||
| ) | ||
| down_scale = file.get_tensor( | ||
| f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight_scale" | ||
| ) | ||
|
|
||
| gate_up_rows.append(torch.cat([gate, up], dim = 0)) | ||
| down_rows.append(down) | ||
| gate_up_scales.append(torch.cat([gate_scale, up_scale], dim = 0)) | ||
| down_scales.append(down_scale) |
There was a problem hiding this comment.
The function _maybe_patch_glm4_stacked_moe_fp8_scales is quite long, which can make it difficult to read and maintain. The loop over num_experts (lines 85-108), which is responsible for loading multiple tensors for each expert from the safetensors file, could be extracted into a separate helper function. Creating a helper such as _load_and_prepare_expert_tensors(file, layer_idx, expert_idx) would encapsulate this logic, thereby improving the overall readability and modularity of the code.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b4bcdef7a4
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if expert_quant_state is None: | ||
| return expert_weight.to(target_dtype) |
There was a problem hiding this comment.
Fall back when FP8 quant scales are missing
In _dequantize_expert_slice, FP8 weights with missing expert_quant_state are directly cast via expert_weight.to(target_dtype) instead of treated as an unsupported dequantization case. For compressed-tensors checkpoints where weight_scale tensors were not attached, this interprets quantized FP8 codes as real weights and silently produces incorrect MoE activations, while bypassing the intended fallback logic that should trigger when metadata is insufficient.
Useful? React with 👍 / 👎.
| major, _minor = torch.cuda.get_device_capability(torch.cuda.current_device()) | ||
| if major != 9: | ||
| _TORCH_SCALED_GROUPED_MM_SUPPORTED = False | ||
| return False |
There was a problem hiding this comment.
Remove SM90-only gate before scaled_grouped_mm probe
_check_torch_scaled_grouped_mm_supported returns False for every GPU whose compute capability major is not 9 before running the runtime probe. That means the FP8 scaled grouped-mm path is forcibly disabled on other architectures even when torch._scaled_grouped_mm exists and could work, so these runs always fall back to the slower dequantization path.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 64491aa59d
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Cache model dimensions and kernel configs on first call | ||
| if self._unsloth_moe_configs is None: | ||
| intermediate_dim = self.gate_up_proj.shape[1] // 2 | ||
| intermediate_dim = gate_up_proj.shape[1] // 2 |
There was a problem hiding this comment.
Compute Triton intermediate dim from actual GEMM weight axis
In the FP8 Triton fallback, gate_up_proj is passed from _dequantize_full_expert_weights, which already applies preprocess_weight and commonly yields a (E, hidden_dim, 2*intermediate) layout. Using gate_up_proj.shape[1] // 2 here therefore derives hidden_dim/2 instead of the real intermediate size, and that incorrect value is sent into get_or_autotune_moe_kernels. On FP8 models running the Triton path (e.g., when torch._grouped_mm is unavailable), this can cache/autotune kernels for the wrong matrix shapes and trigger grouped-GEMM runtime failures.
Useful? React with 👍 / 👎.
Review: FP8 MoE Support (zoo PR #548)Tested on NVIDIA RTX PRO 6000 Blackwell (SM 12.0, 98 GB VRAM), PyTorch 2.10.0, Transformers 5.3.0. Also reviewed the cleanup branch ( Cleanup Branch RecommendationThe cleanup branch is significantly better than the original PR and should be preferred. It fixes several bugs and has better code organization. However, some issues remain in both versions. Bugs FoundCritical (P1):
Medium (P2):
Minor (P3):
SM 12.0 Code Path TraceOn Blackwell SM 12.0:
Cleanup Branch Fixes Summary
|
Investigation: Why dequant doesn't work on cleanup branchTraced the full code path on the cleanup branch. Here is the root cause chain: The ProblemThe dequant path works fine on our Blackwell GPU, but only because When weights actually ARE FP8 (e.g., on Hopper where compressed_tensors might keep them in float8), the dequant fails because the scales are never attached. Root Cause Chain
The FixThe scale patching must happen BEFORE or independently of the dtype check. Two approaches: Option A: Intercept loading to prevent decompression of MoE stacked experts
Option B: Patch scales regardless of weight dtype
Option C: Keep FP8 weights by using
For any option, the Debug EvidenceThe UNEXPECTED keys from the load report confirm scales exist in checkpoint but are dropped: |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 20f56f8fea
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| fp8_module = _load_cached_moe_utils_fp8_module() | ||
| if fp8_module is not None and hasattr(fp8_module, "forward_moe_backend_fp8"): | ||
| _CACHED_FORWARD_MOE_BACKEND = fp8_module.forward_moe_backend_fp8 | ||
| return _CACHED_FORWARD_MOE_BACKEND |
There was a problem hiding this comment.
Route non-FP8 MoE through standard backend selector
get_forward_moe_backend now prioritizes forward_moe_backend_fp8 whenever the cached FP8 module exists, and install_to_cache creates that cache by default, so non-FP8 MoE models are also forced through the FP8 path. In that path, the code dequantizes/stacks full expert weights on every call and runs under @torch.compiler.disable, which introduces a major per-step performance regression for regular bf16/fp16 MoE training and inference.
Useful? React with 👍 / 👎.
| def create_block_mask_wrapper(*args, **kwargs): | ||
| kwargs["_compile"] = False | ||
| return torch_create_block_mask(*args, **kwargs) |
There was a problem hiding this comment.
Guard _compile kwarg injection for older torch versions
The new create_block_mask_wrapper always sets kwargs["_compile"] = False before delegating to torch_create_block_mask; on torch builds where create_block_mask does not accept _compile, this raises an unexpected-keyword TypeError during mask creation. The previous signature-based compatibility guard was removed, so this is now a runtime break for those environments.
Useful? React with 👍 / 👎.
| safetensors_path = hf_hub_download( | ||
| repo_id = model_name, | ||
| filename = "model.safetensors", |
There was a problem hiding this comment.
Support sharded safetensors in FP8 scale recovery
The GLM4 FP8 scale patch hardcodes filename="model.safetensors" when downloading checkpoint weights, but many large checkpoints are distributed only as sharded safetensors (model-0000x-of-... plus an index). In those repos this call fails and aborts patching instead of recovering scales, so affected models cannot use the intended remediation path.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 48c2e01790
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| gate_up_param = nn.Parameter(gate_up_proj, requires_grad=old_gate_up.requires_grad) | ||
| down_param = nn.Parameter(down_proj, requires_grad=old_down.requires_grad) |
There was a problem hiding this comment.
Preserve gradients when installing temporary dequantized weights
Wrapping gate_up_proj/down_proj in new nn.Parameter objects here detaches them from the dequantization graph, so gradients (if experts are trainable) land on temporary parameters that are discarded in finally instead of updating the original MoE weights. This makes the FP8 dequant fallback silently non-trainable for full-finetune or expert-unfrozen runs.
Useful? React with 👍 / 👎.
| routing_weights = routing_weights.to(router_input.dtype) | ||
|
|
||
| final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) | ||
| final_hidden_states = self.experts(router_input, selected_experts, routing_weights) |
There was a problem hiding this comment.
Keep expert inputs in expert-weight dtype
router_input is cast to the gate weight dtype and then reused for self.experts(...). On checkpoints where router weights are fp32 but expert weights are bf16/fp16, this passes fp32 activations into the MoE backend and can trigger dtype-mismatch failures in grouped GEMM paths (or unnecessary fp32 execution). The dtype cast should stay scoped to router computation, while experts should consume the original hidden-state tensor.
Useful? React with 👍 / 👎.
Updated Review: FP8 MoE Support (zoo PR #548)Tested on NVIDIA RTX PRO 6000 Blackwell (SM 12.0, 98 GB VRAM), PyTorch 2.10.0, Transformers 5.3.0. Correction: FP8 IS ActiveThe previous review's conclusion that FP8 was inoperative was incorrect -- it only checked layer 1 (BF16 by design, in the ignore list). All MoE layers 2-38 and 40-45 are What Works
Backend Selection on SM 12.0Bugs FoundCritical:
High:
Medium:
Low/Performance:
SummaryThe FP8 MoE loading, scale patching, and dequant-plus-grouped_mm fallback work correctly on SM 12.0. Training converges with normal loss curves and no numerical issues. Key issues to address: (1) inverse scale handling in dequant path (B8), (2) unconditional FP8 routing for non-FP8 models (B12), (3) sharded checkpoint support (B7), (4) the no-op |
- B7: Support sharded safetensors (multi-shard) in GLM4 FP8 scale patching - B8: Add quant_kind param to _dequantize_expert_slice for weight_scale_inv reciprocal handling - B18: Guard fp8_linear import with try/except and dequant fallback - B19: Flatten 3D top_k_index/top_k_weights alongside hidden_states - B23: Fix _make_grouped_mm_rhs_column_major (was no-op double transpose, now weight.mT.contiguous()) - B25: Add act_fn fallback to F.silu when attribute missing - B26: Remove dead _forward_native_moe_loop_fp8 function - B10: Add 3D input reshape in forward_native_moe_loop - B12: Prefer generic backend over unconditional FP8 in get_forward_moe_backend; use forward_moe_backend as final fallback - Fix use_separated_lora to respect _should_use_separated_lora() instead of hardcoded True - Remove no-op try/except RuntimeError and duplicate comment
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1252be1abb
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| gate = file.get_tensor( | ||
| f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight" | ||
| ) |
There was a problem hiding this comment.
Catch missing-key failures in GLM4 FP8 scale patching
This helper performs hard get_tensor(...) lookups for per-expert gate_proj/up_proj/down_proj keys, but there is no exception handling around these reads; if a GLM4 compressed checkpoint uses a different safetensors layout (for example, stacked or renamed expert tensors), this raises and bubbles out of maybe_patch_stacked_moe_expert_fp8_scales, aborting model patching instead of cleanly skipping the recovery path.
Useful? React with 👍 / 👎.
| gate_up_param = nn.Parameter(gate_up_proj, requires_grad=old_gate_up.requires_grad) | ||
| down_param = nn.Parameter(down_proj, requires_grad=old_down.requires_grad) | ||
| setattr(experts_module, "gate_up_proj", gate_up_param) | ||
| setattr(experts_module, "down_proj", down_param) |
There was a problem hiding this comment.
Keep MoE LoRA adapters active in FP8 dequant fallback
The dequant fallback swaps gate_up_proj and down_proj to raw nn.Parameters, which drops wrapper metadata (lora_A/lora_B) used by grouped-mm/triton paths to extract separated LoRA deltas; as a result, FP8 runs that enter this fallback (e.g., non-SM90 or missing scaled-grouped-mm scales) silently execute base MoE weights and ignore active LoRA adapters.
Useful? React with 👍 / 👎.


grouped_mmis not a supported operation on FP8, so we try to make use of_scaled_grouped_mm, but that is only supported on newer GPU architectures. If the operation is not supported, we de-quantize the tensors on the fly and pass it to the samegrouped_mmfunction.There seems to be some issue where transformers are not very keen on loading the weight scales for the FPI tensors, so we needed to patch it up. Currently, this is very specific to GLM; still not sure about Qwen3 MoE or other MOEs.
Note that this is only tested on B200 so far, and the losses seem to pretty closely match up with the BF16 training counterpart.
To Test:
Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8. The memory usage should be <40GBunsloth/GLM-4.7-Flash-FP8-Dynamic. Same as above