Skip to content

Commit c51a579

Browse files
committed
Add _get_dummy_inputs_for_model helper for quantization tests
Introduce helper that casts floating-point input tensors to the model's parameter dtype, preventing dtype mismatches during quantized model inference.
1 parent c8c8401 commit c51a579

1 file changed

Lines changed: 25 additions & 17 deletions

File tree

tests/models/testing_utils/quantization.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ def _is_module_quantized(self, module):
142142
except (AssertionError, AttributeError):
143143
return False
144144

145+
def _get_dummy_inputs_for_model(self, model):
146+
inputs = self.get_dummy_inputs()
147+
model_dtype = next(model.parameters()).dtype
148+
return {
149+
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
150+
for k, v in inputs.items()
151+
}
152+
145153
def _load_unquantized_model(self):
146154
kwargs = getattr(self, "pretrained_model_kwargs", {})
147155
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
@@ -174,7 +182,7 @@ def _test_quantization_inference(self, config_kwargs):
174182
model_quantized = self._create_quantized_model(config_kwargs)
175183
model_quantized.to(torch_device)
176184

177-
inputs = self.get_dummy_inputs()
185+
inputs = self._get_dummy_inputs_for_model(model_quantized)
178186
output = model_quantized(**inputs, return_dict=False)[0]
179187

180188
assert output is not None, "Model output is None"
@@ -222,7 +230,8 @@ def _test_quantization_lora_inference(self, config_kwargs):
222230
# Move LoRA adapter weights to device (they default to CPU)
223231
model.to(torch_device)
224232

225-
inputs = self.get_dummy_inputs()
233+
inputs = self._get_dummy_inputs_for_model(model)
234+
226235
output = model(**inputs, return_dict=False)[0]
227236

228237
assert output is not None, "Model output is None with LoRA"
@@ -236,7 +245,8 @@ def _test_quantization_serialization(self, config_kwargs, tmp_path):
236245

237246
model_loaded = self.model_class.from_pretrained(str(tmp_path))
238247

239-
inputs = self.get_dummy_inputs()
248+
inputs = self._get_dummy_inputs_for_model(model_loaded)
249+
240250
output = model_loaded(**inputs, return_dict=False)[0]
241251
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
242252

@@ -334,7 +344,8 @@ def _test_quantization_device_map(self, config_kwargs):
334344
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
335345
assert model.hf_device_map is not None, "hf_device_map should not be None"
336346

337-
inputs = self.get_dummy_inputs()
347+
inputs = self._get_dummy_inputs_for_model(model)
348+
338349
output = model(**inputs, return_dict=False)[0]
339350
assert output is not None, "Model output is None"
340351
assert not torch.isnan(output).any(), "Model output contains NaN"
@@ -360,14 +371,7 @@ def _test_dequantize(self, config_kwargs):
360371
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
361372

362373
# Get model dtype from first parameter
363-
model_dtype = next(model.parameters()).dtype
364-
365-
inputs = self.get_dummy_inputs()
366-
# Cast inputs to model dtype
367-
inputs = {
368-
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
369-
for k, v in inputs.items()
370-
}
374+
inputs = self._get_dummy_inputs_for_model(model)
371375
output = model(**inputs, return_dict=False)[0]
372376
assert output is not None, "Model output is None after dequantization"
373377
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
@@ -413,7 +417,7 @@ def _test_quantization_training(self, config_kwargs):
413417
pytest.skip("No attention layers found in model for adapter training test")
414418

415419
# Step 3: run forward and backward pass
416-
inputs = self.get_dummy_inputs()
420+
inputs = self._get_dummy_inputs_for_model(model)
417421

418422
with torch.amp.autocast(torch_device, dtype=torch.float16):
419423
out = model(**inputs, return_dict=False)[0]
@@ -597,7 +601,8 @@ def test_bnb_keep_modules_in_fp32(self):
597601
f"Module {name} should be uint8 but is {module.weight.dtype}"
598602
)
599603

600-
inputs = self.get_dummy_inputs()
604+
inputs = self._get_dummy_inputs_for_model(model)
605+
601606
_ = model(**inputs)
602607
finally:
603608
if original_fp32_modules is not None:
@@ -915,7 +920,8 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path):
915920

916921
model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device))
917922

918-
inputs = self.get_dummy_inputs()
923+
inputs = self._get_dummy_inputs_for_model(model_loaded)
924+
919925
output = model_loaded(**inputs, return_dict=False)[0]
920926
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
921927

@@ -1197,7 +1203,8 @@ def _test_torch_compile(self, config_kwargs):
11971203
model = torch.compile(model, fullgraph=True)
11981204

11991205
with torch._dynamo.config.patch(error_on_recompile=True):
1200-
inputs = self.get_dummy_inputs()
1206+
inputs = self._get_dummy_inputs_for_model(model)
1207+
12011208
output = model(**inputs, return_dict=False)[0]
12021209
assert output is not None, "Model output is None"
12031210
assert not torch.isnan(output).any(), "Model output contains NaN"
@@ -1228,7 +1235,8 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False
12281235
model.enable_group_offload(**group_offload_kwargs)
12291236
model = torch.compile(model)
12301237

1231-
inputs = self.get_dummy_inputs()
1238+
inputs = self._get_dummy_inputs_for_model(model)
1239+
12321240
output = model(**inputs, return_dict=False)[0]
12331241
assert output is not None, "Model output is None"
12341242
assert not torch.isnan(output).any(), "Model output contains NaN"

0 commit comments

Comments
 (0)