Skip to content

Fix Per Row scaling for inference #2253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 295 additions & 13 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils

from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.float8.float8_utils import compute_error
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Expand All @@ -42,6 +43,9 @@
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
choose_qparams_affine_float8,
dequantize_affine_float8,
quantize_affine_float8,
)
from torchao.utils import (
is_sm_at_least_89,
Expand Down Expand Up @@ -297,21 +301,299 @@ def test_fp8_weight_dimension_warning(self):
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_mm_float8dq(self):
@common_utils.parametrize(
"in_features,out_features", [(512, 1024), (256, 768), (1024, 512)]
)
@common_utils.parametrize(
"leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)]
) # fmt: skip
@common_utils.parametrize("bias", [True, False])
def test_mm_float8dq_per_row(
self, in_features, out_features, leading_shape, bias: bool
):
device = "cuda"
dtype = torch.bfloat16
input_shape = leading_shape + (in_features,)

ref_linear = (
torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype)
)
test_linear = copy.deepcopy(ref_linear)
quantize_(
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
)

quant_weight = test_linear.weight

self.assertTrue(hasattr(quant_weight, "original_weight_tensor"))
weight_impl = quant_weight.original_weight_tensor.tensor_impl

self.assertTrue(hasattr(weight_impl, "float8_data"))
self.assertTrue(hasattr(weight_impl, "scale"))
self.assertFalse(weight_impl.transposed)

# Verify scale shape for row-wise quantization
expected_scale_shape = (out_features, 1)
actual_scale_shape = weight_impl.scale.shape
self.assertEqual(actual_scale_shape, expected_scale_shape)

self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features))

input_tensor = torch.randn(*input_shape, device=device, dtype=dtype)

with torch.no_grad():
ref_output = ref_linear(input_tensor)
quant_output = torch.nn.functional.linear(input_tensor, quant_weight)

expected_output_shape = input_tensor.shape[:-1] + (out_features,)
self.assertEqual(quant_output.shape, expected_output_shape)

error = compute_error(ref_output, quant_output)
assert error > 20, f"Quantization error is too high got a SQNR of {error}"

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
"""Test dequantize_affine_float8 with various configurations"""

device = "cuda"
input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32)

# Choose quantization parameters
scale = choose_qparams_affine_float8(
input_tensor, float8_dtype=float8_dtype, block_size=block_size
)

# Quantize
quantized = quantize_affine_float8(input_tensor, scale, float8_dtype)

# Dequantize
dequantized = dequantize_affine_float8(quantized, scale, output_dtype)

# Verify output properties
self.assertEqual(dequantized.dtype, output_dtype)
self.assertEqual(dequantized.shape, input_tensor.shape)
self.assertEqual(dequantized.device, input_tensor.device)

# Verify quantization/dequantization roundtrip is reasonable
error = torch.abs(input_tensor.to(output_dtype) - dequantized).mean()
self.assertLess(error, 0.1, "Quantization error too high")

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_dequantize_affine_float8_scale_broadcasting(self):
"""Test that scale broadcasting works correctly for block-wise quantization"""
device = "cuda"
# Create input tensor with known block structure
input_tensor = torch.randn(4, 32, device=device, dtype=torch.float32)
block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim

# Choose quantization parameters
scale = choose_qparams_affine_float8(
input_tensor, float8_dtype=torch.float8_e4m3fn, block_size=block_size
)

# Verify scale shape
expected_scale_shape = (
input_tensor.shape[0] // block_size[0],
input_tensor.shape[1] // block_size[1],
)
self.assertEqual(scale.shape, expected_scale_shape)

# Quantize
quantized = quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn)

# Dequantize
dequantized = dequantize_affine_float8(quantized, scale, torch.float32)

