|
13 | 13 |
|
14 | 14 | from benchmarks.microbenchmarks.benchmark_runner import (
|
15 | 15 | get_param_combinations,
|
| 16 | + get_quantization_sparsity_recipes, |
16 | 17 | get_shapes_for_config,
|
17 | 18 | load_benchmark_configs,
|
18 | 19 | run_inference_benchmarks_from_config,
|
@@ -88,6 +89,109 @@ def test_run_inference_benchmarks_from_config(self):
|
88 | 89 | results_file = Path(self.temp_dir) / "results.csv"
|
89 | 90 | self.assertTrue(results_file.exists())
|
90 | 91 |
|
| 92 | + def test_get_quantization_sparsity_recipes(self): |
| 93 | + """Test generation of valid quantization and sparsity recipe combinations""" |
| 94 | + # Test basic combinations |
| 95 | + quant_recipes = ["baseline", "int8wo"] |
| 96 | + sparse_recipes = [None, "semi-sparse"] |
| 97 | + recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) |
| 98 | + self.assertIn(("baseline", None), recipes) |
| 99 | + self.assertIn(("int8wo", None), recipes) |
| 100 | + self.assertIn(("baseline", "semi-sparse"), recipes) |
| 101 | + |
| 102 | + # Test marlin with semi-sparse |
| 103 | + quant_recipes = ["marlin", "baseline"] |
| 104 | + sparse_recipes = [None, "semi-sparse"] |
| 105 | + recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) |
| 106 | + self.assertIn(("marlin", "semi-sparse"), recipes) |
| 107 | + self.assertIn(("baseline", None), recipes) |
| 108 | + |
| 109 | + # Test block sparsity |
| 110 | + quant_recipes = ["baseline"] |
| 111 | + sparse_recipes = [None, "block"] |
| 112 | + recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) |
| 113 | + self.assertIn(("baseline", "block"), recipes) |
| 114 | + |
| 115 | + def test_none_string_raises_error(self): |
| 116 | + """Test that passing 'None' as a string raises an error""" |
| 117 | + quant_recipes = ["baseline"] |
| 118 | + sparse_recipes = ["None"] # "None" as a string should raise an error |
| 119 | + with self.assertRaises(ValueError): |
| 120 | + get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) |
| 121 | + |
| 122 | + def test_block_sparsity_with_quantization(self): |
| 123 | + """Test that block sparsity is only paired with baseline quantization""" |
| 124 | + quant_recipes = ["baseline", "int8wo", "int4wo", "marlin"] |
| 125 | + sparse_recipes = ["block"] |
| 126 | + recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) |
| 127 | + |
| 128 | + # Block sparsity should only be paired with baseline |
| 129 | + self.assertIn(("baseline", "block"), recipes) |
| 130 | + self.assertNotIn(("int8wo", "block"), recipes) |
| 131 | + self.assertNotIn(("int4wo", "block"), recipes) |
| 132 | + self.assertNotIn(("marlin", "block"), recipes) |
| 133 | + |
| 134 | + # All quantization techniques should be run without sparsity |
| 135 | + self.assertIn(("baseline", None), recipes) |
| 136 | + self.assertIn(("int8wo", None), recipes) |
| 137 | + self.assertIn(("int4wo", None), recipes) |
| 138 | + self.assertIn(("marlin", None), recipes) |
| 139 | + |
| 140 | + def test_all_quantization_without_sparsity(self): |
| 141 | + """Test that all quantization techniques are run without sparsity""" |
| 142 | + quant_recipes = ["baseline", "int8wo", "int4wo", "marlin"] |
| 143 | + sparse_recipes = [None, "semi-sparse", "block"] |
| 144 | + recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes) |
| 145 | + |
| 146 | + # All quantization techniques should be run without sparsity |
| 147 | + for quant in quant_recipes: |
| 148 | + self.assertIn((quant, None), recipes) |
| 149 | + |
| 150 | + @patch( |
| 151 | + "benchmarks.microbenchmarks.benchmark_runner.get_quantization_sparsity_recipes" |
| 152 | + ) |
| 153 | + def test_load_benchmark_configs_with_sparsity(self, mock_get_recipes): |
| 154 | + """Test loading benchmark configs with sparsity options""" |
| 155 | + # Mock get_quantization_sparsity_recipes to return a valid set of recipes |
| 156 | + mock_get_recipes.return_value = {("baseline", None), ("marlin", "semi-sparse")} |
| 157 | + |
| 158 | + test_config = { |
| 159 | + "benchmark_mode": "inference", |
| 160 | + "quantization_config_recipe_names": ["baseline", "marlin"], |
| 161 | + "sparsity_config_recipe_names": [ |
| 162 | + None, |
| 163 | + "semi-sparse", |
| 164 | + ], # Use None instead of "None" |
| 165 | + "output_dir": self.temp_dir, |
| 166 | + "model_params": [ |
| 167 | + { |
| 168 | + "matrix_shapes": [ |
| 169 | + {"name": "custom", "shapes": [[1024, 1024, 1024]]} |
| 170 | + ], |
| 171 | + "high_precision_dtype": "torch.bfloat16", |
| 172 | + "device": "cpu", |
| 173 | + "model_type": "linear", |
| 174 | + } |
| 175 | + ], |
| 176 | + } |
| 177 | + |
| 178 | + config_path = Path(self.temp_dir) / "test_sparsity_config.yml" |
| 179 | + with open(config_path, "w") as f: |
| 180 | + yaml.dump(test_config, f) |
| 181 | + |
| 182 | + configs = load_benchmark_configs(argparse.Namespace(config=str(config_path))) |
| 183 | + |
| 184 | + # Check that we get configs for baseline and marlin with appropriate sparsity |
| 185 | + self.assertTrue( |
| 186 | + any(c.quantization == "baseline" and c.sparsity is None for c in configs) |
| 187 | + ) |
| 188 | + self.assertTrue( |
| 189 | + any( |
| 190 | + c.quantization == "marlin" and c.sparsity == "semi-sparse" |
| 191 | + for c in configs |
| 192 | + ) |
| 193 | + ) |
| 194 | + |
91 | 195 |
|
92 | 196 | if __name__ == "__main__":
|
93 | 197 | unittest.main()
|
0 commit comments