|
| 1 | +import unittest |
| 2 | + |
| 3 | +import torch |
| 4 | +from parameterized import parameterized |
| 5 | + |
| 6 | +from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl |
| 7 | +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl |
| 8 | +from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl |
| 9 | +from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import ( |
| 10 | + MOEFeedForwardAOQuantizable, |
| 11 | +) |
| 12 | +from torchao.quantization.prototype.moe_quant.utils import ( |
| 13 | + FakeExtraDimTensor, |
| 14 | + MoEQuantConfig, |
| 15 | + UseFakeExtraDimTensor, |
| 16 | + cond_ffn_filter, |
| 17 | +) |
| 18 | +from torchao.quantization.quant_api import ( |
| 19 | + AffineQuantizedTensor, |
| 20 | + Float8DynamicActivationFloat8WeightConfig, |
| 21 | + Float8WeightOnlyConfig, |
| 22 | + Int4WeightOnlyConfig, |
| 23 | + Int8DynamicActivationInt8WeightConfig, |
| 24 | + Int8WeightOnlyConfig, |
| 25 | + LinearActivationQuantizedTensor, |
| 26 | + quantize_, |
| 27 | +) |
| 28 | +from torchao.quantization.utils import compute_error |
| 29 | +from torchao.utils import ( |
| 30 | + TORCH_VERSION_AT_LEAST_2_5, |
| 31 | + TORCH_VERSION_AT_LEAST_2_6, |
| 32 | + is_sm_at_least_90, |
| 33 | +) |
| 34 | + |
| 35 | + |
| 36 | +class TestMoEQuantCompile(unittest.TestCase): |
| 37 | + DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k |
| 38 | + |
| 39 | + @torch.no_grad() |
| 40 | + def _test_impl_moe_quant( |
| 41 | + self, |
| 42 | + config, |
| 43 | + num_tokens=1, |
| 44 | + model_params=None, |
| 45 | + base_class=AffineQuantizedTensor, |
| 46 | + tensor_impl_class=None, |
| 47 | + dtype=torch.bfloat16, |
| 48 | + device="cuda", |
| 49 | + fullgraph=False, |
| 50 | + ): |
| 51 | + """ |
| 52 | + Tests moe quant for techniques using fake extra dim |
| 53 | + """ |
| 54 | + if model_params is None: |
| 55 | + model_params = self.DEFAULT_PARAMS |
| 56 | + |
| 57 | + input_shape = (num_tokens, model_params[0]) |
| 58 | + model = ( |
| 59 | + MOEFeedForwardAOQuantizable(*model_params, empty_init=False) |
| 60 | + .to(dtype) |
| 61 | + .to(device) |
| 62 | + ) |
| 63 | + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) |
| 64 | + |
| 65 | + out = model(input) |
| 66 | + |
| 67 | + quantize_(model, config, cond_ffn_filter) |
| 68 | + |
| 69 | + if ( |
| 70 | + isinstance(config, MoEQuantConfig) |
| 71 | + and config.use_fake_extra_dim_tensor == UseFakeExtraDimTensor.TRUE |
| 72 | + ): |
| 73 | + self.assertIsInstance(model.experts.w1, FakeExtraDimTensor) |
| 74 | + if base_class is not None: |
| 75 | + self.assertIsInstance(model.experts.w1.head_tensor, base_class) |
| 76 | + if tensor_impl_class is not None: |
| 77 | + self.assertIsInstance( |
| 78 | + model.experts.w1.head_tensor.tensor_impl, tensor_impl_class |
| 79 | + ) |
| 80 | + else: |
| 81 | + if base_class is not None: |
| 82 | + self.assertIsInstance(model.experts.w1, base_class) |
| 83 | + if tensor_impl_class is not None: |
| 84 | + self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class) |
| 85 | + |
| 86 | + out_q = model(input) |
| 87 | + |
| 88 | + torch._dynamo.config.capture_scalar_outputs = True |
| 89 | + torch._dynamo.config.capture_dynamic_output_shape_ops = True |
| 90 | + model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph) |
| 91 | + |
| 92 | + model_c(input) |
| 93 | + model_c(input) |
| 94 | + out_qc = model_c(input).clone() |
| 95 | + |
| 96 | + for i in range(10): |
| 97 | + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) |
| 98 | + model_c(input) |
| 99 | + |
| 100 | + self.assertGreaterEqual(compute_error(out_q, out), 10) |
| 101 | + self.assertGreaterEqual(compute_error(out_qc, out), 10) |
| 102 | + |
| 103 | + @parameterized.expand( |
| 104 | + [ |
| 105 | + ("single_token", 1, False), |
| 106 | + ("multiple_tokens", 8, False), |
| 107 | + ] |
| 108 | + ) |
| 109 | + def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): |
| 110 | + if not torch.cuda.is_available(): |
| 111 | + self.skipTest("Need CUDA available") |
| 112 | + if not TORCH_VERSION_AT_LEAST_2_5: |
| 113 | + self.skipTest("Test only enabled for 2.5+") |
| 114 | + |
| 115 | + config = MoEQuantConfig( |
| 116 | + Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE |
| 117 | + ) |
| 118 | + tensor_impl_class = TensorCoreTiledAQTTensorImpl |
| 119 | + |
| 120 | + self._test_impl_moe_quant( |
| 121 | + config=config, |
| 122 | + num_tokens=num_tokens, |
| 123 | + tensor_impl_class=tensor_impl_class, |
| 124 | + fullgraph=fullgraph, |
| 125 | + ) |
| 126 | + |
| 127 | + @parameterized.expand( |
| 128 | + [ |
| 129 | + ("single_token", 1, True), |
| 130 | + ("multiple_tokens", 8, False), |
| 131 | + ] |
| 132 | + ) |
| 133 | + def test_int4wo_base(self, name, num_tokens, fullgraph): |
| 134 | + if not torch.cuda.is_available(): |
| 135 | + self.skipTest("Need CUDA available") |
| 136 | + if not is_sm_at_least_90(): |
| 137 | + self.skipTest("Requires CUDA capability >= 9.0") |
| 138 | + if not TORCH_VERSION_AT_LEAST_2_5: |
| 139 | + self.skipTest("Test only enabled for 2.5+") |
| 140 | + |
| 141 | + config = MoEQuantConfig(Int4WeightOnlyConfig()) |
| 142 | + tensor_impl_class = TensorCoreTiledAQTTensorImpl |
| 143 | + |
| 144 | + self._test_impl_moe_quant( |
| 145 | + config=config, |
| 146 | + num_tokens=num_tokens, |
| 147 | + tensor_impl_class=tensor_impl_class, |
| 148 | + fullgraph=fullgraph, |
| 149 | + ) |
| 150 | + |
| 151 | + @parameterized.expand( |
| 152 | + [ |
| 153 | + ("single_token", 1, False), |
| 154 | + ("multiple_tokens", 8, False), |
| 155 | + ] |
| 156 | + ) |
| 157 | + def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): |
| 158 | + if not torch.cuda.is_available(): |
| 159 | + self.skipTest("Need CUDA available") |
| 160 | + if not TORCH_VERSION_AT_LEAST_2_5: |
| 161 | + self.skipTest("Test only enabled for 2.5+") |
| 162 | + |
| 163 | + config = MoEQuantConfig( |
| 164 | + Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE |
| 165 | + ) |
| 166 | + tensor_impl_class = PlainAQTTensorImpl |
| 167 | + |
| 168 | + self._test_impl_moe_quant( |
| 169 | + config=config, |
| 170 | + num_tokens=num_tokens, |
| 171 | + tensor_impl_class=tensor_impl_class, |
| 172 | + fullgraph=fullgraph, |
| 173 | + ) |
| 174 | + |
| 175 | + @parameterized.expand( |
| 176 | + [ |
| 177 | + ("single_token", 1, True), |
| 178 | + ("multiple_tokens", 8, False), |
| 179 | + ] |
| 180 | + ) |
| 181 | + def test_int8wo_base(self, name, num_tokens, fullgraph): |
| 182 | + if not torch.cuda.is_available(): |
| 183 | + self.skipTest("Need CUDA available") |
| 184 | + if not TORCH_VERSION_AT_LEAST_2_6: |
| 185 | + self.skipTest("Test only enabled for 2.6+") |
| 186 | + |
| 187 | + config = MoEQuantConfig(Int8WeightOnlyConfig()) |
| 188 | + tensor_impl_class = PlainAQTTensorImpl |
| 189 | + |
| 190 | + self._test_impl_moe_quant( |
| 191 | + config=config, |
| 192 | + num_tokens=num_tokens, |
| 193 | + tensor_impl_class=tensor_impl_class, |
| 194 | + fullgraph=fullgraph, |
| 195 | + ) |
| 196 | + |
| 197 | + @parameterized.expand( |
| 198 | + [ |
| 199 | + ("single_token", 1, True), |
| 200 | + ("multiple_tokens", 8, False), |
| 201 | + ] |
| 202 | + ) |
| 203 | + def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): |
| 204 | + if not TORCH_VERSION_AT_LEAST_2_6: |
| 205 | + self.skipTest("Test only enabled for 2.6+") |
| 206 | + |
| 207 | + config = MoEQuantConfig(Int8WeightOnlyConfig()) |
| 208 | + tensor_impl_class = PlainAQTTensorImpl |
| 209 | + |
| 210 | + self._test_impl_moe_quant( |
| 211 | + config=config, |
| 212 | + num_tokens=num_tokens, |
| 213 | + tensor_impl_class=tensor_impl_class, |
| 214 | + fullgraph=fullgraph, |
| 215 | + device="cpu", |
| 216 | + ) |
| 217 | + |
| 218 | + @parameterized.expand( |
| 219 | + [ |
| 220 | + ("multiple_tokens", 32, False), |
| 221 | + ] |
| 222 | + ) |
| 223 | + def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): |
| 224 | + if not torch.cuda.is_available(): |
| 225 | + self.skipTest("Need CUDA available") |
| 226 | + if not TORCH_VERSION_AT_LEAST_2_5: |
| 227 | + self.skipTest("Test only enabled for 2.5+") |
| 228 | + |
| 229 | + config = MoEQuantConfig( |
| 230 | + Int8DynamicActivationInt8WeightConfig(), |
| 231 | + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, |
| 232 | + ) |
| 233 | + base_class = LinearActivationQuantizedTensor |
| 234 | + |
| 235 | + self._test_impl_moe_quant( |
| 236 | + model_params=(512, 256, 2, 2), |
| 237 | + config=config, |
| 238 | + num_tokens=num_tokens, |
| 239 | + base_class=base_class, |
| 240 | + fullgraph=fullgraph, |
| 241 | + ) |
| 242 | + |
| 243 | + @parameterized.expand( |
| 244 | + [ |
| 245 | + ("multiple_tokens", 32, False), |
| 246 | + ] |
| 247 | + ) |
| 248 | + def test_int8dq_base(self, name, num_tokens, fullgraph): |
| 249 | + if not torch.cuda.is_available(): |
| 250 | + self.skipTest("Need CUDA available") |
| 251 | + if not TORCH_VERSION_AT_LEAST_2_5: |
| 252 | + self.skipTest("Test only enabled for 2.5+") |
| 253 | + |
| 254 | + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) |
| 255 | + base_class = LinearActivationQuantizedTensor |
| 256 | + |
| 257 | + self._test_impl_moe_quant( |
| 258 | + model_params=(512, 256, 2, 2), |
| 259 | + config=config, |
| 260 | + num_tokens=num_tokens, |
| 261 | + base_class=base_class, |
| 262 | + fullgraph=fullgraph, |
| 263 | + ) |
| 264 | + |
| 265 | + @parameterized.expand( |
| 266 | + [ |
| 267 | + ("single_token", 1, False), |
| 268 | + ("multiple_tokens", 8, False), |
| 269 | + ] |
| 270 | + ) |
| 271 | + def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): |
| 272 | + if not torch.cuda.is_available(): |
| 273 | + self.skipTest("Need CUDA available") |
| 274 | + if not is_sm_at_least_90(): |
| 275 | + self.skipTest("Requires CUDA capability >= 9.0") |
| 276 | + |
| 277 | + config = MoEQuantConfig( |
| 278 | + Float8WeightOnlyConfig(), |
| 279 | + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, |
| 280 | + ) |
| 281 | + tensor_impl_class = Float8AQTTensorImpl |
| 282 | + |
| 283 | + self._test_impl_moe_quant( |
| 284 | + config=config, |
| 285 | + num_tokens=num_tokens, |
| 286 | + tensor_impl_class=tensor_impl_class, |
| 287 | + fullgraph=fullgraph, |
| 288 | + ) |
| 289 | + |
| 290 | + @parameterized.expand( |
| 291 | + [ |
| 292 | + ("single_token", 1, True), |
| 293 | + ("multiple_tokens", 8, False), |
| 294 | + ] |
| 295 | + ) |
| 296 | + def test_fp8wo_base(self, name, num_tokens, fullgraph): |
| 297 | + if not torch.cuda.is_available(): |
| 298 | + self.skipTest("Need CUDA available") |
| 299 | + if not is_sm_at_least_90(): |
| 300 | + self.skipTest("Requires CUDA capability >= 9.0") |
| 301 | + |
| 302 | + config = MoEQuantConfig(Float8WeightOnlyConfig()) |
| 303 | + tensor_impl_class = Float8AQTTensorImpl |
| 304 | + |
| 305 | + self._test_impl_moe_quant( |
| 306 | + config=config, |
| 307 | + num_tokens=num_tokens, |
| 308 | + tensor_impl_class=tensor_impl_class, |
| 309 | + fullgraph=fullgraph, |
| 310 | + ) |
| 311 | + |
| 312 | + @parameterized.expand( |
| 313 | + [ |
| 314 | + ("single_token", 1, False), |
| 315 | + ("multiple_tokens", 8, False), |
| 316 | + ] |
| 317 | + ) |
| 318 | + def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): |
| 319 | + if not torch.cuda.is_available(): |
| 320 | + self.skipTest("Need CUDA available") |
| 321 | + if not is_sm_at_least_90(): |
| 322 | + self.skipTest("Requires CUDA capability >= 9.0") |
| 323 | + |
| 324 | + config = MoEQuantConfig( |
| 325 | + Float8DynamicActivationFloat8WeightConfig(), |
| 326 | + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, |
| 327 | + ) |
| 328 | + base_class = LinearActivationQuantizedTensor |
| 329 | + |
| 330 | + self._test_impl_moe_quant( |
| 331 | + config=config, |
| 332 | + num_tokens=num_tokens, |
| 333 | + base_class=base_class, |
| 334 | + fullgraph=fullgraph, |
| 335 | + ) |
| 336 | + |
| 337 | + @parameterized.expand( |
| 338 | + [ |
| 339 | + ("single_token", 1, True), |
| 340 | + ("multiple_tokens", 8, False), |
| 341 | + ] |
| 342 | + ) |
| 343 | + def test_fp8dq_base(self, name, num_tokens, fullgraph): |
| 344 | + if not torch.cuda.is_available(): |
| 345 | + self.skipTest("Need CUDA available") |
| 346 | + if not is_sm_at_least_90(): |
| 347 | + self.skipTest("Requires CUDA capability >= 9.0") |
| 348 | + |
| 349 | + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig()) |
| 350 | + base_class = LinearActivationQuantizedTensor |
| 351 | + |
| 352 | + self._test_impl_moe_quant( |
| 353 | + config=config, |
| 354 | + num_tokens=num_tokens, |
| 355 | + base_class=base_class, |
| 356 | + fullgraph=fullgraph, |
| 357 | + ) |
| 358 | + |
| 359 | + |
| 360 | +if __name__ == "__main__": |
| 361 | + unittest.main() |
0 commit comments