Skip to content

Commit 81e48a3

Browse files
authored
Add a triton kernel for swizziling (#2168)
stack-info: PR: #2168, branch: drisspg/stack/53
1 parent 7192edf commit 81e48a3

File tree

8 files changed

+175
-17
lines changed

8 files changed

+175
-17
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from typing import Callable, Tuple
28

39
import fire
410
import torch
511
import triton
612
from torch._inductor.utils import do_bench_using_profiling
713

8-
from torchao.prototype.mx_formats.custom_cast import (
14+
from torchao.prototype.mx_formats.kernels import (
915
triton_to_mxfp8_dim1,
1016
)
1117
from torchao.prototype.mx_formats.mx_tensor import to_mx

test/prototype/mx_formats/test_custom_cast.py renamed to test/prototype/mx_formats/test_kernels.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@
1616
F6_E2M3_EXP_BIAS,
1717
F6_E3M2_EXP_BIAS,
1818
)
19-
from torchao.prototype.mx_formats.custom_cast import (
19+
from torchao.prototype.mx_formats.fp_format_spec import (
20+
_assert_equals,
21+
dtype_to_interesting_values,
22+
float4_e2m1_interesting_values,
23+
float6_e2m3_interesting_values,
24+
float6_e3m2_interesting_values,
25+
get_sem_bits,
26+
sem_bits_to_sem_vals,
27+
sem_vals_to_f32,
28+
)
29+
from torchao.prototype.mx_formats.kernels import (
2030
f4_unpacked_to_f32,
2131
f6_e2m3_unpacked_to_f32,
2232
f6_e3m2_unpacked_to_f32,
@@ -33,17 +43,8 @@
3343
triton_to_mxfp8_dim1_reference,
3444
unpack_uint4,
3545
)
36-
from torchao.prototype.mx_formats.fp_format_spec import (
37-
_assert_equals,
38-
dtype_to_interesting_values,
39-
float4_e2m1_interesting_values,
40-
float6_e2m3_interesting_values,
41-
float6_e3m2_interesting_values,
42-
get_sem_bits,
43-
sem_bits_to_sem_vals,
44-
sem_vals_to_f32,
45-
)
4646
from torchao.prototype.mx_formats.mx_tensor import MXTensor
47+
from torchao.prototype.mx_formats.utils import to_blocked
4748
from torchao.utils import (
4849
TORCH_VERSION_AT_LEAST_2_8,
4950
is_sm_at_least_89,
@@ -465,3 +466,24 @@ def test_triton_mxfp8_dim1_randn(M, K):
465466
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
466467
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
467468
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
469+
470+
471+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
472+
@pytest.mark.parametrize(
473+
"shape",
474+
[
475+
(63, 1023),
476+
(128, 4),
477+
(128, 8),
478+
(256, 8),
479+
(300, 9),
480+
(133, 512),
481+
(528, 512),
482+
(128, 1),
483+
],
484+
)
485+
def test_rearrange(shape):
486+
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8)
487+
eager = to_blocked(scales, False)
488+
triton = to_blocked(scales, True)
489+
torch.testing.assert_close(eager, triton, atol=0, rtol=0)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
DTYPE_FP6_E3M2,
1818
SUPPORTED_ELEM_DTYPES,
1919
)
20-
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
20+
from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6
2121
from torchao.prototype.mx_formats.mx_tensor import (
2222
MXTensor,
2323
ScaleCalculationMode,

torchao/prototype/mx_formats/fp_format_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
DTYPE_FP6_E2M3,
2121
DTYPE_FP6_E3M2,
2222
)
23-
from torchao.prototype.mx_formats.custom_cast import get_bits
23+
from torchao.prototype.mx_formats.kernels import get_bits
2424

2525
dtype_to_bitwidth = {
2626
torch.float: 32,

torchao/prototype/mx_formats/custom_cast.py renamed to torchao/prototype/mx_formats/kernels.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,124 @@ def triton_to_mxfp8_dim1_reference(
13831383
scale_e8m0_dim1,
13841384
)
13851385

1386+
@triton.jit
1387+
def triton_scale_swizzle(
1388+
scale_ptr,
1389+
scale_rows,
1390+
scale_cols,
1391+
output_ptr,
1392+
input_row_stride,
1393+
output_block_stride,
1394+
BLOCK_ROWS: tl.constexpr,
1395+
BLOCK_COLS: tl.constexpr,
1396+
):
1397+
"""
1398+
Rearranges tensor data from row-major to block-scaled swizzle format.
1399+
1400+
Args:
1401+
scale_ptr: Pointer to the input scale tensor
1402+
scale_rows: Number of rows in the scale tensor
1403+
scale_cols: Number of columns in the scale tensor
1404+
output_ptr: Pointer to the output tensor
1405+
input_row_stride: Stride between rows in the input tensor
1406+
output_block_stride: Stride between blocks in the output tensor
1407+
BLOCK_ROWS: Number of rows in a tile (compile-time constant)
1408+
BLOCK_COLS: Number of columns in a tile (compile-time constant)
1409+
"""
1410+
pid_row = tl.program_id(0)
1411+
pid_col = tl.program_id(1)
1412+
1413+
rows = tl.arange(0, BLOCK_ROWS)[:, None]
1414+
cols = tl.arange(0, BLOCK_COLS)[None, :]
1415+
1416+
# Calculate starting row and column for this tile
1417+
start_row = pid_row * BLOCK_ROWS
1418+
start_col = pid_col * BLOCK_COLS
1419+
global_rows = start_row + rows
1420+
global_cols = start_col + cols
1421+
1422+
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
1423+
1424+
input_scales = tl.load(
1425+
scale_ptr + global_rows * input_row_stride + global_cols,
1426+
mask=mask,
1427+
other=0.0,
1428+
)
1429+
1430+
r_div_32 = rows // 32
1431+
r_mod_32 = rows % 32
1432+
1433+
# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
1434+
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
1435+
1436+
# Flatten
1437+
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
1438+
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
1439+
1440+
# Calculate block offset using provided output block stride
1441+
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
1442+
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)
1443+
1444+
tl.store(
1445+
output_ptr + block_offset + dest_indices_flat,
1446+
scales_flat,
1447+
)
1448+
1449+
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1450+
"""
1451+
Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
1452+
1453+
This format is suitable for Tmem as described in NVIDIA documentation:
1454+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1455+
1456+
Args:
1457+
scale_tensor: Input tensor in row-major format with 8-bit elements
1458+
1459+
Returns:
1460+
Rearranged tensor in block-scaled swizzle format
1461+
"""
1462+
assert scale_tensor.element_size() == 1, (
1463+
"Expected element size to be 1 byte (8 bits)"
1464+
)
1465+
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
1466+
1467+
rows, cols = scale_tensor.shape
1468+
1469+
# Calculate blocks needed
1470+
n_row_blocks = triton.cdiv(rows, 128)
1471+
n_col_blocks = triton.cdiv(cols, 4)
1472+
padded_rows = n_row_blocks * 128
1473+
padded_cols = n_col_blocks * 4
1474+
1475+
out = scale_tensor.new_empty((padded_rows, padded_cols))
1476+
1477+
# Input stride (for row-major format)
1478+
input_row_stride = cols
1479+
1480+
# We probably want handle multiple blocks per tile but for now keep it simple
1481+
BLOCK_ROWS, BLOCK_COLS = 128, 4
1482+
1483+
# Output block stride for the rearranged format
1484+
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
1485+
1486+
grid = lambda META: (
1487+
triton.cdiv(padded_rows, BLOCK_ROWS),
1488+
triton.cdiv(padded_cols, BLOCK_COLS),
1489+
)
1490+
1491+
wrap_triton(triton_scale_swizzle)[grid](
1492+
scale_tensor.view(torch.uint8),
1493+
rows,
1494+
cols,
1495+
out.view(torch.uint8),
1496+
input_row_stride,
1497+
output_block_stride,
1498+
BLOCK_ROWS=BLOCK_ROWS,
1499+
BLOCK_COLS=BLOCK_COLS,
1500+
)
1501+
1502+
return out
1503+
13861504
else:
13871505

13881506
def triton_to_mxfp8_dim1(
@@ -1394,3 +1512,6 @@ def triton_to_mxfp8_dim1_reference(
13941512
x_hp: torch.Tensor, block_size
13951513
) -> Tuple[torch.Tensor, torch.Tensor]:
13961514
raise AssertionError("needs torch version 2.8+ and triton")
1515+
1516+
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1517+
raise AssertionError("needs torch version 2.8+ and triton")

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
MXInferenceLinearConfig,
1919
MXLinearConfig,
2020
)
21-
from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1
21+
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim1
2222
from torchao.prototype.mx_formats.mx_tensor import MXTensor
2323
from torchao.quantization.transform_module import (
2424
register_quantize_module_handler,

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
F32_MIN_NORMAL,
4646
SUPPORTED_ELEM_DTYPES,
4747
)
48-
from torchao.prototype.mx_formats.custom_cast import (
48+
from torchao.prototype.mx_formats.kernels import (
4949
f4_unpacked_to_f32,
5050
f6_e2m3_unpacked_to_f32,
5151
f6_e3m2_unpacked_to_f32,

torchao/prototype/mx_formats/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66

77
import torch
88

9+
from torchao.prototype.mx_formats.kernels import triton_mx_block_rearrange
10+
911
Tensor = torch.Tensor
1012

1113

1214
def ceil_div(a, b):
1315
return (a + b - 1) // b
1416

1517

16-
def to_blocked(input_matrix) -> Tensor:
18+
def to_blocked(input_matrix, use_triton_kernel: bool = True) -> Tensor:
1719
"""
1820
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
1921
@@ -22,10 +24,15 @@ def to_blocked(input_matrix) -> Tensor:
2224
2325
Args:
2426
input_matrix: Input tensor of shape (H, W)
27+
use_triton_kernel: Whether to use a triton implementation instead of relying on
28+
torch.compile
2529
2630
Returns:
2731
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
2832
"""
33+
if use_triton_kernel:
34+
return triton_mx_block_rearrange(input_matrix).flatten()
35+
2936
rows, cols = input_matrix.shape
3037
n_row_blocks = ceil_div(rows, 128)
3138
n_col_blocks = ceil_div(cols, 4)
@@ -35,6 +42,8 @@ def to_blocked(input_matrix) -> Tensor:
3542
padded_cols = n_col_blocks * 4
3643

3744
padded = input_matrix
45+
# TODO This is to work around VLLM's usage of compile w/ dynamic shapes
46+
# if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols):
3847
if (rows, cols) != (padded_rows, padded_cols):
3948
padded = torch.zeros(
4049
(padded_rows, padded_cols),

0 commit comments

Comments
 (0)