2424import modelopt .torch .opt as mto
2525import modelopt .torch .quantization as mtq
2626from modelopt .torch .quantization .algorithms import (
27+ AutoQuantizeGradientSearcher ,
2728 QuantRecipe ,
2829 QuantRecipeHparam ,
2930 estimate_quant_compression ,
3031)
3132from modelopt .torch .quantization .config import _base_disable_all , _default_disabled_quantizer_cfg
33+ from modelopt .torch .quantization .model_quant import _normalize_auto_quantize_formats
3234from modelopt .torch .utils import safe_load
3335from modelopt .torch .utils .distributed import DistributedProcessGroup
3436
@@ -62,6 +64,11 @@ def get_input(self):
6264 return torch .randn (1 , 4 , 32 )
6365
6466
67+ def _recipe (quant_cfg ):
68+ name = None if quant_cfg is None else QuantRecipe .get_auto_name_for_config (quant_cfg )
69+ return QuantRecipe (quant_cfg , name = name )
70+
71+
6572@pytest .mark .parametrize (
6673 ("quant_cfg" , "other_quant_cfg" , "is_less_than" ),
6774 [
@@ -71,30 +78,85 @@ def get_input(self):
7178 ],
7279)
7380def test_quant_recipe (quant_cfg , other_quant_cfg , is_less_than ):
74- qr_this = QuantRecipe (quant_cfg )
75- qr_other = QuantRecipe (other_quant_cfg )
81+ qr_this = _recipe (quant_cfg )
82+ qr_other = _recipe (other_quant_cfg )
7683 assert (qr_this < qr_other ) == is_less_than
7784
78- qr_this_duplicate = QuantRecipe (quant_cfg )
85+ qr_this_duplicate = _recipe (quant_cfg )
7986 assert qr_this_duplicate in {qr_this }
8087
8188
82- def test_quant_recipe_custom_quantize_config_requires_name ( ):
83- custom_cfg = mtq .QuantizeConfig (
89+ def _custom_quantize_config ( path ):
90+ return mtq .QuantizeConfig (
8491 quant_cfg = [
8592 mtq .QuantizerCfgEntry (
86- quantizer_name = "*weight_quantizer" ,
93+ quantizer_name = path ,
8794 cfg = mtq .QuantizerAttributeConfig (num_bits = 8 , axis = None ),
8895 )
8996 ]
9097 )
9198
99+
100+ def test_quant_recipe_custom_quantize_config_requires_name ():
101+ custom_cfg = _custom_quantize_config ("*custom_weight_quantizer" )
102+
92103 with pytest .raises (ValueError , match = "name must be provided" ):
93104 QuantRecipe (custom_cfg )
94105
95106 assert str (QuantRecipe (custom_cfg , name = "custom_cfg" )).startswith ("custom_cfg(" )
96107
97108
109+ def test_quant_recipe_none_requires_no_name ():
110+ assert str (QuantRecipe (quant_cfg = None )).startswith ("NONE(" )
111+
112+ with pytest .raises (ValueError , match = "name must be None" ):
113+ QuantRecipe (quant_cfg = None , name = "NONE" )
114+
115+
116+ def test_quant_recipe_honors_explicit_name ():
117+ assert str (QuantRecipe (mtq .INT8_DEFAULT_CFG , name = "int8_alias" )).startswith ("int8_alias(" )
118+
119+
120+ def test_auto_quantize_search_config_requires_named_formats ():
121+ custom_a = _custom_quantize_config ("*custom_weight_quantizer_a" )
122+ custom_b = _custom_quantize_config ("*custom_weight_quantizer_b" )
123+ searcher = AutoQuantizeGradientSearcher ()
124+
125+ with pytest .warns (UserWarning ) as records :
126+ quantization_formats = _normalize_auto_quantize_formats ([custom_a , custom_b ])
127+
128+ assert quantization_formats == [(custom_a , "CUSTOM_0" ), (custom_b , "CUSTOM_1" )]
129+ assert any ("CUSTOM_0" in str (record .message ) for record in records )
130+ assert any ("CUSTOM_1" in str (record .message ) for record in records )
131+
132+ config = searcher .sanitize_search_config (
133+ {
134+ "quantization_formats" : quantization_formats ,
135+ "data_loader" : [torch .randn (1 )],
136+ "forward_step" : lambda model , data : data ,
137+ "loss_func" : lambda output , data : output .sum (),
138+ }
139+ )
140+ assert config ["quantization_formats" ] == quantization_formats
141+
142+ with pytest .raises (TypeError , match = "Named quantization format tuples are internal" ):
143+ _normalize_auto_quantize_formats ([(custom_a , "custom_a" )])
144+
145+ with pytest .raises (TypeError , match = "must be a list of" ):
146+ searcher .sanitize_search_config (
147+ {
148+ "quantization_formats" : [custom_a ],
149+ "data_loader" : [torch .randn (1 )],
150+ "forward_step" : lambda model , data : data ,
151+ "loss_func" : lambda output , data : output .sum (),
152+ }
153+ )
154+
155+ recipes = AutoQuantizeGradientSearcher ._get_search_recipes (config ["quantization_formats" ])
156+ assert {str (recipe ).split ("(" , 1 )[0 ] for recipe in recipes } == {"CUSTOM_0" , "CUSTOM_1" }
157+ assert len (set (recipes )) == 2
158+
159+
98160def test_quant_recipe_hparam ():
99161 model_test = torch .nn .Linear (4 , 16 )
100162 model_ref = torch .nn .Linear (4 , 16 )
@@ -104,20 +166,20 @@ def test_quant_recipe_hparam():
104166 model_ref = mtq .quantize (model_ref , mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG )
105167
106168 search_recipes = [
107- QuantRecipe (mtq .INT8_DEFAULT_CFG ),
108- QuantRecipe (mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG ),
169+ _recipe (mtq .INT8_DEFAULT_CFG ),
170+ _recipe (mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG ),
109171 ]
110172 hparam = QuantRecipeHparam (
111173 search_recipes ,
112174 quant_modules = [model_test ],
113175 )
114176 model_test ._register_hparam ("quant_recipe" , hparam )
115- assert model_test .quant_recipe == QuantRecipe (mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG )
177+ assert model_test .quant_recipe == _recipe (mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG )
116178 assert model_test .get_hparam ("quant_recipe" ).choices == sorted (
117179 [* search_recipes , QuantRecipe (quant_cfg = None )]
118180 )
119181
120- model_test .quant_recipe = QuantRecipe (mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG )
182+ model_test .quant_recipe = _recipe (mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG )
121183 inputs = torch .randn (1 , 4 , 4 )
122184 output_test = model_test (inputs )
123185 output_ref = model_ref (inputs )
@@ -244,7 +306,7 @@ def test_auto_quantize_disabled_layers_no_poison():
244306
245307 assert not best_model .mlp .input_quantizer .is_enabled
246308 hparam = best_model .attn .q_proj .get_hparam ("quant_recipe" )
247- assert QuantRecipe (mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG ) in hparam .choices
309+ assert _recipe (mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG ) in hparam .choices
248310
249311
250312INT4INT8_AWQ_CFG = {
0 commit comments