Skip to content

test_affine_quantized_float.py pytest too unittest #2261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 54 additions & 57 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,21 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import unittest

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
)

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
unittest.skip("Unsupported PyTorch version")

import copy
import io
import random
import unittest
from contextlib import nullcontext
from functools import partial
from typing import Tuple

import pytest
import torch
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils
Expand Down Expand Up @@ -95,68 +92,68 @@ def test_fp8_linear_variants(
"Static quantization only supports PerTensor granularity"
)

error_context = (
pytest.raises(AssertionError, match=error_message)
if error_message
else nullcontext()
if error_message:
with self.assertRaisesRegex(AssertionError, error_message):
self._run_fp8_linear_variant(dtype, mode, compile, sizes, granularity)
else:
self._run_fp8_linear_variant(dtype, mode, compile, sizes, granularity)

def _run_fp8_linear_variant(self, dtype, mode, compile, sizes, granularity):
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
# Get a "reasonable" scale for the input tensor even though
# we use the same scale for multiple activations
scale, _ = choose_qparams_affine(
input_tensor,
MappingType.SYMMETRIC,
input_tensor.shape,
torch.float8_e4m3fn,
scale_dtype=torch.float32,
)
mode_map = {
"dynamic": partial(
float8_dynamic_activation_float8_weight, granularity=granularity
),
"weight-only": float8_weight_only,
"static": partial(
float8_static_activation_float8_weight,
scale=scale,
granularity=granularity,
),
}

with error_context:
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
# Get a "reasonable" scale for the input tensor even though
# we use the same scale for multiple activations
scale, _ = choose_qparams_affine(
input_tensor,
MappingType.SYMMETRIC,
input_tensor.shape,
torch.float8_e4m3fn,
scale_dtype=torch.float32,
)
mode_map = {
"dynamic": partial(
float8_dynamic_activation_float8_weight, granularity=granularity
),
"weight-only": float8_weight_only,
"static": partial(
float8_static_activation_float8_weight,
scale=scale,
granularity=granularity,
),
}

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)
factory = mode_map[mode]()
quantize_(quantized_model, factory)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

output_original = model(input_tensor)
output_quantized = quantized_model(input_tensor)

error = compute_error(output_original, output_quantized)
assert compute_error(output_original, output_quantized) > 20, (
f"Quantization error is too high got a SQNR of {error}"
)
# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)
factory = mode_map[mode]()
quantize_(quantized_model, factory)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

output_original = model(input_tensor)
output_quantized = quantized_model(input_tensor)

error = compute_error(output_original, output_quantized)
self.assertGreater(
error, 20, f"Quantization error is too high got a SQNR of {error}"
)

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_invalid_granularity(self):
with pytest.raises(ValueError, match="Invalid granularity specification"):
with self.assertRaisesRegex(ValueError, "Invalid granularity specification"):
float8_dynamic_activation_float8_weight(granularity="invalid")

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_mismatched_granularity(self):
with pytest.raises(
with self.assertRaisesRegex(
ValueError,
match="Different granularities for activation and weight are not supported",
"Different granularities for activation and weight are not supported",
):
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))

Expand All @@ -167,7 +164,7 @@ def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass

with pytest.raises(ValueError, match="Invalid granularity types"):
with self.assertRaisesRegex(ValueError, "Invalid granularity types"):
float8_dynamic_activation_float8_weight(
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
)
Expand All @@ -177,9 +174,9 @@ class UnsupportedGranularity:
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_per_row_with_float32(self):
with pytest.raises(
with self.assertRaisesRegex(
AssertionError,
match="PerRow quantization only works for bfloat16 precision",
"PerRow quantization only works for bfloat16 precision",
):
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
Expand Down Expand Up @@ -317,4 +314,4 @@ def test_mm_float8dq(self):
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

if __name__ == "__main__":
pytest.main([__file__])
unittest.main()
168 changes: 102 additions & 66 deletions test/dtypes/test_bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import unittest

import torch
from torch.testing._internal import common_utils
from torch.utils._triton import has_triton

from torchao.dtypes.uintx.bitpacking import pack, pack_cpu, unpack, unpack_cpu
Expand All @@ -13,68 +15,102 @@
dimensions = (0, -1, 1)


@pytest.fixture(autouse=True)
def run_before_and_after_tests():
yield
torch._dynamo.reset() # reset cache between tests


@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("dim", dimensions)
def test_CPU(bit_width, dim):
test_tensor = torch.randint(
0, 2**bit_width, (32, 32, 32), dtype=torch.uint8, device="cpu"
)
packed = pack_cpu(test_tensor, bit_width, dim=dim)
unpacked = unpack_cpu(packed, bit_width, dim=dim)
assert unpacked.allclose(test_tensor)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("dim", dimensions)
def test_GPU(bit_width, dim):
test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda()
packed = pack(test_tensor, bit_width, dim=dim)
unpacked = unpack(packed, bit_width, dim=dim)
assert unpacked.allclose(test_tensor)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("dim", dimensions)
def test_compile(bit_width, dim):
torch._dynamo.config.specialize_int = True
torch.compile(pack, fullgraph=True)
torch.compile(unpack, fullgraph=True)
test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda()
packed = pack(test_tensor, bit_width, dim=dim)
unpacked = unpack(packed, bit_width, dim=dim)
assert unpacked.allclose(test_tensor)


# these test cases are for the example pack walk through in the bitpacking.py file
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pack_example():
test_tensor = torch.tensor(
[0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8
).cuda()
shard_4, shard_2 = pack(test_tensor, 6)
print(shard_4, shard_2)
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4)
assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2)
unpacked = unpack([shard_4, shard_2], 6)
assert unpacked.allclose(test_tensor)


