Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions mlx_vlm/tests/test_trainer_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import unittest
from unittest.mock import MagicMock, patch

import mlx.core as mx
import mlx.nn as nn

from mlx_vlm.trainer.lora import LoRaLayer
from mlx_vlm.trainer.utils import (
find_all_linear_names,
get_module_by_name,
Expand Down Expand Up @@ -54,6 +56,46 @@ def test_find_all_linear_names(self):
result = find_all_linear_names(model)
self.assertEqual(set(result), {"layer1", "layer2"})

def test_lora_layer_uses_alpha_over_rank_scaling(self):
"""LoRaLayer must apply the standard alpha/rank scaling factor.

Regression test for issue #845: previously the layer multiplied
the LoRA update by raw `alpha`, making the effective scaling
rank-times too large for the documented defaults.
"""
rank = 8
alpha = 16
linear = nn.Linear(64, 64)
lora = LoRaLayer(linear, rank=rank, alpha=alpha, dropout=0.0)

self.assertEqual(lora.rank, rank)
self.assertEqual(lora.alpha, alpha)
self.assertAlmostEqual(lora.scaling, alpha / rank)

def test_lora_layer_forward_matches_alpha_over_rank(self):
"""LoRaLayer.__call__ output should equal base + (alpha/rank) * (x A B).

Verifies the actual forward pass uses the corrected scaling, not
just the stored attribute. Sets B to a non-zero value so the
update is observable (zero-init B is the standard PEFT default).
"""
rank = 4
alpha = 8
linear = nn.Linear(8, 8, bias=False)
lora = LoRaLayer(linear, rank=rank, alpha=alpha, dropout=0.0)

# Override B with deterministic non-zero values so update != 0.
lora.B = mx.ones((rank, 8))
x = mx.ones((1, 8))

expected_base = linear(x)
expected_update = (alpha / rank) * ((x @ lora.A) @ lora.B)
expected = expected_base + expected_update.astype(x.dtype)

actual = lora(x)
mx.eval(actual, expected)
self.assertTrue(mx.allclose(actual, expected, atol=1e-5).item())


if __name__ == "__main__":
unittest.main()
16 changes: 13 additions & 3 deletions mlx_vlm/trainer/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,29 @@ def __init__(
shape=(input_dims, rank),
)
self.B = mx.zeros((rank, output_dims))
self.rank = rank
self.alpha = alpha
# Standard LoRA scaling factor (Hu et al. 2021): alpha / rank.
# Computed once at construction so __call__ and
# replace_lora_with_linear apply consistent scaling. Previously
# this layer multiplied the update by raw `alpha`, which made
# the effective scaling rank-times too large for the documented
# defaults (e.g. r=8 alpha=16 gave an effective scaling of 16
# instead of 2). See issue #845.
self.scaling = alpha / rank

def __call__(self, x):
y = self.original_layer(x)
lora_update = (self.dropout(x) @ self.A) @ self.B
return y + (self.alpha * lora_update).astype(x.dtype)
return y + (self.scaling * lora_update).astype(x.dtype)


def replace_lora_with_linear(model):
for i, layer in enumerate(model.layers):
if isinstance(layer, LoRaLayer):
# Compute the final merged weight
lora_update = layer.alpha * (layer.A @ layer.B)
# Compute the final merged weight using the same alpha/rank
# scaling that LoRaLayer.__call__ applies during training.
lora_update = layer.scaling * (layer.A @ layer.B)
updated_weight = layer.original_layer.weight + lora_update
use_bias = layer.original_layer.bias is not None

Expand Down