Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions unsloth_zoo/patching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .compiler import UNSLOTH_COMPILE_LOCATION
from .utils import _get_dtype, Version
from .hf_utils import dtype_from_config, set_dtype_in_config, HAS_TORCH_DTYPE
from .temporary_patches.moe_utils_fp8 import maybe_patch_stacked_moe_expert_fp8_scales

# Also disable compiling on bitsandbytes
def patch_compiling_bitsandbytes():
Expand Down Expand Up @@ -396,6 +397,8 @@ def __fix_dtype(config):
# string when trying to save the config or serialize it
patch_to_dict()

maybe_patch_stacked_moe_expert_fp8_scales(model)

# Check all params and patch!
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
Expand Down
8 changes: 6 additions & 2 deletions unsloth_zoo/temporary_patches/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,12 @@ def moe_block_forward(self, hidden_states) -> torch.Tensor:
return hidden_states + shared_output

# Apply patches
patch_function(Glm4MoeLiteNaiveMoe, "forward", get_forward_moe_backend())
patch_function(Glm4MoeLiteMoE, "forward", moe_block_forward)
# Recent transformers wraps the expert forward with use_experts_implementation
# and drops some annotations, so strict signature matching rejects the patch.
# For GLM4 we want to bypass that wrapper entirely and route into Unsloth's
# MoE backend on purpose.
patch_function(Glm4MoeLiteNaiveMoe, "forward", get_forward_moe_backend(), force = True)
patch_function(Glm4MoeLiteMoE, "forward", moe_block_forward, force = True)

if UNSLOTH_ENABLE_LOGGING:
logger.info("Unsloth: Patched GLM4 MoE for Split LoRA support.")
Expand Down
203 changes: 183 additions & 20 deletions unsloth_zoo/temporary_patches/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,14 @@ def install_to_cache(source_path, destination_filename=None):


install_to_cache(__file__, "moe_utils.py")
_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
_MOE_UTILS_FP8_PATH = os.path.join(_CURRENT_DIR, "moe_utils_fp8.py")
if os.path.isfile(_MOE_UTILS_FP8_PATH):
install_to_cache(_MOE_UTILS_FP8_PATH, "moe_utils_fp8.py")

_CACHED_FORWARD_MOE_BACKEND = None
_CACHED_MOE_UTILS_MODULE = None
_CACHED_MOE_UTILS_FP8_MODULE = None


def _load_cached_moe_utils_module():
Expand All @@ -80,7 +85,7 @@ def _load_cached_moe_utils_module():
return None

try:
module_name = "unsloth_cached_moe_utils"
module_name = "unsloth_zoo.temporary_patches._cached_moe_utils"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

medium: Consider adding a more descriptive name to the module to avoid potential naming conflicts. Using a more specific name like _cached_moe_utils_module would enhance clarity.

module_name = "unsloth_zoo.temporary_patches._cached_moe_utils_module"

module = sys.modules.get(module_name, None)
if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file:
_CACHED_MOE_UTILS_MODULE = module
Expand All @@ -90,6 +95,7 @@ def _load_cached_moe_utils_module():
if spec is None or spec.loader is None:
return None
module = importlib.util.module_from_spec(spec)
module.__package__ = "unsloth_zoo.temporary_patches"
sys.modules[module_name] = module
spec.loader.exec_module(module)
_CACHED_MOE_UTILS_MODULE = module
Expand All @@ -98,18 +104,52 @@ def _load_cached_moe_utils_module():
return None


def _load_cached_moe_utils_fp8_module():
global _CACHED_MOE_UTILS_FP8_MODULE

cache_file = os.path.abspath(os.path.join(_get_compile_location(), "moe_utils_fp8.py"))
current_file = os.path.abspath(_MOE_UTILS_FP8_PATH)
if not os.path.isfile(cache_file) or cache_file == current_file:
return None

try:
module_name = "unsloth_zoo.temporary_patches._cached_moe_utils_fp8"
module = sys.modules.get(module_name, None)
if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file:
_CACHED_MOE_UTILS_FP8_MODULE = module
return module

spec = importlib.util.spec_from_file_location(module_name, cache_file)
if spec is None or spec.loader is None:
return None
module = importlib.util.module_from_spec(spec)
module.__package__ = "unsloth_zoo.temporary_patches"
sys.modules[module_name] = module
spec.loader.exec_module(module)
_CACHED_MOE_UTILS_FP8_MODULE = module
return module
except Exception:
return None
Comment on lines +131 to +132
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

medium: The except Exception block is too broad. It would be better to catch specific exceptions like FileNotFoundError to avoid masking unexpected errors.

References
  1. Avoid using try...except ImportError for libraries that are required dependencies of the project, as the check is redundant.



