Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions unsloth_zoo/saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,18 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict
remove_keys.add(name)
pass
pass
# PEFT target_parameters (ParamWrapper for nn.Parameter, not nn.Linear)
# have lora_A/B/scaling but no .base_layer, leaving module_count short.
# nn.Linear targets always get .base_layer set, so module=None reliably
# identifies nn.Parameter targets. Count them to align (#3405, #3701).
for _key, _stats in lora_weights.items():
if (
_stats.lora_A is not None
and _stats.lora_B is not None
and _stats.module is None
):
module_count += 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this pretty much fixes only the count part. I think my previous changes (in the #450 perhaps) would automatically handle the right tensor and file placement things I presume.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly right — this fix only addresses the module_count alignment so the mismatch warning no longer fires for nn.Parameter targets. The actual tensor placement and file writing is handled by your work in #450.

The two fixes are complementary: #450 handles the merge mechanics, this PR ensures the diagnostic counts are correct so users don't see a misleading warning during an otherwise successful merge.


if not (module_count == lora_A_count == lora_B_count == scaling_count):
print(
f"[Unsloth merge debug] LoRA count mismatch: modules={module_count}, "
Expand All @@ -405,9 +417,6 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict
print(f" key={k} param={param_name} A={a_shape} B={b_shape}")
except Exception:
pass
# Allow merge to continue; downstream checks will still fail loudly if tensors are missing
# but this avoids silent assertion without context.
# TODO: handle MoE target_parameters to align counts.

# Also return state_dict if needed
if return_state_dict:
Expand Down Expand Up @@ -2576,6 +2585,46 @@ def detect_keys_format(keys_to_check, forward_mapping):
return "new" # Default, assuming most models/keys will be in the "new" (current HF) format.
pass

def _infer_prefix_and_remap(lora_weights, safetensor_keys):
"""Infer a missing key prefix by matching LoRA keys against safetensor keys.

Some composite models (e.g. Qwen3.5) store safetensors with an extra
prefix like ``model.language_model.`` that differs from the runtime key
namespace ``model.``. When no explicit ``_checkpoint_conversion_mapping``
exists, this helper detects the discrepancy and remaps LoRA keys so that
the merge loop can match them.

Returns a remapped ``defaultdict`` on success, or ``None`` if no prefix
could be inferred (caller should fall back to returning keys unchanged).
"""
if not safetensor_keys:
return None

inferred_prefix = None
for lora_key in lora_weights:
if not isinstance(lora_key, str):
continue
suffix = lora_key + ".weight"
for sf_key in safetensor_keys:
if sf_key.endswith(suffix):
candidate = sf_key[: -len(suffix)]
if candidate: # non-empty extra prefix
inferred_prefix = candidate
break
if inferred_prefix is not None:
break

if inferred_prefix is None:
return None

remapped = defaultdict(lora_weights.default_factory)
for k, v in lora_weights.items():
new_key = inferred_prefix + k if isinstance(k, str) else k
remapped[new_key] = v
return remapped
pass


def _convert_lora_keys_to_safetensor_format(
lora_weights, # Global dict of LoraStats objects
safetensor_keys, # List of keys from the CURRENT shard
Expand All @@ -2587,6 +2636,9 @@ def _convert_lora_keys_to_safetensor_format(
forward_mapping = _get_checkpoint_conversion_mapping(model_class_name)

if not forward_mapping:
remapped = _infer_prefix_and_remap(lora_weights, safetensor_keys)
if remapped is not None:
return remapped
return defaultdict(lora_weights.default_factory, lora_weights)

# Create reverse mapping
Expand Down