-
Notifications
You must be signed in to change notification settings - Fork 234
Add Bnb4bit support for MoE models on transformers v5 - #4032 #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sensai99
wants to merge
23
commits into
unslothai:main
Choose a base branch
from
sensai99:moeFix
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
dcde5aa
add fixes for moe
sensai99 d07b488
fix quantized params4bit returns correct features via proxy param
sensai99 a706108
remove the old code
sensai99 50d3560
clean moe_bnb_transformers.py
sensai99 eebb31d
clean moe_utils
sensai99 615bc12
clean misc.py
sensai99 8e79b07
replace only expert params
sensai99 d98f6a1
"fix function doc"
sensai99 4cb3a65
"fix logic for handling experts params that are not nn.Parameter"
sensai99 87eb337
"clean moe_bnb_transformers"
sensai99 52f0260
"fix comments in misc"
sensai99 4b4ab68
"clean moe_utils comments"
sensai99 d12c0c2
"rename file"
sensai99 c4ad5e7
"fix minor issue"
sensai99 268c6c0
"move patches to new file moe_bnb_transformers.py"
sensai99 d9bb923
"restore moe_bnb.py"
sensai99 e730337
"minor changes"
sensai99 f3f2c6e
"fix Params4bit instance check in moe_utils"
sensai99 d5b567c
"clean code"
sensai99 d7edd84
Merge remote-tracking branch 'upstream/main' into moeFix
sensai99 d5a2a6b
"fix docs @moe_utils.py"
sensai99 fb69ead
"fix docs @misc.py"
sensai99 493405b
"fix minor function patch issue"
sensai99 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,254 @@ | ||
| # Unsloth Zoo - Utilities for Unsloth | ||
| # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. | ||
| # | ||
| # This program is free software: you can redistribute it and/or modify | ||
| # it under the terms of the GNU Affero General Public License as published | ||
| # by the Free Software Foundation, either version 3 of the License, or | ||
| # (at your option) any later version. | ||
| # | ||
| # This program is distributed in the hope that it will be useful, | ||
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
| # GNU Affero General Public License for more details. | ||
| # | ||
| # You should have received a copy of the GNU Affero General Public License | ||
| # along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
| """ | ||
| MoE Expert Parameter 4-bit Quantization Patch for Transformers | ||
|
|
||
| Patches transformers' bitsandbytes quantization to handle MoE expert parameters | ||
| (gate_up_proj, down_proj) that are nn.Parameter instead of nn.Linear. | ||
| """ | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from typing import Optional, List, Tuple, Union | ||
| import os | ||
| from .common import TEMPORARY_PATCHES, UNSLOTH_ENABLE_LOGGING, logger | ||
| from .utils import patch_function, raise_error | ||
|
|
||
| # Check bitsandbytes availability | ||
| try: | ||
| import bitsandbytes as bnb | ||
| from bitsandbytes.nn import Params4bit | ||
| HAS_BNB = True | ||
| except ImportError: | ||
| HAS_BNB = False | ||
| Params4bit = None | ||
|
|
||
|
|
||
| def _check_bnb_available(): | ||
| if not HAS_BNB: | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| def _is_expert_module(module: nn.Module) -> bool: | ||
| """ | ||
| Check if a module is an MoE experts module. | ||
| Specifically, check if the module has gate_up_proj & down_proj attributes that are nn.Parameter. | ||
| """ | ||
| return ( | ||
| hasattr(module, "gate_up_proj") | ||
| and hasattr(module, "down_proj") | ||
| and isinstance(module.gate_up_proj, nn.Parameter) | ||
| and isinstance(module.down_proj, nn.Parameter) | ||
| ) | ||
|
|
||
|
|
||
| def replace_expert_params_with_bnb_params( | ||
| model: nn.Module, | ||
| modules_to_not_convert: Optional[List[str]] = None, | ||
| quantization_config = None, | ||
| pre_quantized: bool = False, | ||
| ) -> nn.Module: | ||
| """ | ||
| Replace MoE parameters (gate_up_proj, down_proj) of nn.Parameter type with Params4bit. | ||
| """ | ||
|
|
||
| try: | ||
| from transformers.quantizers.quantizers_utils import should_convert_module | ||
| except Exception as e: | ||
| return raise_error("transformers.quantizers.quantizers_utils.should_convert_module", e) | ||
|
|
||
| has_been_replaced = False | ||
|
|
||
| for module_name, module in model.named_modules(): | ||
| if not should_convert_module(module_name, modules_to_not_convert): | ||
| continue | ||
|
|
||
| if not _is_expert_module(module): | ||
| continue | ||
|
|
||
| gate_up_proj = module.gate_up_proj | ||
| down_proj = module.down_proj | ||
|
|
||
| if isinstance(gate_up_proj, Params4bit) or isinstance(down_proj, Params4bit): | ||
| continue | ||
| with torch.device("meta"): | ||
| placeholder_gate_up = Params4bit( | ||
| torch.zeros(gate_up_proj.shape), | ||
| requires_grad=False, | ||
| compress_statistics=quantization_config.bnb_4bit_use_double_quant, | ||
| quant_type=quantization_config.bnb_4bit_quant_type, | ||
| quant_storage=quantization_config.bnb_4bit_quant_storage, | ||
| ) | ||
|
|
||
| placeholder_down = Params4bit( | ||
| torch.zeros(down_proj.shape), | ||
| requires_grad=False, | ||
| compress_statistics=quantization_config.bnb_4bit_use_double_quant, | ||
| quant_type=quantization_config.bnb_4bit_quant_type, | ||
| quant_storage=quantization_config.bnb_4bit_quant_storage, | ||
| ) | ||
|
|
||
| if pre_quantized: | ||
| placeholder_gate_up.data = placeholder_gate_up.data.to(dtype=quantization_config.bnb_4bit_quant_storage) | ||
| placeholder_down.data = placeholder_down.data.to(dtype=quantization_config.bnb_4bit_quant_storage) | ||
| module.gate_up_proj = placeholder_gate_up | ||
| module.down_proj = placeholder_down | ||
| has_been_replaced = True | ||
|
|
||
| # TODO: Can remove this? | ||
| logger.info(f"Unsloth: Prepared {module_name}'s gate_up_proj & down_proj for BNB 4-bit quantization (shapes: {gate_up_proj.shape}, {down_proj.shape})") | ||
|
|
||
| if not has_been_replaced: | ||
| logger.warning(f"Unsloth: No expert parameters were found to be replaced for {model.name_or_path}") | ||
|
Comment on lines
+115
to
+116
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| return model | ||
|
|
||
|
|
||
| def patch_bnb4bit_quantize_convert(): | ||
| """ | ||
| Expert modules of nn.Parameter type are converted to Params4bit placeholders during weight loading. | ||
| Also preserves the original shape of the expert parameters for PEFT LoRA compatibility. | ||
| """ | ||
|
|
||
| try: | ||
| from transformers.integrations.bitsandbytes import Bnb4bitQuantize | ||
| except Exception as e: | ||
| return raise_error("transformers.integrations.bitsandbytes.Bnb4bitQuantize", e) | ||
|
|
||
| if hasattr(Bnb4bitQuantize.convert, "_unsloth_moe_patched"): | ||
| return | ||
|
|
||
| original_convert = Bnb4bitQuantize.convert | ||
|
|
||
| def patched_convert( | ||
| self, | ||
| input_dict: dict[str, Union[list[torch.Tensor], torch.Tensor]], | ||
| full_layer_name: str | None = None, | ||
| model: torch.nn.Module | None = None, | ||
| **kwargs, | ||
| ) -> dict[str, torch.Tensor]: | ||
| """ | ||
| input_dict: | ||
| - Dictionary containing raw tensors for the parameter to be quantized. | ||
| - For MoE module of nn.Parameter type, value is a tensor. TODO: Fix the comment | ||
| full_layer_name: Name of the module to be quantized. | ||
| """ | ||
| value = list(input_dict.values())[0] | ||
sensai99 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| try: | ||
| from transformers.quantizers.quantizers_utils import get_module_from_name | ||
| module, _ = get_module_from_name(model, full_layer_name) | ||
|
|
||
| if _is_expert_module(module): | ||
| old_value = model.get_parameter_or_buffer(full_layer_name) | ||
|
|
||
| old_dict = {k: v for k, v in old_value.__dict__.items()} | ||
| new_value = Params4bit(value, requires_grad=False, **old_dict).to(value.device) | ||
|
|
||
| # Preserve _original_shape for expert params (critical for PEFT LoRA compatibility) | ||
| original_shape = value.shape | ||
| if original_shape is not None: | ||
| setattr(new_value, "_original_shape", original_shape) | ||
|
|
||
| module._is_hf_initialized = True | ||
| return {full_layer_name: new_value} | ||
|
|
||
| except Exception as e: | ||
| logger.warning(f"Unsloth: Error handling expert param quantization for {full_layer_name}: {e}") | ||
|
Comment on lines
+170
to
+171
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| pass | ||
|
|
||
| # Fall back to original convert for non-expert params or in case of any failure | ||
| return original_convert(self, input_dict, full_layer_name=full_layer_name, model=model, **kwargs) | ||
|
|
||
| patched_convert._unsloth_moe_patched = True | ||
| patch_function(Bnb4bitQuantize, "convert", patched_convert) | ||
|
|
||
| logger.info("Unsloth: Patched Bnb4bitQuantize.convert for MoE expert parameter support") | ||
| pass | ||
| TEMPORARY_PATCHES.append(patch_bnb4bit_quantize_convert) | ||
|
|
||
|
|
||
| def patch_bnb4bit_quantizer_param_needs_quantization(): | ||
| """Recognize MoE expert modules of Params4bit type as needing quantization.""" | ||
|
|
||
| try: | ||
| from transformers.quantizers.quantizer_bnb_4bit import Bnb4BitHfQuantizer | ||
| from transformers.quantizers.quantizers_utils import get_module_from_name | ||
| except Exception as e: | ||
| return raise_error("transformers.quantizers.quantizer_bnb_4bit.Bnb4BitHfQuantizer", e) | ||
|
|
||
| if hasattr(Bnb4BitHfQuantizer.param_needs_quantization, "_unsloth_moe_patched"): | ||
| return | ||
|
|
||
| original_param_needs_quantization = Bnb4BitHfQuantizer.param_needs_quantization | ||
|
|
||
| def patched_param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: | ||
| if original_param_needs_quantization(self, model, param_name, **kwargs): | ||
| return True | ||
|
|
||
| try: | ||
| module, name = get_module_from_name(model, param_name) | ||
| if name in ("gate_up_proj", "down_proj"): | ||
| param = getattr(module, name, None) | ||
| if isinstance(param, Params4bit): | ||
| return True | ||
| except Exception as e: | ||
| # TODO: Can we raise an error here? | ||
| logger.warning( | ||
| f"Unsloth: Error checking MoE expert param_needs_quantization for {param_name}: {e}" | ||
| ) | ||
|
Comment on lines
+210
to
+213
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| pass | ||
|
|
||
| return False | ||
|
|
||
| patched_param_needs_quantization._unsloth_moe_patched = True | ||
| patch_function(Bnb4BitHfQuantizer, "param_needs_quantization", patched_param_needs_quantization) | ||
|
|
||
| logger.info("Unsloth: Patched Bnb4BitHfQuantizer.param_needs_quantization for MoE expert parameters") | ||
| pass | ||
| TEMPORARY_PATCHES.append(patch_bnb4bit_quantizer_param_needs_quantization) | ||
|
|
||
|
|
||
| def patch_bnb4bit_quantizer_process_model(): | ||
| try: | ||
| from transformers.quantizers.quantizer_bnb_4bit import Bnb4BitHfQuantizer | ||
| except Exception as e: | ||
| return raise_error("transformers.quantizers.quantizer_bnb_4bit.Bnb4BitHfQuantizer", e) | ||
|
|
||
| # Fast return if already patched | ||
| if hasattr(Bnb4BitHfQuantizer._process_model_before_weight_loading, "_unsloth_moe_patched"): | ||
| return | ||
|
|
||
| original_process_model_before_weight_loading = Bnb4BitHfQuantizer._process_model_before_weight_loading | ||
|
|
||
| def patched_process_model_before_weight_loading(self, model, device_map, **kwargs): | ||
| original_process_model_before_weight_loading(self, model, device_map, **kwargs) | ||
|
|
||
| # Use the patched replace_expert_params_with_bnb_params function | ||
| model = replace_expert_params_with_bnb_params( | ||
| model, | ||
| modules_to_not_convert=self.modules_to_not_convert, | ||
| quantization_config=self.quantization_config, | ||
| pre_quantized=self.pre_quantized, | ||
| ) | ||
| return model | ||
|
|
||
| patched_process_model_before_weight_loading._unsloth_moe_patched = True | ||
| patch_function(Bnb4BitHfQuantizer, "_process_model_before_weight_loading", patched_process_model_before_weight_loading, match_level = "relaxed") | ||
| pass | ||
| pass | ||
| TEMPORARY_PATCHES.append(patch_bnb4bit_quantizer_process_model) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a more specific exception type instead of a general
Exceptionto catch only the expected error, which would prevent masking other potential issues.