Skip to content

Commit 02ee93b

Browse files
committed
Add test for partial torch.compile unwrap in extract_model_from_parallel
1 parent f88344d commit 02ee93b

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/test_modeling_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
retie_parameters,
5353
set_module_tensor_to_device,
5454
)
55+
from accelerate.utils.other import extract_model_from_parallel
5556

5657

5758
torch_device = f"{torch_device}:0" if torch_device != "cpu" else "cpu"
@@ -1065,3 +1066,24 @@ def test_align_module_device_offloaded_nested(self):
10651066
with align_module_device(module, align_device):
10661067
for param in model.parameters(recurse=False):
10671068
assert param.device == align_device
1069+
1070+
def test_extract_model_from_parallel_partial_compile(self):
1071+
"""Partial torch.compile on a submodule should not crash and should preserve the compiled wrapper."""
1072+
model = ModelForTest()
1073+
model.linear2 = torch.compile(model.linear2)
1074+
1075+
# Precondition: top is not compiled, only submodule is
1076+
assert not hasattr(model, "_orig_mod")
1077+
assert hasattr(model.linear2, "_orig_mod")
1078+
1079+
# Standard extraction
1080+
extracted = extract_model_from_parallel(model)
1081+
x = torch.randn(2, 3)
1082+
torch.testing.assert_close(model(x), extracted(x))
1083+
assert isinstance(extracted, ModelForTest)
1084+
assert hasattr(extracted.linear2, "_orig_mod")
1085+
1086+
# Extraction with keep_torch_compile=False
1087+
extracted_no_keep = extract_model_from_parallel(model, keep_torch_compile=False)
1088+
assert hasattr(extracted_no_keep.linear2, "_orig_mod")
1089+
torch.testing.assert_close(model(x), extracted_no_keep(x))

0 commit comments

Comments
 (0)