Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions tests/models/testing_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def _is_module_quantized(self, module):
except (AssertionError, AttributeError):
return False

def _get_dummy_inputs_for_model(self, model):
inputs = self.get_dummy_inputs()
model_dtype = next(model.parameters()).dtype
return {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}

def _load_unquantized_model(self):
kwargs = getattr(self, "pretrained_model_kwargs", {})
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
Expand Down Expand Up @@ -174,7 +182,7 @@ def _test_quantization_inference(self, config_kwargs):
model_quantized = self._create_quantized_model(config_kwargs)
model_quantized.to(torch_device)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model_quantized)
output = model_quantized(**inputs, return_dict=False)[0]

assert output is not None, "Model output is None"
Expand Down Expand Up @@ -222,7 +230,8 @@ def _test_quantization_lora_inference(self, config_kwargs):
# Move LoRA adapter weights to device (they default to CPU)
model.to(torch_device)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]

assert output is not None, "Model output is None with LoRA"
Expand All @@ -236,7 +245,8 @@ def _test_quantization_serialization(self, config_kwargs, tmp_path):

model_loaded = self.model_class.from_pretrained(str(tmp_path))

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model_loaded)

output = model_loaded(**inputs, return_dict=False)[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"

Expand Down Expand Up @@ -334,7 +344,8 @@ def _test_quantization_device_map(self, config_kwargs):
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
assert model.hf_device_map is not None, "hf_device_map should not be None"

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
Expand All @@ -359,7 +370,8 @@ def _test_dequantize(self, config_kwargs):
if isinstance(module, torch.nn.Linear):
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"

inputs = self.get_dummy_inputs()
# Get model dtype from first parameter
inputs = self._get_dummy_inputs_for_model(model)
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None after dequantization"
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
Expand Down Expand Up @@ -405,7 +417,7 @@ def _test_quantization_training(self, config_kwargs):
pytest.skip("No attention layers found in model for adapter training test")

# Step 3: run forward and backward pass
inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

with torch.amp.autocast(torch_device, dtype=torch.float16):
out = model(**inputs, return_dict=False)[0]
Expand Down Expand Up @@ -587,7 +599,7 @@ def test_bnb_keep_modules_in_fp32(self):
f"Module {name} should be uint8 but is {module.weight.dtype}"
)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)
_ = model(**inputs)

def test_bnb_modules_to_not_convert(self):
Expand Down Expand Up @@ -902,7 +914,8 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path):

model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device))

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model_loaded)

output = model_loaded(**inputs, return_dict=False)[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"

Expand Down Expand Up @@ -1159,6 +1172,14 @@ class QuantizationCompileTesterMixin:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""

def _get_dummy_inputs_for_model(self, model):
inputs = self.get_dummy_inputs()
model_dtype = next(model.parameters()).dtype
return {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}

def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
Expand All @@ -1184,7 +1205,8 @@ def _test_torch_compile(self, config_kwargs):
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True):
inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
Expand Down Expand Up @@ -1215,7 +1237,8 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False
model.enable_group_offload(**group_offload_kwargs)
model = torch.compile(model)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
Expand Down
Loading