Skip to content

Commit a1b6365

Browse files
committed
Add a triton kernel for swizziling
stack-info: PR: #2168, branch: drisspg/stack/53
1 parent 44a878b commit a1b6365

File tree

4 files changed

+142
-2
lines changed

4 files changed

+142
-2
lines changed

test/prototype/mx_formats/test_custom_cast.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
sem_vals_to_f32,
4545
)
4646
from torchao.prototype.mx_formats.mx_tensor import MXTensor
47+
from torchao.prototype.mx_formats.utils import to_blocked
4748
from torchao.utils import (
4849
TORCH_VERSION_AT_LEAST_2_8,
4950
is_sm_at_least_89,
@@ -465,3 +466,12 @@ def test_triton_mxfp8_dim1_randn(M, K):
465466
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
466467
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
467468
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
469+
470+
471+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
472+
@pytest.mark.parametrize("shape", [(63, 1023), (128, 4), (128, 8), (256, 8), (300, 9)])
473+
def test_rearrange(shape):
474+
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8)
475+
eager = to_blocked(scales)
476+
triton = to_blocked(scales, True)
477+
torch.testing.assert_close(eager, triton, atol=0, rtol=0)

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,128 @@ def triton_to_mxfp8_dim1_reference(
13831383
scale_e8m0_dim1,
13841384
)
13851385

1386+
@triton.jit
1387+
def scale_swizzle(
1388+
scale_ptr,
1389+
scale_rows,
1390+
scale_cols,
1391+
output_ptr,
1392+
input_row_stride, # Added parameter for input row stride
1393+
output_block_stride, # Added parameter for output block stride
1394+
BLOCK_ROWS: tl.constexpr,
1395+
BLOCK_COLS: tl.constexpr,
1396+
):
1397+
"""
1398+
Rearranges tensor data from row-major to block-scaled swizzle format.
1399+
1400+
The transformation follows NVIDIA's block scaling factors layout:
1401+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1402+
"""
1403+
pid_row = tl.program_id(0)
1404+
pid_col = tl.program_id(1)
1405+
1406+
rows = tl.arange(0, BLOCK_ROWS)[:, None]
1407+
cols = tl.arange(0, BLOCK_COLS)[None, :]
1408+
1409+
# Calculate starting row and column for this tile
1410+
start_row = pid_row * BLOCK_ROWS
1411+
start_col = pid_col * BLOCK_COLS
1412+
global_rows = start_row + rows
1413+
global_cols = start_col + cols
1414+
1415+
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
1416+
1417+
input_scales = tl.load(
1418+
scale_ptr + global_rows * input_row_stride + global_cols,
1419+
mask=mask,
1420+
other=0.0,
1421+
)
1422+
1423+
# Block rearrangement logic for the _to_blocked_single transformation:
1424+
# 1) Divide into 4×32 blocks
1425+
r_div_32 = rows // 32
1426+
r_mod_32 = rows % 32
1427+
1428+
# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
1429+
# row = r_mod_32, col = (r_div_32 * 4 + inner_col)
1430+
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
1431+
1432+
# Flatten indices for storage
1433+
dest_indices_flat = tl.reshape(
1434+
dest_indices, (BLOCK_ROWS * BLOCK_COLS), can_reorder=True
1435+
)
1436+
1437+
# Calculate block offset using provided output block stride
1438+
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
1439+
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)
1440+
1441+
# Store the rearranged values
1442+
tl.store(
1443+
output_ptr + block_offset + dest_indices_flat,
1444+
tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS), can_reorder=True),
1445+
)
1446+
1447+
def triton_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
1448+
"""
1449+
Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
1450+
1451+
This format is suitable for Tmem as described in NVIDIA documentation:
1452+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1453+
1454+
Args:
1455+
scale_tensor: Input tensor in row-major format with 8-bit elements
1456+
1457+
Returns:
1458+
Rearranged tensor in block-scaled swizzle format
1459+
"""
1460+
# Validate input
1461+
assert scale_tensor.element_size() == 1, (
1462+
"Expected element size to be 1 byte (8 bits)"
1463+
)
1464+
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
1465+
1466+
# Get dimensions
1467+
rows, cols = scale_tensor.shape
1468+
1469+
# Calculate blocks needed
1470+
n_row_blocks = triton.cdiv(rows, 128)
1471+
n_col_blocks = triton.cdiv(cols, 4)
1472+
1473+
# Calculate padded dimensions
1474+
padded_rows = n_row_blocks * 128
1475+
padded_cols = n_col_blocks * 4
1476+
1477+
# Create output tensor
1478+
out = scale_tensor.new_empty((padded_rows, padded_cols))
1479+
1480+
# Input stride (for row-major format)
1481+
input_row_stride = cols
1482+
1483+
BLOCK_ROWS, BLOCK_COLS = 128, 4
1484+
1485+
# Output block stride for the rearranged format
1486+
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
1487+
1488+
# Calculate grid dimensions
1489+
grid = lambda META: (
1490+
triton.cdiv(padded_rows, BLOCK_ROWS),
1491+
triton.cdiv(padded_cols, BLOCK_COLS),
1492+
)
1493+
1494+
# Launch kernel with added stride parameters
1495+
wrap_triton(scale_swizzle)[grid](
1496+
scale_tensor.view(torch.uint8),
1497+
rows,
1498+
cols,
1499+
out.view(torch.uint8),
1500+
input_row_stride,
1501+
output_block_stride,
1502+
BLOCK_ROWS=BLOCK_ROWS,
1503+
BLOCK_COLS=BLOCK_COLS,
1504+
)
1505+
1506+
return out
1507+
13861508
else:
13871509

13881510
def triton_to_mxfp8_dim1(

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def to_mx(
178178
# Add an epsilon to prevent the log2 function call for returning -inf
179179
# where the values are zero.
180180
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)
181+
# eps = torch.finfo(torch.float32).tiny
182+
# eps_tensor = torch.full(max_abs.shape, eps, device=max_abs.device, dtype=max_abs.dtype)
183+
# max_abs = torch.maximum(max_abs, eps_tensor)
181184

182185
# Set X to be the largest power-of-two less than or equal to
183186
# max_abs(v), divided by the largest power of two representable

torchao/prototype/mx_formats/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66

77
import torch
88

9+
from torchao.prototype.mx_formats.custom_cast import triton_block_rearrange
10+
911
Tensor = torch.Tensor
1012

1113

1214
def ceil_div(a, b):
1315
return (a + b - 1) // b
1416

1517

16-
def to_blocked(input_matrix) -> Tensor:
18+
def to_blocked(input_matrix, swizzle_kernel: bool = True) -> Tensor:
1719
"""
1820
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
1921
@@ -26,6 +28,8 @@ def to_blocked(input_matrix) -> Tensor:
2628
Returns:
2729
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
2830
"""
31+
if swizzle_kernel:
32+
return triton_block_rearrange(input_matrix).flatten()
2933
rows, cols = input_matrix.shape
3034
n_row_blocks = ceil_div(rows, 128)
3135
n_col_blocks = ceil_div(cols, 4)
@@ -36,7 +40,8 @@ def to_blocked(input_matrix) -> Tensor:
3640

3741
padded = input_matrix
3842
# TODO This is to work around VLLM's usage of compile w/ dynamic shapes
39-
if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols):
43+
# if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols):
44+
if (rows, cols) != (padded_rows, padded_cols):
4045
padded = torch.zeros(
4146
(padded_rows, padded_cols),
4247
device=input_matrix.device,

0 commit comments

Comments
 (0)