Skip to content

Commit c244b32

Browse files
committed
Add subclass based method for inference w/ MXFP8
stack-info: PR: #2132, branch: drisspg/stack/50
1 parent f69bd4e commit c244b32

File tree

9 files changed

+444
-61
lines changed

9 files changed

+444
-61
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MXInferenceLinear,
2626
MXLinear,
2727
)
28+
from torchao.prototype.mx_formats.mx_subclass import MXFPConfig
2829
from torchao.quantization import quantize_
2930
from torchao.quantization.utils import compute_error
3031
from torchao.utils import (
@@ -372,3 +373,34 @@ def test_inference_print_str():
372373
s = str(m)
373374
assert "bl_sz=32" in s
374375
assert "kernel=emulated" in s
376+
377+
378+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
379+
@pytest.mark.skipif(
380+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
381+
)
382+
@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100")
383+
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
384+
@pytest.mark.parametrize("bias", [True, False])
385+
@pytest.mark.parametrize("compile", [True, False])
386+
@torch.no_grad()
387+
def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
388+
"""
389+
Smoke test for inference compile
390+
"""
391+
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
392+
if not is_sm_at_least_89():
393+
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
394+
395+
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
396+
m_mx = copy.deepcopy(m)
397+
config = MXFPConfig()
398+
quantize_(m_mx, config=config)
399+
if compile:
400+
m_mx = torch.compile(m_mx, fullgraph=True)
401+
402+
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
403+
y_ref = m(x)
404+
y_mx = m_mx(x)
405+
sqnr = compute_error(y_ref, y_mx)
406+
assert sqnr >= 25.0, f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"

torchao/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
quantize_,
4444
)
4545

46-
from . import dtypes, optim, testing
46+
from . import dtypes, optim, quantization, testing
4747

4848
__all__ = [
4949
"dtypes",
@@ -52,4 +52,5 @@
5252
"quantize_",
5353
"testing",
5454
"ops",
55+
"quantization",
5556
]

torchao/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
175175
"torchao.quantization",
176176
"torchao.sparsity.sparse_api",
177177
"torchao.prototype.quantization",
178+
"torchao.prototype.mx_formats",
178179
}
179180

180181

torchao/prototype/mx_formats/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
MXLinearConfig,
55
MXLinearRecipeName,
66
)
7+
from torchao.prototype.mx_formats.mx_subclass import MXFPConfig
78

89
# import mx_linear here to register the quantize_ transform logic
910
# ruff: noqa: I001
@@ -14,4 +15,5 @@
1415
"MXInferenceLinearConfig",
1516
"MXLinearConfig",
1617
"MXLinearRecipeName",
18+
"MXFPConfig",
1719
]
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
This file defines the top level torch ops that are extended by MXTensor
9+
See: https://docs.pytorch.org/docs/stable/notes/extending.html#extending-torch-with-a-tensor-wrapper-type
10+
for more details.
11+
"""
12+
13+
from typing import Any, Dict
14+
15+
import torch
16+
17+
from torchao.prototype.mx_formats.mx_ops import _addmm_mx_dispatch
18+
from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501
19+
MXTensor,
20+
)
21+
22+
aten = torch.ops.aten
23+
24+
MX_FUNC_TABLE: Dict[Any, Any] = {}
25+
26+
27+
def implements_func(torch_ops):
28+
"""Register torch ops to the mx op table for torch function"""
29+
30+
def decorator(func):
31+
for op in torch_ops:
32+
MX_FUNC_TABLE[op] = func
33+
return func
34+
35+
return decorator
36+
37+
38+
@implements_func([aten.linear.default])
39+
def mx_linear(func, types, args, kwargs):
40+
a, b = args[0], args[1]
41+
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
42+
bias = args[2] if len(args) == 3 else None
43+
return _addmm_mx_dispatch(a, b.t(), func, bias=bias)

0 commit comments

Comments
 (0)