diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index 21ac2a297a..56fbaf1c01 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + from typing import Callable, Tuple import fire @@ -5,7 +11,7 @@ import triton from torch._inductor.utils import do_bench_using_profiling -from torchao.prototype.mx_formats.custom_cast import ( +from torchao.prototype.mx_formats.kernels import ( triton_to_mxfp8_dim1, ) from torchao.prototype.mx_formats.mx_tensor import to_mx diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_kernels.py similarity index 95% rename from test/prototype/mx_formats/test_custom_cast.py rename to test/prototype/mx_formats/test_kernels.py index bce0b3913c..276d180046 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -16,7 +16,17 @@ F6_E2M3_EXP_BIAS, F6_E3M2_EXP_BIAS, ) -from torchao.prototype.mx_formats.custom_cast import ( +from torchao.prototype.mx_formats.fp_format_spec import ( + _assert_equals, + dtype_to_interesting_values, + float4_e2m1_interesting_values, + float6_e2m3_interesting_values, + float6_e3m2_interesting_values, + get_sem_bits, + sem_bits_to_sem_vals, + sem_vals_to_f32, +) +from torchao.prototype.mx_formats.kernels import ( f4_unpacked_to_f32, f6_e2m3_unpacked_to_f32, f6_e3m2_unpacked_to_f32, @@ -33,17 +43,8 @@ triton_to_mxfp8_dim1_reference, unpack_uint4, ) -from torchao.prototype.mx_formats.fp_format_spec import ( - _assert_equals, - dtype_to_interesting_values, - float4_e2m1_interesting_values, - float6_e2m3_interesting_values, - float6_e3m2_interesting_values, - get_sem_bits, - sem_bits_to_sem_vals, - sem_vals_to_f32, -) from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, @@ -465,3 +466,24 @@ def test_triton_mxfp8_dim1_randn(M, K): x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "shape", + [ + (63, 1023), + (128, 4), + (128, 8), + (256, 8), + (300, 9), + (133, 512), + (528, 512), + (128, 1), + ], +) +def test_rearrange(shape): + scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8) + eager = to_blocked(scales, False) + triton = to_blocked(scales, True) + torch.testing.assert_close(eager, triton, atol=0, rtol=0) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 76f340dc78..51ede29bcb 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -17,7 +17,7 @@ DTYPE_FP6_E3M2, SUPPORTED_ELEM_DTYPES, ) -from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6 +from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6 from torchao.prototype.mx_formats.mx_tensor import ( MXTensor, ScaleCalculationMode, diff --git a/torchao/prototype/mx_formats/fp_format_spec.py b/torchao/prototype/mx_formats/fp_format_spec.py index bdc0cc4dfd..fc9521ef66 100644 --- a/torchao/prototype/mx_formats/fp_format_spec.py +++ b/torchao/prototype/mx_formats/fp_format_spec.py @@ -20,7 +20,7 @@ DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, ) -from torchao.prototype.mx_formats.custom_cast import get_bits +from torchao.prototype.mx_formats.kernels import get_bits dtype_to_bitwidth = { torch.float: 32, diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/kernels.py similarity index 92% rename from torchao/prototype/mx_formats/custom_cast.py rename to torchao/prototype/mx_formats/kernels.py index 3f870b4f28..f643ac3106 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1383,6 +1383,124 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1, ) + @triton.jit + def triton_scale_swizzle( + scale_ptr, + scale_rows, + scale_cols, + output_ptr, + input_row_stride, + output_block_stride, + BLOCK_ROWS: tl.constexpr, + BLOCK_COLS: tl.constexpr, + ): + """ + Rearranges tensor data from row-major to block-scaled swizzle format. + + Args: + scale_ptr: Pointer to the input scale tensor + scale_rows: Number of rows in the scale tensor + scale_cols: Number of columns in the scale tensor + output_ptr: Pointer to the output tensor + input_row_stride: Stride between rows in the input tensor + output_block_stride: Stride between blocks in the output tensor + BLOCK_ROWS: Number of rows in a tile (compile-time constant) + BLOCK_COLS: Number of columns in a tile (compile-time constant) + """ + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + rows = tl.arange(0, BLOCK_ROWS)[:, None] + cols = tl.arange(0, BLOCK_COLS)[None, :] + + # Calculate starting row and column for this tile + start_row = pid_row * BLOCK_ROWS + start_col = pid_col * BLOCK_COLS + global_rows = start_row + rows + global_cols = start_col + cols + + mask = (global_rows < scale_rows) & (global_cols < scale_cols) + + input_scales = tl.load( + scale_ptr + global_rows * input_row_stride + global_cols, + mask=mask, + other=0.0, + ) + + r_div_32 = rows // 32 + r_mod_32 = rows % 32 + + # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates + dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols + + # Flatten + dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) + scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) + + # Calculate block offset using provided output block stride + LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS + block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride) + + tl.store( + output_ptr + block_offset + dest_indices_flat, + scales_flat, + ) + + def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: + """ + Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format. + + This format is suitable for Tmem as described in NVIDIA documentation: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + scale_tensor: Input tensor in row-major format with 8-bit elements + + Returns: + Rearranged tensor in block-scaled swizzle format + """ + assert scale_tensor.element_size() == 1, ( + "Expected element size to be 1 byte (8 bits)" + ) + assert scale_tensor.is_contiguous(), "Input tensor must be contiguous" + + rows, cols = scale_tensor.shape + + # Calculate blocks needed + n_row_blocks = triton.cdiv(rows, 128) + n_col_blocks = triton.cdiv(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + out = scale_tensor.new_empty((padded_rows, padded_cols)) + + # Input stride (for row-major format) + input_row_stride = cols + + # We probably want handle multiple blocks per tile but for now keep it simple + BLOCK_ROWS, BLOCK_COLS = 128, 4 + + # Output block stride for the rearranged format + output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS) + + grid = lambda META: ( + triton.cdiv(padded_rows, BLOCK_ROWS), + triton.cdiv(padded_cols, BLOCK_COLS), + ) + + wrap_triton(triton_scale_swizzle)[grid]( + scale_tensor.view(torch.uint8), + rows, + cols, + out.view(torch.uint8), + input_row_stride, + output_block_stride, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) + + return out + else: def triton_to_mxfp8_dim1( @@ -1394,3 +1512,6 @@ def triton_to_mxfp8_dim1_reference( x_hp: torch.Tensor, block_size ) -> Tuple[torch.Tensor, torch.Tensor]: raise AssertionError("needs torch version 2.8+ and triton") + + def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: + raise AssertionError("needs torch version 2.8+ and triton") diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 067613afb7..4db029480f 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -18,7 +18,7 @@ MXInferenceLinearConfig, MXLinearConfig, ) -from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1 +from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim1 from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.quantization.transform_module import ( register_quantize_module_handler, diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index f3aca15a73..3125f3c0cc 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -45,7 +45,7 @@ F32_MIN_NORMAL, SUPPORTED_ELEM_DTYPES, ) -from torchao.prototype.mx_formats.custom_cast import ( +from torchao.prototype.mx_formats.kernels import ( f4_unpacked_to_f32, f6_e2m3_unpacked_to_f32, f6_e3m2_unpacked_to_f32, diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 8b186f82d6..2c828e477c 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -6,6 +6,8 @@ import torch +from torchao.prototype.mx_formats.kernels import triton_mx_block_rearrange + Tensor = torch.Tensor @@ -13,7 +15,7 @@ def ceil_div(a, b): return (a + b - 1) // b -def to_blocked(input_matrix) -> Tensor: +def to_blocked(input_matrix, use_triton_kernel: bool = True) -> Tensor: """ Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. @@ -22,10 +24,15 @@ def to_blocked(input_matrix) -> Tensor: Args: input_matrix: Input tensor of shape (H, W) + use_triton_kernel: Whether to use a triton implementation instead of relying on + torch.compile Returns: Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) """ + if use_triton_kernel: + return triton_mx_block_rearrange(input_matrix).flatten() + rows, cols = input_matrix.shape n_row_blocks = ceil_div(rows, 128) n_col_blocks = ceil_div(cols, 4) @@ -35,6 +42,8 @@ def to_blocked(input_matrix) -> Tensor: padded_cols = n_col_blocks * 4 padded = input_matrix + # TODO This is to work around VLLM's usage of compile w/ dynamic shapes + # if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols): if (rows, cols) != (padded_rows, padded_cols): padded = torch.zeros( (padded_rows, padded_cols),