@@ -142,6 +142,14 @@ def _is_module_quantized(self, module):
142142 except (AssertionError , AttributeError ):
143143 return False
144144
145+ def _get_dummy_inputs_for_model (self , model ):
146+ inputs = self .get_dummy_inputs ()
147+ model_dtype = next (model .parameters ()).dtype
148+ return {
149+ k : v .to (model_dtype ) if isinstance (v , torch .Tensor ) and v .is_floating_point () else v
150+ for k , v in inputs .items ()
151+ }
152+
145153 def _load_unquantized_model (self ):
146154 kwargs = getattr (self , "pretrained_model_kwargs" , {})
147155 return self .model_class .from_pretrained (self .pretrained_model_name_or_path , ** kwargs )
@@ -174,7 +182,7 @@ def _test_quantization_inference(self, config_kwargs):
174182 model_quantized = self ._create_quantized_model (config_kwargs )
175183 model_quantized .to (torch_device )
176184
177- inputs = self .get_dummy_inputs ( )
185+ inputs = self ._get_dummy_inputs_for_model ( model_quantized )
178186 output = model_quantized (** inputs , return_dict = False )[0 ]
179187
180188 assert output is not None , "Model output is None"
@@ -222,7 +230,8 @@ def _test_quantization_lora_inference(self, config_kwargs):
222230 # Move LoRA adapter weights to device (they default to CPU)
223231 model .to (torch_device )
224232
225- inputs = self .get_dummy_inputs ()
233+ inputs = self ._get_dummy_inputs_for_model (model )
234+
226235 output = model (** inputs , return_dict = False )[0 ]
227236
228237 assert output is not None , "Model output is None with LoRA"
@@ -236,7 +245,8 @@ def _test_quantization_serialization(self, config_kwargs, tmp_path):
236245
237246 model_loaded = self .model_class .from_pretrained (str (tmp_path ))
238247
239- inputs = self .get_dummy_inputs ()
248+ inputs = self ._get_dummy_inputs_for_model (model_loaded )
249+
240250 output = model_loaded (** inputs , return_dict = False )[0 ]
241251 assert not torch .isnan (output ).any (), "Loaded model output contains NaN"
242252
@@ -334,7 +344,8 @@ def _test_quantization_device_map(self, config_kwargs):
334344 assert hasattr (model , "hf_device_map" ), "Model should have hf_device_map attribute"
335345 assert model .hf_device_map is not None , "hf_device_map should not be None"
336346
337- inputs = self .get_dummy_inputs ()
347+ inputs = self ._get_dummy_inputs_for_model (model )
348+
338349 output = model (** inputs , return_dict = False )[0 ]
339350 assert output is not None , "Model output is None"
340351 assert not torch .isnan (output ).any (), "Model output contains NaN"
@@ -360,14 +371,7 @@ def _test_dequantize(self, config_kwargs):
360371 assert not self ._is_module_quantized (module ), f"Module { name } is still quantized after dequantize()"
361372
362373 # Get model dtype from first parameter
363- model_dtype = next (model .parameters ()).dtype
364-
365- inputs = self .get_dummy_inputs ()
366- # Cast inputs to model dtype
367- inputs = {
368- k : v .to (model_dtype ) if isinstance (v , torch .Tensor ) and v .is_floating_point () else v
369- for k , v in inputs .items ()
370- }
374+ inputs = self ._get_dummy_inputs_for_model (model )
371375 output = model (** inputs , return_dict = False )[0 ]
372376 assert output is not None , "Model output is None after dequantization"
373377 assert not torch .isnan (output ).any (), "Model output contains NaN after dequantization"
@@ -413,7 +417,7 @@ def _test_quantization_training(self, config_kwargs):
413417 pytest .skip ("No attention layers found in model for adapter training test" )
414418
415419 # Step 3: run forward and backward pass
416- inputs = self .get_dummy_inputs ( )
420+ inputs = self ._get_dummy_inputs_for_model ( model )
417421
418422 with torch .amp .autocast (torch_device , dtype = torch .float16 ):
419423 out = model (** inputs , return_dict = False )[0 ]
@@ -597,7 +601,8 @@ def test_bnb_keep_modules_in_fp32(self):
597601 f"Module { name } should be uint8 but is { module .weight .dtype } "
598602 )
599603
600- inputs = self .get_dummy_inputs ()
604+ inputs = self ._get_dummy_inputs_for_model (model )
605+
601606 _ = model (** inputs )
602607 finally :
603608 if original_fp32_modules is not None :
@@ -915,7 +920,8 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path):
915920
916921 model_loaded = self .model_class .from_pretrained (str (tmp_path ), device_map = str (torch_device ))
917922
918- inputs = self .get_dummy_inputs ()
923+ inputs = self ._get_dummy_inputs_for_model (model_loaded )
924+
919925 output = model_loaded (** inputs , return_dict = False )[0 ]
920926 assert not torch .isnan (output ).any (), "Loaded model output contains NaN"
921927
@@ -1197,7 +1203,8 @@ def _test_torch_compile(self, config_kwargs):
11971203 model = torch .compile (model , fullgraph = True )
11981204
11991205 with torch ._dynamo .config .patch (error_on_recompile = True ):
1200- inputs = self .get_dummy_inputs ()
1206+ inputs = self ._get_dummy_inputs_for_model (model )
1207+
12011208 output = model (** inputs , return_dict = False )[0 ]
12021209 assert output is not None , "Model output is None"
12031210 assert not torch .isnan (output ).any (), "Model output contains NaN"
@@ -1228,7 +1235,8 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False
12281235 model .enable_group_offload (** group_offload_kwargs )
12291236 model = torch .compile (model )
12301237
1231- inputs = self .get_dummy_inputs ()
1238+ inputs = self ._get_dummy_inputs_for_model (model )
1239+
12321240 output = model (** inputs , return_dict = False )[0 ]
12331241 assert output is not None , "Model output is None"
12341242 assert not torch .isnan (output ).any (), "Model output contains NaN"
0 commit comments