Skip to content

Commit 1017c7e

Browse files
authored
Fix Per Row scaling for inference (#2253)
1 parent e0e8b39 commit 1017c7e

File tree

6 files changed

+628
-179
lines changed

6 files changed

+628
-179
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 295 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch._inductor.test_case import TestCase as InductorTestCase
2626
from torch.testing._internal import common_utils
2727

28+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
2829
from torchao.float8.float8_utils import compute_error
2930
from torchao.quantization import (
3031
Float8DynamicActivationFloat8WeightConfig,
@@ -42,6 +43,9 @@
4243
from torchao.quantization.quant_primitives import (
4344
MappingType,
4445
choose_qparams_affine,
46+
choose_qparams_affine_float8,
47+
dequantize_affine_float8,
48+
quantize_affine_float8,
4549
)
4650
from torchao.utils import (
4751
is_sm_at_least_89,
@@ -297,21 +301,299 @@ def test_fp8_weight_dimension_warning(self):
297301
@unittest.skipIf(
298302
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
299303
)
300-
def test_mm_float8dq(self):
304+
@common_utils.parametrize(
305+
"in_features,out_features", [(512, 1024), (256, 768), (1024, 512)]
306+
)
307+
@common_utils.parametrize(
308+
"leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)]
309+
) # fmt: skip
310+
@common_utils.parametrize("bias", [True, False])
311+
def test_mm_float8dq_per_row(
312+
self, in_features, out_features, leading_shape, bias: bool
313+
):
314+
device = "cuda"
315+
dtype = torch.bfloat16
316+
input_shape = leading_shape + (in_features,)
317+
318+
ref_linear = (
319+
torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype)
320+
)
321+
test_linear = copy.deepcopy(ref_linear)
322+
quantize_(
323+
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
324+
)
325+
326+
quant_weight = test_linear.weight
327+
328+
self.assertTrue(hasattr(quant_weight, "original_weight_tensor"))
329+
weight_impl = quant_weight.original_weight_tensor.tensor_impl
330+
331+
self.assertTrue(hasattr(weight_impl, "float8_data"))
332+
self.assertTrue(hasattr(weight_impl, "scale"))
333+
self.assertFalse(weight_impl.transposed)
334+
335+
# Verify scale shape for row-wise quantization
336+
expected_scale_shape = (out_features, 1)
337+
actual_scale_shape = weight_impl.scale.shape
338+
self.assertEqual(actual_scale_shape, expected_scale_shape)
339+
340+
self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features))
341+
342+
input_tensor = torch.randn(*input_shape, device=device, dtype=dtype)
343+
344+
with torch.no_grad():
345+
ref_output = ref_linear(input_tensor)
346+
quant_output = torch.nn.functional.linear(input_tensor, quant_weight)
347+
348+
expected_output_shape = input_tensor.shape[:-1] + (out_features,)
349+
self.assertEqual(quant_output.shape, expected_output_shape)
350+
351+
error = compute_error(ref_output, quant_output)
352+
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
353+
354+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
355+
@unittest.skipIf(
356+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
357+
)
358+
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
359+
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
360+
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])
361+
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
362+
"""Test dequantize_affine_float8 with various configurations"""
363+
364+
device = "cuda"
365+
input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32)
366+
367+
# Choose quantization parameters
368+
scale = choose_qparams_affine_float8(
369+
input_tensor, float8_dtype=float8_dtype, block_size=block_size
370+
)
371+
372+
# Quantize
373+
quantized = quantize_affine_float8(input_tensor, scale, float8_dtype)
374+
375+
# Dequantize
376+
dequantized = dequantize_affine_float8(quantized, scale, output_dtype)
377+
378+
# Verify output properties
379+
self.assertEqual(dequantized.dtype, output_dtype)
380+
self.assertEqual(dequantized.shape, input_tensor.shape)
381+
self.assertEqual(dequantized.device, input_tensor.device)
382+
383+
# Verify quantization/dequantization roundtrip is reasonable
384+
error = torch.abs(input_tensor.to(output_dtype) - dequantized).mean()
385+
self.assertLess(error, 0.1, "Quantization error too high")
386+
387+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
388+
@unittest.skipIf(
389+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
390+
)
391+
def test_dequantize_affine_float8_scale_broadcasting(self):
392+
"""Test that scale broadcasting works correctly for block-wise quantization"""
393+
device = "cuda"
394+
# Create input tensor with known block structure
395+
input_tensor = torch.randn(4, 32, device=device, dtype=torch.float32)
396+
block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim
397+
398+
# Choose quantization parameters
399+
scale = choose_qparams_affine_float8(
400+
input_tensor, float8_dtype=torch.float8_e4m3fn, block_size=block_size
401+
)
402+
403+
# Verify scale shape
404+
expected_scale_shape = (
405+
input_tensor.shape[0] // block_size[0],
406+
input_tensor.shape[1] // block_size[1],
407+
)
408+
self.assertEqual(scale.shape, expected_scale_shape)
409+
410+
# Quantize
411+
quantized = quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn)
412+
413+
# Dequantize
414+
dequantized = dequantize_affine_float8(quantized, scale, torch.float32)
415+
416+
# Verify shapes match
417+
self.assertEqual(dequantized.shape, input_tensor.shape)
418+
419+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
420+
@unittest.skipIf(
421+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
422+
)
423+
@common_utils.parametrize(
424+
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
425+
)
426+
def test_float8_tensor_slicing_basic(self, granularity):
427+
"""Test basic slicing operations on Float8 tensors"""
301428
device = "cuda"
302429
dtype = torch.bfloat16
303-
weight = torch.randn(512, 1024).to(device).to(dtype)
304-
weight = weight.t()
305-
306-
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
307-
l.weight = torch.nn.Parameter(weight)
308-
quantize_(l, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
309-
# weight shape: 1024 x 512
310-
weight = l.weight
311-
312-
input = torch.randn(1, 512, device=device, dtype=dtype)
313-
# make sure it runs
314-
torch.nn.functional.linear(input, weight)
430+
431+
# Create and quantize a model
432+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
433+
quantize_(
434+
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
435+
)
436+
437+
weight_impl = model.weight.original_weight_tensor.tensor_impl
438+
439+
# Test dimension 0 slicing (rows)
440+
sliced_0 = weight_impl[10:20]
441+
self.assertEqual(sliced_0.shape, (10, 64))
442+
443+
# Test dimension 1 slicing (columns)
444+
sliced_1 = weight_impl[:, 20:40]
445+
self.assertEqual(sliced_1.shape, (32, 20))
446+
447+
# Test combined slicing
448+
sliced_both = weight_impl[5:15, 10:30]
449+
self.assertEqual(sliced_both.shape, (10, 20))
450+
451+
# Verify the sliced tensors are still Float8 tensors
452+
self.assertTrue(isinstance(sliced_0, Float8AQTTensorImpl))
453+
self.assertTrue(isinstance(sliced_1, Float8AQTTensorImpl))
454+
self.assertTrue(isinstance(sliced_both, Float8AQTTensorImpl))
455+
456+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
457+
@unittest.skipIf(
458+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
459+
)
460+
def test_float8_tensor_slicing_per_tensor(self):
461+
"""Test slicing with per-tensor quantization (scale should not change)"""
462+
device = "cuda"
463+
dtype = torch.bfloat16
464+
465+
# Create and quantize with per-tensor granularity
466+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
467+
quantize_(
468+
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
469+
)
470+
471+
original_weight = model.weight
472+
original_impl = original_weight.original_weight_tensor.tensor_impl
473+
original_scale = original_impl.scale
474+
475+
# Test slicing
476+
sliced_weight = original_weight[10:20, 20:40]
477+
sliced_impl = sliced_weight.original_weight_tensor.tensor_impl
478+
479+
# For per-tensor quantization, scale should be identical
480+
self.assertTrue(torch.equal(original_scale, sliced_impl.scale))
481+
self.assertEqual(sliced_impl.scale.numel(), 1)
482+
483+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
484+
@unittest.skipIf(
485+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
486+
)
487+
@unittest.skipIf(
488+
not is_sm_at_least_90(),
489+
"Per-row quantization requires compute capability >= 9.0",
490+
)
491+
def test_float8_tensor_slicing_per_row(self):
492+
"""Test slicing with per-row quantization (scale should be sliced appropriately)"""
493+
device = "cuda"
494+
dtype = torch.bfloat16
495+
496+
# Create and quantize with per-row granularity
497+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
498+
quantize_(
499+
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
500+
)
501+
502+
original_weight = model.weight # Shape: (32, 64)
503+
original_impl = original_weight.original_weight_tensor.tensor_impl
504+
original_scale = original_impl.scale # Shape: (32, 1)
505+
506+
# Test row slicing (dimension 0)
507+
sliced_rows = original_weight[10:20] # Shape: (10, 64)
508+
sliced_impl = sliced_rows.original_weight_tensor.tensor_impl
509+
510+
# Scale should be sliced to match the rows
511+
expected_scale_shape = (10, 1)
512+
self.assertEqual(sliced_impl.scale.shape, expected_scale_shape)
513+
514+
# Verify the scale values are correct (should be subset of original)
515+
self.assertTrue(torch.equal(sliced_impl.scale, original_scale[10:20]))
516+
517+
# Test column slicing (dimension 1) - scale should not change for per-row
518+
sliced_cols = original_weight[:, 20:40] # Shape: (32, 20)
519+
sliced_cols_impl = sliced_cols.original_weight_tensor.tensor_impl
520+
521+
# Scale shape should remain the same since we're not changing rows
522+
self.assertEqual(sliced_cols_impl.scale.shape, (32, 1))
523+
self.assertTrue(torch.equal(sliced_cols_impl.scale, original_scale))
524+
525+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
526+
@unittest.skipIf(
527+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
528+
)
529+
def test_float8_tensor_slicing_edge_cases(self):
530+
"""Test edge cases in slicing"""
531+
device = "cuda"
532+
dtype = torch.bfloat16
533+
534+
# Create and quantize a model
535+
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
536+
quantize_(
537+
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
538+
)
539+
540+
original_weight = model.weight
541+
542+
# Test empty slice
543+
empty_slice = original_weight[0:0]
544+
self.assertEqual(empty_slice.shape, (0, 64))
545+
546+
# Test single element slice
547+
single_row = original_weight[0:1]
548+
self.assertEqual(single_row.shape, (1, 64))
549+
550+
# Test out of bounds (should be handled by PyTorch)
551+
large_slice = original_weight[:100] # More than available rows
552+
self.assertEqual(large_slice.shape, (32, 64)) # Should clamp to available
553+
554+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
555+
@unittest.skipIf(
556+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
557+
)
558+
@common_utils.parametrize(
559+
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
560+
)
561+
def test_float8_tensor_slicing_functional_correctness(self, granularity):
562+
"""Test that sliced tensors produce correct results in computations"""
563+
device = "cuda"
564+
dtype = torch.bfloat16
565+
566+
# Create reference and quantized models with dimensions that are multiples of 16
567+
ref_model = (
568+
torch.nn.Linear(64, 48, bias=False).to(device).to(dtype)
569+
) # 48 is divisible by 16
570+
quant_model = copy.deepcopy(ref_model)
571+
quantize_(
572+
quant_model,
573+
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
574+
)
575+
576+
# Create input with batch size that works well with slicing
577+
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
578+
579+
ref_weight_slice = ref_model.weight[0:16, 0:32]
580+
quant_weight_slice = quant_model.weight[0:16, 0:32]
581+
582+
input_slice = input_tensor[:, 0:32] # (8, 32) to match sliced weight
583+
584+
# Compute with sliced weights
585+
with torch.no_grad():
586+
ref_output = torch.nn.functional.linear(input_slice, ref_weight_slice)
587+
quant_output = torch.nn.functional.linear(input_slice, quant_weight_slice)
588+
589+
# Verify shapes
590+
expected_shape = (8, 16) # batch_size x out_features_sliced
591+
self.assertEqual(ref_output.shape, expected_shape)
592+
self.assertEqual(quant_output.shape, expected_shape)
593+
594+
# Verify reasonable quantization error
595+
error = compute_error(ref_output, quant_output)
596+
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
315597

316598

317599
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,10 @@ def from_hp_to_floatx(
462462
if target_dtype in FP8_TYPES:
463463
original_shape = input_float.shape
464464
input_float = _layout.pre_process(input_float)
465-
466-
scale = choose_qparams_affine_float8(input_float, float8_dtype=target_dtype)
465+
scale = choose_qparams_affine_float8(
466+
input_float, float8_dtype=target_dtype, block_size=block_size
467+
)
467468
data = quantize_affine_float8(input_float, scale, target_dtype)
468-
469469
data, scale, zero_point = _layout.post_process(
470470
data, scale, None, block_size
471471
)
@@ -503,7 +503,6 @@ def from_hp_to_floatx_static(
503503
input_float,
504504
scale,
505505
target_dtype,
506-
scale_dtype,
507506
)
508507

509508
data, scale, zero_point = _layout.post_process(

0 commit comments

Comments
 (0)