|
52 | 52 | retie_parameters, |
53 | 53 | set_module_tensor_to_device, |
54 | 54 | ) |
| 55 | +from accelerate.utils.other import extract_model_from_parallel |
55 | 56 |
|
56 | 57 |
|
57 | 58 | torch_device = f"{torch_device}:0" if torch_device != "cpu" else "cpu" |
@@ -1065,3 +1066,24 @@ def test_align_module_device_offloaded_nested(self): |
1065 | 1066 | with align_module_device(module, align_device): |
1066 | 1067 | for param in model.parameters(recurse=False): |
1067 | 1068 | assert param.device == align_device |
| 1069 | + |
| 1070 | + def test_extract_model_from_parallel_partial_compile(self): |
| 1071 | + """Partial torch.compile on a submodule should not crash and should preserve the compiled wrapper.""" |
| 1072 | + model = ModelForTest() |
| 1073 | + model.linear2 = torch.compile(model.linear2) |
| 1074 | + |
| 1075 | + # Precondition: top is not compiled, only submodule is |
| 1076 | + assert not hasattr(model, "_orig_mod") |
| 1077 | + assert hasattr(model.linear2, "_orig_mod") |
| 1078 | + |
| 1079 | + # Standard extraction |
| 1080 | + extracted = extract_model_from_parallel(model) |
| 1081 | + x = torch.randn(2, 3) |
| 1082 | + torch.testing.assert_close(model(x), extracted(x)) |
| 1083 | + assert isinstance(extracted, ModelForTest) |
| 1084 | + assert hasattr(extracted.linear2, "_orig_mod") |
| 1085 | + |
| 1086 | + # Extraction with keep_torch_compile=False |
| 1087 | + extracted_no_keep = extract_model_from_parallel(model, keep_torch_compile=False) |
| 1088 | + assert hasattr(extracted_no_keep.linear2, "_orig_mod") |
| 1089 | + torch.testing.assert_close(model(x), extracted_no_keep(x)) |
0 commit comments