Skip to content

Add a triton kernel for swizziling #2168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Tuple

import fire
import torch
import triton
from torch._inductor.utils import do_bench_using_profiling

from torchao.prototype.mx_formats.custom_cast import (
from torchao.prototype.mx_formats.kernels import (
triton_to_mxfp8_dim1,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@
F6_E2M3_EXP_BIAS,
F6_E3M2_EXP_BIAS,
)
from torchao.prototype.mx_formats.custom_cast import (
from torchao.prototype.mx_formats.fp_format_spec import (
_assert_equals,
dtype_to_interesting_values,
float4_e2m1_interesting_values,
float6_e2m3_interesting_values,
float6_e3m2_interesting_values,
get_sem_bits,
sem_bits_to_sem_vals,
sem_vals_to_f32,
)
from torchao.prototype.mx_formats.kernels import (
f4_unpacked_to_f32,
f6_e2m3_unpacked_to_f32,
f6_e3m2_unpacked_to_f32,
Expand All @@ -33,17 +43,8 @@
triton_to_mxfp8_dim1_reference,
unpack_uint4,
)
from torchao.prototype.mx_formats.fp_format_spec import (
_assert_equals,
dtype_to_interesting_values,
float4_e2m1_interesting_values,
float6_e2m3_interesting_values,
float6_e3m2_interesting_values,
get_sem_bits,
sem_bits_to_sem_vals,
sem_vals_to_f32,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
Expand Down Expand Up @@ -465,3 +466,24 @@ def test_triton_mxfp8_dim1_randn(M, K):
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"shape",
[
(63, 1023),
(128, 4),
(128, 8),
(256, 8),
(300, 9),
(133, 512),
(528, 512),
(128, 1),
],
)
def test_rearrange(shape):
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8)
eager = to_blocked(scales, False)
triton = to_blocked(scales, True)
torch.testing.assert_close(eager, triton, atol=0, rtol=0)
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
DTYPE_FP6_E3M2,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6
from torchao.prototype.mx_formats.mx_tensor import (
MXTensor,
ScaleCalculationMode,
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/mx_formats/fp_format_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
)
from torchao.prototype.mx_formats.custom_cast import get_bits
from torchao.prototype.mx_formats.kernels import get_bits

dtype_to_bitwidth = {
torch.float: 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,124 @@ def triton_to_mxfp8_dim1_reference(
scale_e8m0_dim1,
)

@triton.jit
def triton_scale_swizzle(
scale_ptr,
scale_rows,
scale_cols,
output_ptr,
input_row_stride,
output_block_stride,
BLOCK_ROWS: tl.constexpr,
BLOCK_COLS: tl.constexpr,
):
"""
Rearranges tensor data from row-major to block-scaled swizzle format.

Args:
scale_ptr: Pointer to the input scale tensor
scale_rows: Number of rows in the scale tensor
scale_cols: Number of columns in the scale tensor
output_ptr: Pointer to the output tensor
input_row_stride: Stride between rows in the input tensor
output_block_stride: Stride between blocks in the output tensor
BLOCK_ROWS: Number of rows in a tile (compile-time constant)
BLOCK_COLS: Number of columns in a tile (compile-time constant)
"""
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)

rows = tl.arange(0, BLOCK_ROWS)[:, None]
cols = tl.arange(0, BLOCK_COLS)[None, :]

# Calculate starting row and column for this tile
start_row = pid_row * BLOCK_ROWS
start_col = pid_col * BLOCK_COLS
global_rows = start_row + rows
global_cols = start_col + cols

mask = (global_rows < scale_rows) & (global_cols < scale_cols)

input_scales = tl.load(
scale_ptr + global_rows * input_row_stride + global_cols,
mask=mask,
other=0.0,
)

r_div_32 = rows // 32
r_mod_32 = rows % 32

# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols

# Flatten
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))

# Calculate block offset using provided output block stride
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)

tl.store(
output_ptr + block_offset + dest_indices_flat,
scales_flat,
)

def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
"""
Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.

This format is suitable for Tmem as described in NVIDIA documentation:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout

Args:
scale_tensor: Input tensor in row-major format with 8-bit elements

Returns:
Rearranged tensor in block-scaled swizzle format
"""
assert scale_tensor.element_size() == 1, (
"Expected element size to be 1 byte (8 bits)"
)
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"

rows, cols = scale_tensor.shape

# Calculate blocks needed
n_row_blocks = triton.cdiv(rows, 128)
n_col_blocks = triton.cdiv(cols, 4)
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4

out = scale_tensor.new_empty((padded_rows, padded_cols))

# Input stride (for row-major format)
input_row_stride = cols

# We probably want handle multiple blocks per tile but for now keep it simple
BLOCK_ROWS, BLOCK_COLS = 128, 4

# Output block stride for the rearranged format
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)

grid = lambda META: (
triton.cdiv(padded_rows, BLOCK_ROWS),
triton.cdiv(padded_cols, BLOCK_COLS),
)

wrap_triton(triton_scale_swizzle)[grid](
scale_tensor.view(torch.uint8),
rows,
cols,
out.view(torch.uint8),
input_row_stride,
output_block_stride,
BLOCK_ROWS=BLOCK_ROWS,
BLOCK_COLS=BLOCK_COLS,
)

return out

else:

def triton_to_mxfp8_dim1(
Expand All @@ -1394,3 +1512,6 @@ def triton_to_mxfp8_dim1_reference(
x_hp: torch.Tensor, block_size
) -> Tuple[torch.Tensor, torch.Tensor]:
raise AssertionError("needs torch version 2.8+ and triton")

def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
raise AssertionError("needs torch version 2.8+ and triton")
2 changes: 1 addition & 1 deletion torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
MXInferenceLinearConfig,
MXLinearConfig,
)
from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim1
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.quantization.transform_module import (
register_quantize_module_handler,
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
F32_MIN_NORMAL,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.custom_cast import (
from torchao.prototype.mx_formats.kernels import (
f4_unpacked_to_f32,
f6_e2m3_unpacked_to_f32,
f6_e3m2_unpacked_to_f32,
Expand Down
11 changes: 10 additions & 1 deletion torchao/prototype/mx_formats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

import torch

from torchao.prototype.mx_formats.kernels import triton_mx_block_rearrange

Tensor = torch.Tensor


def ceil_div(a, b):
return (a + b - 1) // b


def to_blocked(input_matrix) -> Tensor:
def to_blocked(input_matrix, use_triton_kernel: bool = True) -> Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.

Expand All @@ -22,10 +24,15 @@ def to_blocked(input_matrix) -> Tensor:

Args:
input_matrix: Input tensor of shape (H, W)
use_triton_kernel: Whether to use a triton implementation instead of relying on
torch.compile

Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
if use_triton_kernel:
return triton_mx_block_rearrange(input_matrix).flatten()

rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
Expand All @@ -35,6 +42,8 @@ def to_blocked(input_matrix) -> Tensor:
padded_cols = n_col_blocks * 4

padded = input_matrix
# TODO This is to work around VLLM's usage of compile w/ dynamic shapes
# if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols):
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros(
(padded_rows, padded_cols),
Expand Down
Loading