Skip to content

Commit 4621f10

Browse files
committed
Add a triton kernel for swizziling
stack-info: PR: #2168, branch: drisspg/stack/53
1 parent 44a878b commit 4621f10

File tree

3 files changed

+152
-2
lines changed

3 files changed

+152
-2
lines changed

test/prototype/mx_formats/test_custom_cast.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
sem_vals_to_f32,
4545
)
4646
from torchao.prototype.mx_formats.mx_tensor import MXTensor
47+
from torchao.prototype.mx_formats.utils import to_blocked
4748
from torchao.utils import (
4849
TORCH_VERSION_AT_LEAST_2_8,
4950
is_sm_at_least_89,
@@ -465,3 +466,16 @@ def test_triton_mxfp8_dim1_randn(M, K):
465466
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
466467
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
467468
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
469+
470+
471+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
472+
@pytest.mark.parametrize(
473+
"shape",
474+
# [(63, 1023), (128, 4), (128, 8), (256, 8), (300, 9), (133, 512), (528, 512), (128, 1)],
475+
[(128, 1)],
476+
)
477+
def test_rearrange(shape):
478+
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8)
479+
eager = to_blocked(scales, False)
480+
triton = to_blocked(scales, True)
481+
torch.testing.assert_close(eager, triton, atol=0, rtol=0)

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,133 @@ def triton_to_mxfp8_dim1_reference(
13831383
scale_e8m0_dim1,
13841384
)
13851385

1386+
@triton.jit
1387+
def scale_swizzle(
1388+
scale_ptr,
1389+
scale_rows,
1390+
scale_cols,
1391+
output_ptr,
1392+
input_row_stride,
1393+
output_block_stride,
1394+
BLOCK_ROWS: tl.constexpr,
1395+
BLOCK_COLS: tl.constexpr,
1396+
):
1397+
"""
1398+
Rearranges tensor data from row-major to block-scaled swizzle format.
1399+
1400+
Args:
1401+
scale_ptr: Pointer to the input scale tensor
1402+
scale_rows: Number of rows in the scale tensor
1403+
scale_cols: Number of columns in the scale tensor
1404+
output_ptr: Pointer to the output tensor
1405+
input_row_stride: Stride between rows in the input tensor
1406+
output_block_stride: Stride between blocks in the output tensor
1407+
BLOCK_ROWS: Number of rows in a tile (compile-time constant)
1408+
BLOCK_COLS: Number of columns in a tile (compile-time constant)
1409+
"""
1410+
pid_row = tl.program_id(0)
1411+
pid_col = tl.program_id(1)
1412+
1413+
rows = tl.arange(0, BLOCK_ROWS)[:, None]
1414+
cols = tl.arange(0, BLOCK_COLS)[None, :]
1415+
1416+
# Calculate starting row and column for this tile
1417+
start_row = pid_row * BLOCK_ROWS
1418+
start_col = pid_col * BLOCK_COLS
1419+
global_rows = start_row + rows
1420+
global_cols = start_col + cols
1421+
1422+
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
1423+
1424+
input_scales = tl.load(
1425+
scale_ptr + global_rows * input_row_stride + global_cols,
1426+
mask=mask,
1427+
other=0.0,
1428+
)
1429+
1430+
# Block rearrangement logic for the _to_blocked_single transformation:
1431+
# 1) Divide into 4×32 blocks
1432+
r_div_32 = rows // 32
1433+
r_mod_32 = rows % 32
1434+
1435+
# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
1436+
# row = r_mod_32, col = (r_div_32 * 4 + inner_col)
1437+
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
1438+
1439+
# Flatten indices for storage
1440+
dest_indices_flat = tl.reshape(
1441+
dest_indices, (BLOCK_ROWS * BLOCK_COLS), can_reorder=True
1442+
)
1443+
1444+
# Calculate block offset using provided output block stride
1445+
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
1446+
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)
1447+
1448+
# Store the rearranged values
1449+
tl.store(
1450+
output_ptr + block_offset + dest_indices_flat,
1451+
tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS), can_reorder=True),
1452+
)
1453+
1454+
def mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1455+
"""
1456+
Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
1457+
1458+
This format is suitable for Tmem as described in NVIDIA documentation:
1459+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1460+
1461+
Args:
1462+
scale_tensor: Input tensor in row-major format with 8-bit elements
1463+
1464+
Returns:
1465+
Rearranged tensor in block-scaled swizzle format
1466+
"""
1467+
assert scale_tensor.element_size() == 1, (
1468+
"Expected element size to be 1 byte (8 bits)"
1469+
)
1470+
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
1471+
1472+
rows, cols = scale_tensor.shape
1473+
1474+
# Calculate blocks needed
1475+
n_row_blocks = triton.cdiv(rows, 128)
1476+
n_col_blocks = triton.cdiv(cols, 4)
1477+
padded_rows = n_row_blocks * 128
1478+
padded_cols = n_col_blocks * 4
1479+
1480+
out = scale_tensor.new_empty((padded_rows, padded_cols))
1481+
1482+
# Input stride (for row-major format)
1483+
input_row_stride = cols
1484+
1485+
# We probably want handle multiple blocks per tile but for now keep it simple
1486+
BLOCK_ROWS, BLOCK_COLS = 128, 4
1487+
1488+
# Output block stride for the rearranged format
1489+
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
1490+
1491+
# Calculate grid dimensions
1492+
grid = lambda META: (
1493+
triton.cdiv(padded_rows, BLOCK_ROWS),
1494+
triton.cdiv(padded_cols, BLOCK_COLS),
1495+
)
1496+
1497+
# Launch kernel with added stride parameters
1498+
# TODO fix before land
1499+
# wrap_triton(scale_swizzle)[grid](
1500+
scale_swizzle[grid](
1501+
scale_tensor.view(torch.uint8),
1502+
rows,
1503+
cols,
1504+
out.view(torch.uint8),
1505+
input_row_stride,
1506+
output_block_stride,
1507+
BLOCK_ROWS=BLOCK_ROWS,
1508+
BLOCK_COLS=BLOCK_COLS,
1509+
)
1510+
1511+
return out
1512+
13861513
else:
13871514

