Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions unsloth_zoo/saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,18 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict
remove_keys.add(name)
pass
pass
# MoE target_parameters (ParamWrapper) entries have lora_A/B/scaling but
# may lack a corresponding .base_layer module, leaving module_count short.
# Count these so the diagnostic check below stays accurate (#3405, #3701).
for key, stats in lora_weights.items():
if (
stats.lora_A is not None
and stats.lora_B is not None
and stats.module is None
and ".mlp.experts" in key
):
module_count += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this pretty much fixes only the count part. I think my previous changes (in the #450 perhaps) would automatically handle the right tensor and file placement things I presume.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly right — this fix only addresses the module_count alignment so the mismatch warning no longer fires for nn.Parameter targets. The actual tensor placement and file writing is handled by your work in #450.

The two fixes are complementary: #450 handles the merge mechanics, this PR ensures the diagnostic counts are correct so users don't see a misleading warning during an otherwise successful merge.


if not (module_count == lora_A_count == lora_B_count == scaling_count):
print(
f"[Unsloth merge debug] LoRA count mismatch: modules={module_count}, "
Expand All @@ -405,9 +417,6 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict
print(f" key={k} param={param_name} A={a_shape} B={b_shape}")
except Exception:
pass
# Allow merge to continue; downstream checks will still fail loudly if tensors are missing
# but this avoids silent assertion without context.
# TODO: handle MoE target_parameters to align counts.

# Also return state_dict if needed
if return_state_dict:
Expand Down