def test_pack_example_CPU():
test_tensor = torch.tensor(
[0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8
)
shard_4, shard_2 = pack(test_tensor, 6)
print(shard_4, shard_2)
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4)
assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2)
unpacked = unpack([shard_4, shard_2], 6)
assert unpacked.allclose(test_tensor)
class TestBitpacking(unittest.TestCase):
def setUp(self):
"""Set up test fixtures before each test method."""
# Set random seeds for reproducibility
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)

# Initialize any common test data
self.test_shape = (32, 32, 32)

# Set default device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Reset any cached state
torch._dynamo.reset()

def tearDown(self):
"""Clean up after each test method."""
# Reset torch._dynamo cache
torch._dynamo.reset()

# Clear CUDA cache if available
if torch.cuda.is_available():
torch.cuda.empty_cache()

# Reset any other state if needed
torch._dynamo.config.specialize_int = False

@common_utils.parametrize("bit_width", bit_widths)
@common_utils.parametrize("dim", dimensions)
def test_CPU(self, bit_width, dim):
test_tensor = torch.randint(
0, 2**bit_width, (32, 32, 32), dtype=torch.uint8, device="cpu"
)
packed = pack_cpu(test_tensor, bit_width, dim=dim)
unpacked = unpack_cpu(packed, bit_width, dim=dim)
self.assertTrue(unpacked.allclose(test_tensor))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@common_utils.parametrize("bit_width", bit_widths)
@common_utils.parametrize("dim", dimensions)
def test_GPU(self, bit_width, dim):
test_tensor = torch.randint(
0, 2**bit_width, (32, 32, 32), dtype=torch.uint8
).cuda()
packed = pack(test_tensor, bit_width, dim=dim)
unpacked = unpack(packed, bit_width, dim=dim)
self.assertTrue(unpacked.allclose(test_tensor))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not has_triton(), "unsupported without triton")
@common_utils.parametrize("bit_width", bit_widths)
@common_utils.parametrize("dim", dimensions)
def test_compile(self, bit_width, dim):
torch._dynamo.config.specialize_int = True
torch.compile(pack, fullgraph=True)
torch.compile(unpack, fullgraph=True)
test_tensor = torch.randint(
0, 2**bit_width, (32, 32, 32), dtype=torch.uint8
).cuda()
packed = pack(test_tensor, bit_width, dim=dim)
unpacked = unpack(packed, bit_width, dim=dim)
self.assertTrue(unpacked.allclose(test_tensor))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_pack_example(self):
test_tensor = torch.tensor(
[0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8
).cuda()
shard_4, shard_2 = pack(test_tensor, 6)
print(shard_4, shard_2)
self.assertTrue(
torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4)
)
self.assertTrue(
torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2)
)
unpacked = unpack([shard_4, shard_2], 6)
self.assertTrue(unpacked.allclose(test_tensor))

def test_pack_example_CPU(self):
test_tensor = torch.tensor(
[0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8
)
shard_4, shard_2 = pack(test_tensor, 6)
print(shard_4, shard_2)
self.assertTrue(
torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4)
)
self.assertTrue(torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2))
unpacked = unpack([shard_4, shard_2], 6)
self.assertTrue(unpacked.allclose(test_tensor))


common_utils.instantiate_parametrized_tests(TestBitpacking)

if __name__ == "__main__":
unittest.main()