Skip to content

Commit 44a878b

Browse files
committed
Add subclass based method for inference w/ MXFP8
stack-info: PR: #2132, branch: drisspg/stack/50
1 parent 955ebb0 commit 44a878b

File tree

8 files changed

+429
-61
lines changed

8 files changed

+429
-61
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MXInferenceLinear,
2626
MXLinear,
2727
)
28+
from torchao.prototype.mx_formats.mx_subclass import MXFPConfig
2829
from torchao.quantization import quantize_
2930
from torchao.quantization.utils import compute_error
3031
from torchao.utils import (
@@ -372,3 +373,34 @@ def test_inference_print_str():
372373
s = str(m)
373374
assert "bl_sz=32" in s
374375
assert "kernel=emulated" in s
376+
377+
378+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
379+
@pytest.mark.skipif(
380+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
381+
)
382+
@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100")
383+
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
384+
@pytest.mark.parametrize("bias", [True, False])
385+
@pytest.mark.parametrize("compile", [True, False])
386+
@torch.no_grad()
387+
def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
388+
"""
389+
Smoke test for inference compile
390+
"""
391+
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
392+
if not is_sm_at_least_89():
393+
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
394+
395+
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
396+
m_mx = copy.deepcopy(m)
397+
config = MXFPConfig()
398+
quantize_(m_mx, config=config)
399+
if compile:
400+
m_mx = torch.compile(m_mx, fullgraph=True)
401+
402+
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
403+
y_ref = m(x)
404+
y_mx = m_mx(x)
405+
sqnr = compute_error(y_ref, y_mx)
406+
assert sqnr >= 25.0, f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"

torchao/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
quantize_,
4444
)
4545

46-
from . import dtypes, optim, swizzle, testing
46+
from . import dtypes, optim, quantization, swizzle, testing
4747

4848
__all__ = [
4949
"dtypes",
@@ -53,4 +53,5 @@
5353
"swizzle",
5454
"testing",
5555
"ops",
56+
"quantization",
5657
]

torchao/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
175175
"torchao.quantization",
176176
"torchao.sparsity.sparse_api",
177177
"torchao.prototype.quantization",
178+
"torchao.prototype.mx_formats",
178179
}
179180

180181

torchao/prototype/mx_formats/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
MXLinearConfig,
55
MXLinearRecipeName,
66
)
7+
from torchao.prototype.mx_formats.mx_subclass import MXFPConfig
78

89
# import mx_linear here to register the quantize_ transform logic
910
# ruff: noqa: I001
@@ -14,4 +15,5 @@
1415
"MXInferenceLinearConfig",
1516
"MXLinearConfig",
1617
"MXLinearRecipeName",
18+
"MXFPConfig",
1719
]

0 commit comments

Comments
 (0)