Skip to content

Commit dd5758b

Browse files
fix: SDXL DoRA LoRA fails with enable_partial_loading=true (invoke-ai#9063)
* fix: SDXL DoRA LoRA fails with enable_partial_loading=true cast_to_device returns plain torch.Tensor instead of torch.nn.Parameter, causing _aggregate_patch_parameters to replace valid weights with meta device dummies, falsely triggering DoRA's quantization guard. Fixes invoke-ai#8624 * test: regression coverage for DoRA + partial-loading + CPU→device autocast Adds targeted coverage for the bug fixed in a0a8721 (invoke-ai#8624, PR invoke-ai#9063): - test_aggregate_patch_parameters_preserves_plain_tensor_with_dora: CPU-only unit test that feeds a plain torch.Tensor (as handed in by _cast_weight_bias_for_input) into _aggregate_patch_parameters with a DoRA patch. Pre-fix, the tensor was replaced by a meta-device dummy, tripping DoRA's quantization guard. - "single_dora" variant in the patch_under_test fixture: exercises the full CUDA/MPS autocast hot path via test_linear_sidecar_patches_with_autocast_from_cpu_to_device. --------- Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
1 parent 0531108 commit dd5758b

2 files changed

Lines changed: 59 additions & 0 deletions

File tree

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _aggregate_patch_parameters(
5252
if isinstance(param, torch.nn.Parameter) and type(param.data) is torch.Tensor:
5353
pass
5454
elif type(param) is torch.Tensor:
55+
# Plain tensor (e.g. after cast_to_device moved a Parameter to another device).
5556
pass
5657
elif type(param) is GGMLTensor:
5758
# Move to device and dequantize here. Doing it in the patch layer can result in redundant casts /

tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from invokeai.backend.patches.layer_patcher import LayerPatcher
1616
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
17+
from invokeai.backend.patches.layers.dora_layer import DoRALayer
1718
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
1819
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
1920
from invokeai.backend.patches.layers.lora_layer import LoRALayer
@@ -346,6 +347,7 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
346347
"concatenated_lora",
347348
"flux_control_lora",
348349
"single_lokr",
350+
"single_dora",
349351
]
350352
)
351353
def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
@@ -432,6 +434,20 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
432434
)
433435
input = torch.randn(1, in_features)
434436
return ([(lokr_layer, 0.7)], input)
437+
elif layer_type == "single_dora":
438+
# Regression coverage for #8624: DoRA + partial-loading + CPU->device autocast.
439+
# Scaled down so the patched weight stays well-conditioned for allclose comparisons.
440+
# dora_scale has shape (1, in_features) to broadcast against direction_norm in
441+
# DoRALayer.get_weight — see dora_layer.py:74-82.
442+
dora_layer = DoRALayer(
443+
up=torch.randn(out_features, rank) * 0.01,
444+
down=torch.randn(rank, in_features) * 0.01,
445+
dora_scale=torch.ones(1, in_features),
446+
alpha=1.0,
447+
bias=torch.randn(out_features) * 0.01,
448+
)
449+
input = torch.randn(1, in_features)
450+
return ([(dora_layer, 0.7)], input)
435451
else:
436452
raise ValueError(f"Unsupported layer_type: {layer_type}")
437453

@@ -676,3 +692,45 @@ def test_conv2d_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):
676692

677693
assert output.dtype == input.dtype
678694
assert output.shape == (2, 16, 3, 3)
695+
696+
697+
@torch.no_grad()
698+
def test_aggregate_patch_parameters_preserves_plain_tensor_with_dora():
699+
"""Regression test for #8624: when partial-loading autocasts a CPU Parameter onto the
700+
compute device, cast_to_device returns a plain torch.Tensor (not a Parameter). The
701+
aggregator must treat that as a real tensor and not substitute a meta-device dummy —
702+
otherwise DoRA's quantization guard falsely triggers on non-quantized base models.
703+
704+
This test is CPU-only and simulates the hand-off by constructing a plain torch.Tensor
705+
directly; the equivalent CUDA/MPS E2E flow is exercised by the "single_dora" variant
706+
of test_linear_sidecar_patches_with_autocast_from_cpu_to_device.
707+
"""
708+
layer = wrap_single_custom_layer(torch.nn.Linear(32, 64))
709+
710+
rank = 4
711+
dora_patch = DoRALayer(
712+
up=torch.randn(64, rank) * 0.01,
713+
down=torch.randn(rank, 32) * 0.01,
714+
dora_scale=torch.ones(1, 32),
715+
alpha=1.0,
716+
bias=None,
717+
)
718+
719+
# Plain torch.Tensor — the shape _cast_weight_bias_for_input hands into
720+
# _aggregate_patch_parameters after autocasting a Parameter across devices.
721+
plain_weight = torch.randn(64, 32)
722+
assert type(plain_weight) is torch.Tensor
723+
724+
orig_params = {"weight": plain_weight}
725+
params = layer._aggregate_patch_parameters(
726+
patches_and_weights=[(dora_patch, 1.0)],
727+
orig_params=orig_params,
728+
device=torch.device("cpu"),
729+
)
730+
731+
# Pre-fix, orig_params["weight"] would have been replaced by a meta-device dummy,
732+
# causing DoRALayer.get_parameters to raise "not compatible with DoRA patches".
733+
assert orig_params["weight"].device.type == "cpu"
734+
assert params["weight"].shape == (64, 32)
735+
assert params["weight"].device.type == "cpu"
736+
assert not torch.isnan(params["weight"]).any()

0 commit comments

Comments
 (0)