def get_forward_moe_backend():
"""
Resolve forward_moe_backend from the compiled cache copy when available.
Falls back to the local module definition.
"""
global _CACHED_FORWARD_MOE_BACKEND
fp8_module = _load_cached_moe_utils_fp8_module()
if fp8_module is not None and hasattr(fp8_module, "forward_moe_backend_fp8"):
_CACHED_FORWARD_MOE_BACKEND = fp8_module.forward_moe_backend_fp8
return _CACHED_FORWARD_MOE_BACKEND

module = _load_cached_moe_utils_module()
if module is not None and hasattr(module, "forward_moe_backend"):
_CACHED_FORWARD_MOE_BACKEND = module.forward_moe_backend
return _CACHED_FORWARD_MOE_BACKEND

_CACHED_FORWARD_MOE_BACKEND = forward_moe_backend
from .moe_utils_fp8 import forward_moe_backend_fp8
_CACHED_FORWARD_MOE_BACKEND = forward_moe_backend_fp8
return _CACHED_FORWARD_MOE_BACKEND

# ============================================================================
Expand All @@ -134,7 +174,6 @@ def _grouped_mm_with_backward_fix(
# Global flag to check if grouped GEMM is available
_GROUPED_GEMM_AVAILABLE = None
_TORCH_GROUPED_MM_AVAILABLE = hasattr(torch, "_grouped_mm")

# Check if GPU supports torch._grouped_mm (verified via runtime check)
_TORCH_GROUPED_MM_SUPPORTED = None

Expand Down Expand Up @@ -181,6 +220,15 @@ def _check_torch_grouped_mm_supported():
_PERSISTENT_BUFFER = None


def _try_attach_block_size(tensor_like, block_size) -> None:
if block_size is None or tensor_like is None:
return
try:
tensor_like.block_size = block_size
except (AttributeError, RuntimeError):
pass


def _init_triton_allocator():
"""
Initialize a persistent Triton allocator to avoid memory allocation overhead per call.
Expand Down Expand Up @@ -279,6 +327,14 @@ def forward_moe_backend(
Centralizes backend selection to keep model-specific patches minimal.
"""
# This Unsloth Zoo code section is licensed under AGPL3
try:
from .moe_utils_fp8 import _moe_uses_fp8_expert_weights, forward_moe_backend_fp8
if _moe_uses_fp8_expert_weights(self):
return forward_moe_backend_fp8(
self, hidden_states, top_k_index, top_k_weights
)
except Exception:
pass

backend = select_moe_backend()
if backend == "grouped_mm":
Expand Down Expand Up @@ -477,6 +533,130 @@ def _get_base_weight(param):
return param


def _get_base_weight_and_quant_state(param):
base_layer = param
while hasattr(base_layer, "base_layer"):
base_layer = base_layer.base_layer

if hasattr(base_layer, "get_param"):
weight = base_layer.get_param()
elif hasattr(base_layer, "weight"):
weight = base_layer.weight
else:
weight = base_layer

quant_state = getattr(weight, "quant_state", None)
if quant_state is None:
quant_state = getattr(base_layer, "weight_scale_inv", None)
if quant_state is None:
quant_state = getattr(base_layer, "weight_scale", None)

block_size = getattr(base_layer, "block_size", None)
if block_size is not None:
_try_attach_block_size(weight, block_size)
if quant_state is not None:
_try_attach_block_size(quant_state, block_size)

return weight, quant_state


def _get_moe_weight_and_quant_state(experts_module, param_name: str):
param = getattr(experts_module, param_name)
weight, quant_state = _get_base_weight_and_quant_state(param)

if quant_state is None:
quant_state = getattr(experts_module, f"{param_name}_weight_scale_inv", None)
if quant_state is None:
quant_state = getattr(experts_module, f"{param_name}_weight_scale", None)
if quant_state is None:
quant_state = getattr(experts_module, f"{param_name}_scale_inv", None)
if quant_state is None:
quant_state = getattr(experts_module, f"{param_name}_scale", None)

block_size = getattr(param, "block_size", None)
if block_size is None:
block_size = getattr(experts_module, f"{param_name}_block_size", None)
if block_size is not None:
_try_attach_block_size(weight, block_size)
if quant_state is not None:
_try_attach_block_size(quant_state, block_size)

return weight, quant_state


def _get_grouped_lora(self, projection_name: str, cache_attr: str, use_separated_lora: bool):
cached_lora = getattr(self, cache_attr, None)
if cached_lora is not None:
return cached_lora[:3]

projection = getattr(self, projection_name, None)
if use_separated_lora and projection is not None and _has_lora_adapters(projection):
return _extract_lora_weights(
projection, num_experts=self.num_experts, experts_module=self
)
return None


def _apply_grouped_lora(
grouped_input: torch.Tensor,
lora_weights,
offsets: torch.Tensor,
target_dtype: torch.dtype,
active_expert_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
first_weight, second_weight, scaling = lora_weights
if active_expert_ids is not None:
active_expert_ids = active_expert_ids.to(first_weight.device)
first_weight = first_weight.index_select(0, active_expert_ids)
second_weight = second_weight.index_select(0, active_expert_ids)

first_weight = first_weight.to(target_dtype).contiguous()
second_weight = second_weight.to(target_dtype).contiguous()
lora_out = _grouped_mm_with_backward_fix(
grouped_input.to(target_dtype), first_weight, offsets
).contiguous()
try:
if second_weight.shape[-1] % 8 != 0:
pad_size = 8 - (second_weight.shape[-1] % 8)
second_weight_padded = F.pad(second_weight, (0, pad_size)).contiguous()
lora_delta = _grouped_mm_with_backward_fix(
lora_out, second_weight_padded, offsets
)
lora_delta = lora_delta[:, :-pad_size]
else:
lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets)
except RuntimeError:
lora_delta = torch.empty(
(lora_out.shape[0], second_weight.shape[-1]),
dtype=lora_out.dtype,
device=lora_out.device,
)
cpu_offsets = offsets.cpu().tolist()
prev_offset = 0
for i, end in enumerate(cpu_offsets):
if prev_offset < end:
lora_delta[prev_offset:end] = torch.matmul(
lora_out[prev_offset:end], second_weight[i]
)
prev_offset = end
return lora_delta * scaling


def _expand_grouped_bias(
bias: torch.Tensor,
counts: torch.Tensor,
expert_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if expert_ids is None:
expanded = bias.repeat_interleave(counts.to(bias.device), dim=0)
else:
expert_ids = expert_ids.to(bias.device)
expanded = bias.index_select(0, expert_ids).repeat_interleave(
counts.to(bias.device), dim=0
)
return expanded


def _get_lora_wrapper_for_param(experts_module, param_name):
"""
Get the PEFT ParamWrapper for a specific parameter (gate_up_proj or down_proj).
Expand Down Expand Up @@ -672,27 +852,18 @@ def _patched_param_wrapper_forward(
"""
# This Unsloth Zoo code section is licensed under AGPL3

# CRITICAL: Use self.base_layer for forward call (immediate parent)
# NOT self.get_base_layer() which recursively traverses to deepest layer!
# The wrapper chain must be preserved: down_proj -> gate_up_proj -> Qwen3MoeExperts
immediate_base_layer = self.base_layer

# For storing LoRA data, we DO need the actual experts module
# Use get_base_layer() to find it (recursive traversal is correct here)
experts_module = self.get_base_layer()

use_separated = _should_use_separated_lora()
param_name = getattr(self, "parameter_name", None)

# Check if this is an MoE experts module that should use separated LoRA
if (
use_separated
and param_name in ("gate_up_proj", "down_proj")
and _is_moe_experts_module(experts_module)
):
# MoE experts: bypass PEFT's _activate_lora, use separated computation

# Check adapter state
if self.disable_adapters:
if self.merged:
self.unmerge()
Expand All @@ -701,7 +872,6 @@ def _patched_param_wrapper_forward(
if self.merged:
return immediate_base_layer(x, *args, **kwargs)

# Ensure wrapper.num_experts is set for LoRA weight reshaping
if not hasattr(self, "num_experts"):
if hasattr(experts_module, "num_experts"):
self.num_experts = experts_module.num_experts
Expand All @@ -710,29 +880,22 @@ def _patched_param_wrapper_forward(
if hasattr(p, "shape") and len(p.shape) >= 1:
self.num_experts = p.shape[0]

# Extract LoRA for this specific parameter
lora_data = _extract_lora_from_wrapper(self)

if lora_data is not None and param_name:
# Store LoRA data on the EXPERTS MODULE (not base_layer)
# e.g., _unsloth_lora_gate_up_proj or _unsloth_lora_down_proj
lora_attr = f"_unsloth_lora_{param_name}"
setattr(experts_module, lora_attr, lora_data)

try:
# Call IMMEDIATE base_layer to preserve wrapper chain
# (down_proj wrapper calls gate_up_proj wrapper calls Qwen3MoeExperts)
result = immediate_base_layer(x, *args, **kwargs)
finally:
# Clean up
if param_name:
lora_attr = f"_unsloth_lora_{param_name}"
if hasattr(experts_module, lora_attr):
delattr(experts_module, lora_attr)

return result

# Non-MoE: use original PEFT forward with _activate_lora
return _original_param_wrapper_forward(self, x, *args, **kwargs)


Expand Down
Loading
Loading