# Verify shapes match
self.assertEqual(dequantized.shape, input_tensor.shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize(
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
)
def test_float8_tensor_slicing_basic(self, granularity):
"""Test basic slicing operations on Float8 tensors"""
device = "cuda"
dtype = torch.bfloat16
weight = torch.randn(512, 1024).to(device).to(dtype)
weight = weight.t()

l = torch.nn.Linear(512, 1024).to(device).to(dtype)
l.weight = torch.nn.Parameter(weight)
quantize_(l, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
# weight shape: 1024 x 512
weight = l.weight

input = torch.randn(1, 512, device=device, dtype=dtype)
# make sure it runs
torch.nn.functional.linear(input, weight)

# Create and quantize a model
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
quantize_(
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
)

weight_impl = model.weight.original_weight_tensor.tensor_impl

# Test dimension 0 slicing (rows)
sliced_0 = weight_impl[10:20]
self.assertEqual(sliced_0.shape, (10, 64))

# Test dimension 1 slicing (columns)
sliced_1 = weight_impl[:, 20:40]
self.assertEqual(sliced_1.shape, (32, 20))

# Test combined slicing
sliced_both = weight_impl[5:15, 10:30]
self.assertEqual(sliced_both.shape, (10, 20))

# Verify the sliced tensors are still Float8 tensors
self.assertTrue(isinstance(sliced_0, Float8AQTTensorImpl))
self.assertTrue(isinstance(sliced_1, Float8AQTTensorImpl))
self.assertTrue(isinstance(sliced_both, Float8AQTTensorImpl))

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_float8_tensor_slicing_per_tensor(self):
"""Test slicing with per-tensor quantization (scale should not change)"""
device = "cuda"
dtype = torch.bfloat16

# Create and quantize with per-tensor granularity
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
quantize_(
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
)

original_weight = model.weight
original_impl = original_weight.original_weight_tensor.tensor_impl
original_scale = original_impl.scale

# Test slicing
sliced_weight = original_weight[10:20, 20:40]
sliced_impl = sliced_weight.original_weight_tensor.tensor_impl

# For per-tensor quantization, scale should be identical
self.assertTrue(torch.equal(original_scale, sliced_impl.scale))
self.assertEqual(sliced_impl.scale.numel(), 1)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@unittest.skipIf(
not is_sm_at_least_90(),
"Per-row quantization requires compute capability >= 9.0",
)
def test_float8_tensor_slicing_per_row(self):
"""Test slicing with per-row quantization (scale should be sliced appropriately)"""
device = "cuda"
dtype = torch.bfloat16

# Create and quantize with per-row granularity
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
quantize_(
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
)

original_weight = model.weight # Shape: (32, 64)
original_impl = original_weight.original_weight_tensor.tensor_impl
original_scale = original_impl.scale # Shape: (32, 1)

# Test row slicing (dimension 0)
sliced_rows = original_weight[10:20] # Shape: (10, 64)
sliced_impl = sliced_rows.original_weight_tensor.tensor_impl

# Scale should be sliced to match the rows
expected_scale_shape = (10, 1)
self.assertEqual(sliced_impl.scale.shape, expected_scale_shape)

# Verify the scale values are correct (should be subset of original)
self.assertTrue(torch.equal(sliced_impl.scale, original_scale[10:20]))

# Test column slicing (dimension 1) - scale should not change for per-row
sliced_cols = original_weight[:, 20:40] # Shape: (32, 20)
sliced_cols_impl = sliced_cols.original_weight_tensor.tensor_impl

# Scale shape should remain the same since we're not changing rows
self.assertEqual(sliced_cols_impl.scale.shape, (32, 1))
self.assertTrue(torch.equal(sliced_cols_impl.scale, original_scale))

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_float8_tensor_slicing_edge_cases(self):
"""Test edge cases in slicing"""
device = "cuda"
dtype = torch.bfloat16

# Create and quantize a model
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
quantize_(
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
)

original_weight = model.weight

# Test empty slice
empty_slice = original_weight[0:0]
self.assertEqual(empty_slice.shape, (0, 64))

# Test single element slice
single_row = original_weight[0:1]
self.assertEqual(single_row.shape, (1, 64))

# Test out of bounds (should be handled by PyTorch)
large_slice = original_weight[:100] # More than available rows
self.assertEqual(large_slice.shape, (32, 64)) # Should clamp to available

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize(
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
)
def test_float8_tensor_slicing_functional_correctness(self, granularity):
"""Test that sliced tensors produce correct results in computations"""
device = "cuda"
dtype = torch.bfloat16

# Create reference and quantized models with dimensions that are multiples of 16
ref_model = (
torch.nn.Linear(64, 48, bias=False).to(device).to(dtype)
) # 48 is divisible by 16
quant_model = copy.deepcopy(ref_model)
quantize_(
quant_model,
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
)

# Create input with batch size that works well with slicing
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)

ref_weight_slice = ref_model.weight[0:16, 0:32]
quant_weight_slice = quant_model.weight[0:16, 0:32]

input_slice = input_tensor[:, 0:32] # (8, 32) to match sliced weight

# Compute with sliced weights
with torch.no_grad():
ref_output = torch.nn.functional.linear(input_slice, ref_weight_slice)
quant_output = torch.nn.functional.linear(input_slice, quant_weight_slice)

# Verify shapes
expected_shape = (8, 16) # batch_size x out_features_sliced
self.assertEqual(ref_output.shape, expected_shape)
self.assertEqual(quant_output.shape, expected_shape)

# Verify reasonable quantization error
error = compute_error(ref_output, quant_output)
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
Expand Down
7 changes: 3 additions & 4 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,10 @@ def from_hp_to_floatx(
if target_dtype in FP8_TYPES:
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)

scale = choose_qparams_affine_float8(input_float, float8_dtype=target_dtype)
scale = choose_qparams_affine_float8(
input_float, float8_dtype=target_dtype, block_size=block_size
)
data = quantize_affine_float8(input_float, scale, target_dtype)

data, scale, zero_point = _layout.post_process(
data, scale, None, block_size
)
Expand Down Expand Up @@ -503,7 +503,6 @@ def from_hp_to_floatx_static(
input_float,
scale,
target_dtype,
scale_dtype,
)

data, scale, zero_point = _layout.post_process(
Expand Down
Loading
Loading