13881515
def triton_to_mxfp8_dim1(
@@ -1394,3 +1521,6 @@ def triton_to_mxfp8_dim1_reference(
13941521
x_hp: torch.Tensor, block_size
13951522
) -> Tuple[torch.Tensor, torch.Tensor]:
13961523
raise AssertionError("needs torch version 2.8+ and triton")
1524+
1525+
def mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1526+
raise AssertionError("needs torch version 2.8+ and triton")

torchao/prototype/mx_formats/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66

77
import torch
88

9+
from torchao.prototype.mx_formats.custom_cast import mx_block_rearrange
10+
911
Tensor = torch.Tensor
1012

1113

1214
def ceil_div(a, b):
1315
return (a + b - 1) // b
1416

1517

16-
def to_blocked(input_matrix) -> Tensor:
18+
def to_blocked(input_matrix, swizzle_kernel: bool = False) -> Tensor:
1719
"""
1820
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
1921
@@ -26,6 +28,9 @@ def to_blocked(input_matrix) -> Tensor:
2628
Returns:
2729
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
2830
"""
31+
if swizzle_kernel:
32+
return mx_block_rearrange(input_matrix).flatten()
33+
2934
rows, cols = input_matrix.shape
3035
n_row_blocks = ceil_div(rows, 128)
3136
n_col_blocks = ceil_div(cols, 4)
@@ -36,7 +41,8 @@ def to_blocked(input_matrix) -> Tensor:
3641

3742
padded = input_matrix
3843
# TODO This is to work around VLLM's usage of compile w/ dynamic shapes
39-
if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols):
44+
# if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols):
45+
if (rows, cols) != (padded_rows, padded_cols):
4046
padded = torch.zeros(
4147
(padded_rows, padded_cols),
4248
device=input_matrix.device,

0 commit comments

Comments
 (0)