|
25 | 25 | from torch._inductor.test_case import TestCase as InductorTestCase
|
26 | 26 | from torch.testing._internal import common_utils
|
27 | 27 |
|
| 28 | +from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl |
28 | 29 | from torchao.float8.float8_utils import compute_error
|
29 | 30 | from torchao.quantization import (
|
30 | 31 | Float8DynamicActivationFloat8WeightConfig,
|
|
42 | 43 | from torchao.quantization.quant_primitives import (
|
43 | 44 | MappingType,
|
44 | 45 | choose_qparams_affine,
|
| 46 | + choose_qparams_affine_float8, |
| 47 | + dequantize_affine_float8, |
| 48 | + quantize_affine_float8, |
45 | 49 | )
|
46 | 50 | from torchao.utils import (
|
47 | 51 | is_sm_at_least_89,
|
@@ -297,21 +301,299 @@ def test_fp8_weight_dimension_warning(self):
|
297 | 301 | @unittest.skipIf(
|
298 | 302 | not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
|
299 | 303 | )
|
300 |
| - def test_mm_float8dq(self): |
| 304 | + @common_utils.parametrize( |
| 305 | + "in_features,out_features", [(512, 1024), (256, 768), (1024, 512)] |
| 306 | + ) |
| 307 | + @common_utils.parametrize( |
| 308 | + "leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)] |
| 309 | + ) # fmt: skip |
| 310 | + @common_utils.parametrize("bias", [True, False]) |
| 311 | + def test_mm_float8dq_per_row( |
| 312 | + self, in_features, out_features, leading_shape, bias: bool |
| 313 | + ): |
| 314 | + device = "cuda" |
| 315 | + dtype = torch.bfloat16 |
| 316 | + input_shape = leading_shape + (in_features,) |
| 317 | + |
| 318 | + ref_linear = ( |
| 319 | + torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype) |
| 320 | + ) |
| 321 | + test_linear = copy.deepcopy(ref_linear) |
| 322 | + quantize_( |
| 323 | + test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) |
| 324 | + ) |
| 325 | + |
| 326 | + quant_weight = test_linear.weight |
| 327 | + |
| 328 | + self.assertTrue(hasattr(quant_weight, "original_weight_tensor")) |
| 329 | + weight_impl = quant_weight.original_weight_tensor.tensor_impl |
| 330 | + |
| 331 | + self.assertTrue(hasattr(weight_impl, "float8_data")) |
| 332 | + self.assertTrue(hasattr(weight_impl, "scale")) |
| 333 | + self.assertFalse(weight_impl.transposed) |
| 334 | + |
| 335 | + # Verify scale shape for row-wise quantization |
| 336 | + expected_scale_shape = (out_features, 1) |
| 337 | + actual_scale_shape = weight_impl.scale.shape |
| 338 | + self.assertEqual(actual_scale_shape, expected_scale_shape) |
| 339 | + |
| 340 | + self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features)) |
| 341 | + |
| 342 | + input_tensor = torch.randn(*input_shape, device=device, dtype=dtype) |
| 343 | + |
| 344 | + with torch.no_grad(): |
| 345 | + ref_output = ref_linear(input_tensor) |
| 346 | + quant_output = torch.nn.functional.linear(input_tensor, quant_weight) |
| 347 | + |
| 348 | + expected_output_shape = input_tensor.shape[:-1] + (out_features,) |
| 349 | + self.assertEqual(quant_output.shape, expected_output_shape) |
| 350 | + |
| 351 | + error = compute_error(ref_output, quant_output) |
| 352 | + assert error > 20, f"Quantization error is too high got a SQNR of {error}" |
| 353 | + |
| 354 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 355 | + @unittest.skipIf( |
| 356 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 357 | + ) |
| 358 | + @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) |
| 359 | + @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) |
| 360 | + @common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)]) |
| 361 | + def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): |
| 362 | + """Test dequantize_affine_float8 with various configurations""" |
| 363 | + |
| 364 | + device = "cuda" |
| 365 | + input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) |
| 366 | + |
| 367 | + # Choose quantization parameters |
| 368 | + scale = choose_qparams_affine_float8( |
| 369 | + input_tensor, float8_dtype=float8_dtype, block_size=block_size |
| 370 | + ) |
| 371 | + |
| 372 | + # Quantize |
| 373 | + quantized = quantize_affine_float8(input_tensor, scale, float8_dtype) |
| 374 | + |
| 375 | + # Dequantize |
| 376 | + dequantized = dequantize_affine_float8(quantized, scale, output_dtype) |
| 377 | + |
| 378 | + # Verify output properties |
| 379 | + self.assertEqual(dequantized.dtype, output_dtype) |
| 380 | + self.assertEqual(dequantized.shape, input_tensor.shape) |
| 381 | + self.assertEqual(dequantized.device, input_tensor.device) |
| 382 | + |
| 383 | + # Verify quantization/dequantization roundtrip is reasonable |
| 384 | + error = torch.abs(input_tensor.to(output_dtype) - dequantized).mean() |
| 385 | + self.assertLess(error, 0.1, "Quantization error too high") |
| 386 | + |
| 387 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 388 | + @unittest.skipIf( |
| 389 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 390 | + ) |
| 391 | + def test_dequantize_affine_float8_scale_broadcasting(self): |
| 392 | + """Test that scale broadcasting works correctly for block-wise quantization""" |
| 393 | + device = "cuda" |
| 394 | + # Create input tensor with known block structure |
| 395 | + input_tensor = torch.randn(4, 32, device=device, dtype=torch.float32) |
| 396 | + block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim |
| 397 | + |
| 398 | + # Choose quantization parameters |
| 399 | + scale = choose_qparams_affine_float8( |
| 400 | + input_tensor, float8_dtype=torch.float8_e4m3fn, block_size=block_size |
| 401 | + ) |
| 402 | + |
| 403 | + # Verify scale shape |
| 404 | + expected_scale_shape = ( |
| 405 | + input_tensor.shape[0] // block_size[0], |
| 406 | + input_tensor.shape[1] // block_size[1], |
| 407 | + ) |
| 408 | + self.assertEqual(scale.shape, expected_scale_shape) |
| 409 | + |
| 410 | + # Quantize |
| 411 | + quantized = quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn) |
| 412 | + |
| 413 | + # Dequantize |
| 414 | + dequantized = dequantize_affine_float8(quantized, scale, torch.float32) |
| 415 | + |
| 416 | + # Verify shapes match |
| 417 | + self.assertEqual(dequantized.shape, input_tensor.shape) |
| 418 | + |
| 419 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 420 | + @unittest.skipIf( |
| 421 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 422 | + ) |
| 423 | + @common_utils.parametrize( |
| 424 | + "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()] |
| 425 | + ) |
| 426 | + def test_float8_tensor_slicing_basic(self, granularity): |
| 427 | + """Test basic slicing operations on Float8 tensors""" |
301 | 428 | device = "cuda"
|
302 | 429 | dtype = torch.bfloat16
|
303 |
| - weight = torch.randn(512, 1024).to(device).to(dtype) |
304 |
| - weight = weight.t() |
305 |
| - |
306 |
| - l = torch.nn.Linear(512, 1024).to(device).to(dtype) |
307 |
| - l.weight = torch.nn.Parameter(weight) |
308 |
| - quantize_(l, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) |
309 |
| - # weight shape: 1024 x 512 |
310 |
| - weight = l.weight |
311 |
| - |
312 |
| - input = torch.randn(1, 512, device=device, dtype=dtype) |
313 |
| - # make sure it runs |
314 |
| - torch.nn.functional.linear(input, weight) |
| 430 | + |
| 431 | + # Create and quantize a model |
| 432 | + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) |
| 433 | + quantize_( |
| 434 | + model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity) |
| 435 | + ) |
| 436 | + |
| 437 | + weight_impl = model.weight.original_weight_tensor.tensor_impl |
| 438 | + |
| 439 | + # Test dimension 0 slicing (rows) |
| 440 | + sliced_0 = weight_impl[10:20] |
| 441 | + self.assertEqual(sliced_0.shape, (10, 64)) |
| 442 | + |
| 443 | + # Test dimension 1 slicing (columns) |
| 444 | + sliced_1 = weight_impl[:, 20:40] |
| 445 | + self.assertEqual(sliced_1.shape, (32, 20)) |
| 446 | + |
| 447 | + # Test combined slicing |
| 448 | + sliced_both = weight_impl[5:15, 10:30] |
| 449 | + self.assertEqual(sliced_both.shape, (10, 20)) |
| 450 | + |
| 451 | + # Verify the sliced tensors are still Float8 tensors |
| 452 | + self.assertTrue(isinstance(sliced_0, Float8AQTTensorImpl)) |
| 453 | + self.assertTrue(isinstance(sliced_1, Float8AQTTensorImpl)) |
| 454 | + self.assertTrue(isinstance(sliced_both, Float8AQTTensorImpl)) |
| 455 | + |
| 456 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 457 | + @unittest.skipIf( |
| 458 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 459 | + ) |
| 460 | + def test_float8_tensor_slicing_per_tensor(self): |
| 461 | + """Test slicing with per-tensor quantization (scale should not change)""" |
| 462 | + device = "cuda" |
| 463 | + dtype = torch.bfloat16 |
| 464 | + |
| 465 | + # Create and quantize with per-tensor granularity |
| 466 | + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) |
| 467 | + quantize_( |
| 468 | + model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) |
| 469 | + ) |
| 470 | + |
| 471 | + original_weight = model.weight |
| 472 | + original_impl = original_weight.original_weight_tensor.tensor_impl |
| 473 | + original_scale = original_impl.scale |
| 474 | + |
| 475 | + # Test slicing |
| 476 | + sliced_weight = original_weight[10:20, 20:40] |
| 477 | + sliced_impl = sliced_weight.original_weight_tensor.tensor_impl |
| 478 | + |
| 479 | + # For per-tensor quantization, scale should be identical |
| 480 | + self.assertTrue(torch.equal(original_scale, sliced_impl.scale)) |
| 481 | + self.assertEqual(sliced_impl.scale.numel(), 1) |
| 482 | + |
| 483 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 484 | + @unittest.skipIf( |
| 485 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 486 | + ) |
| 487 | + @unittest.skipIf( |
| 488 | + not is_sm_at_least_90(), |
| 489 | + "Per-row quantization requires compute capability >= 9.0", |
| 490 | + ) |
| 491 | + def test_float8_tensor_slicing_per_row(self): |
| 492 | + """Test slicing with per-row quantization (scale should be sliced appropriately)""" |
| 493 | + device = "cuda" |
| 494 | + dtype = torch.bfloat16 |
| 495 | + |
| 496 | + # Create and quantize with per-row granularity |
| 497 | + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) |
| 498 | + quantize_( |
| 499 | + model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) |
| 500 | + ) |
| 501 | + |
| 502 | + original_weight = model.weight # Shape: (32, 64) |
| 503 | + original_impl = original_weight.original_weight_tensor.tensor_impl |
| 504 | + original_scale = original_impl.scale # Shape: (32, 1) |
| 505 | + |
| 506 | + # Test row slicing (dimension 0) |
| 507 | + sliced_rows = original_weight[10:20] # Shape: (10, 64) |
| 508 | + sliced_impl = sliced_rows.original_weight_tensor.tensor_impl |
| 509 | + |
| 510 | + # Scale should be sliced to match the rows |
| 511 | + expected_scale_shape = (10, 1) |
| 512 | + self.assertEqual(sliced_impl.scale.shape, expected_scale_shape) |
| 513 | + |
| 514 | + # Verify the scale values are correct (should be subset of original) |
| 515 | + self.assertTrue(torch.equal(sliced_impl.scale, original_scale[10:20])) |
| 516 | + |
| 517 | + # Test column slicing (dimension 1) - scale should not change for per-row |
| 518 | + sliced_cols = original_weight[:, 20:40] # Shape: (32, 20) |
| 519 | + sliced_cols_impl = sliced_cols.original_weight_tensor.tensor_impl |
| 520 | + |
| 521 | + # Scale shape should remain the same since we're not changing rows |
| 522 | + self.assertEqual(sliced_cols_impl.scale.shape, (32, 1)) |
| 523 | + self.assertTrue(torch.equal(sliced_cols_impl.scale, original_scale)) |
| 524 | + |
| 525 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 526 | + @unittest.skipIf( |
| 527 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 528 | + ) |
| 529 | + def test_float8_tensor_slicing_edge_cases(self): |
| 530 | + """Test edge cases in slicing""" |
| 531 | + device = "cuda" |
| 532 | + dtype = torch.bfloat16 |
| 533 | + |
| 534 | + # Create and quantize a model |
| 535 | + model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) |
| 536 | + quantize_( |
| 537 | + model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) |
| 538 | + ) |
| 539 | + |
| 540 | + original_weight = model.weight |
| 541 | + |
| 542 | + # Test empty slice |
| 543 | + empty_slice = original_weight[0:0] |
| 544 | + self.assertEqual(empty_slice.shape, (0, 64)) |
| 545 | + |
| 546 | + # Test single element slice |
| 547 | + single_row = original_weight[0:1] |
| 548 | + self.assertEqual(single_row.shape, (1, 64)) |
| 549 | + |
| 550 | + # Test out of bounds (should be handled by PyTorch) |
| 551 | + large_slice = original_weight[:100] # More than available rows |
| 552 | + self.assertEqual(large_slice.shape, (32, 64)) # Should clamp to available |
| 553 | + |
| 554 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 555 | + @unittest.skipIf( |
| 556 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 557 | + ) |
| 558 | + @common_utils.parametrize( |
| 559 | + "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()] |
| 560 | + ) |
| 561 | + def test_float8_tensor_slicing_functional_correctness(self, granularity): |
| 562 | + """Test that sliced tensors produce correct results in computations""" |
| 563 | + device = "cuda" |
| 564 | + dtype = torch.bfloat16 |
| 565 | + |
| 566 | + # Create reference and quantized models with dimensions that are multiples of 16 |
| 567 | + ref_model = ( |
| 568 | + torch.nn.Linear(64, 48, bias=False).to(device).to(dtype) |
| 569 | + ) # 48 is divisible by 16 |
| 570 | + quant_model = copy.deepcopy(ref_model) |
| 571 | + quantize_( |
| 572 | + quant_model, |
| 573 | + Float8DynamicActivationFloat8WeightConfig(granularity=granularity), |
| 574 | + ) |
| 575 | + |
| 576 | + # Create input with batch size that works well with slicing |
| 577 | + input_tensor = torch.randn(8, 64, device=device, dtype=dtype) |
| 578 | + |
| 579 | + ref_weight_slice = ref_model.weight[0:16, 0:32] |
| 580 | + quant_weight_slice = quant_model.weight[0:16, 0:32] |
| 581 | + |
| 582 | + input_slice = input_tensor[:, 0:32] # (8, 32) to match sliced weight |
| 583 | + |
| 584 | + # Compute with sliced weights |
| 585 | + with torch.no_grad(): |
| 586 | + ref_output = torch.nn.functional.linear(input_slice, ref_weight_slice) |
| 587 | + quant_output = torch.nn.functional.linear(input_slice, quant_weight_slice) |
| 588 | + |
| 589 | + # Verify shapes |
| 590 | + expected_shape = (8, 16) # batch_size x out_features_sliced |
| 591 | + self.assertEqual(ref_output.shape, expected_shape) |
| 592 | + self.assertEqual(quant_output.shape, expected_shape) |
| 593 | + |
| 594 | + # Verify reasonable quantization error |
| 595 | + error = compute_error(ref_output, quant_output) |
| 596 | + self.assertGreater(error, 15, f"Quantization SQNR too low: {error}") |
315 | 597 |
|
316 | 598 |
|
317 | 599 | common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
|
|
0 commit comments