Skip to content

Commit f91230f

Browse files
authored
[TRITON_KERNELS] some refactoring (#9134)
1 parent 2de1c5b commit f91230f

32 files changed

Lines changed: 1037 additions & 772 deletions

python/triton_kernels/tests/test_matmul.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from triton_kernels.swiglu import swiglu, swiglu_fn
2222
from triton_kernels.swiglu import PrecisionConfig as SwiGLUPrecisionConfig
2323
from triton_kernels.tensor_details import layout
24+
from triton_kernels.tensor_details.dtype import FP32
25+
2426
# ---------------
2527
# numerics stuff
2628
# ---------------
@@ -134,9 +136,9 @@ def _build_test_op_cases():
134136
Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1"),
135137
Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9),
136138
Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9, b_hbm_swizzling=True),
137-
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn"),
139+
Case(300, 400, 416, "ragged", "float8_e5m2", "mxfloat8_e4m3fn"),
138140
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1"),
139-
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn"),
141+
Case(300, 400, 416, "batched", "float8_e5m2", "mxfloat8_e4m3fn"),
140142
])
141143
# mxfloat x mxfloat
142144
test_cases.extend([
@@ -145,11 +147,11 @@ def _build_test_op_cases():
145147
Case(1024, 1024, 1024, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", split_k=9, colmajor_mxfp_weight=False),
146148
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True, a_hbm_swizzling=True),
147149
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True, a_hbm_swizzling=True),
148-
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True, a_hbm_swizzling=True),
150+
Case(300, 400, 416, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True, a_hbm_swizzling=True),
149151
Case(256, 1024, 512, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True, a_hbm_swizzling=True),
150-
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"),
151-
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", b_hbm_swizzling=True),
152-
Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"),
152+
Case(300, 400, 416, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"),
153+
Case(300, 400, 416, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", b_hbm_swizzling=True),
154+
Case(300, 400, 416, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"),
153155
Case(1024, 1024, 1024, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True),
154156
])
155157
# amd-specific float8
@@ -340,9 +342,7 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
340342
ragged_padding = inner_expt_opt is not None and "pad_a" in inner_expt_opt,
341343
squeeze_batch_dim = mode == "plain",
342344
scale_hbm_swizzling = layout.make_default_matmul_mxfp8_act_scale_layout if a_hbm_swizzling else None,
343-
scale_hbm_swizzling_args = {"ragged_metadata": None}, # ragged_metadata will be set in the make_random_tensor function
344345
)
345-
346346
b, b_scale_tri, b_ragged_metadata = make_random_tensor(
347347
shape=(k, n),
348348
n_slices = n_slices,
@@ -354,10 +354,8 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
354354
ragged_padding = inner_expt_opt is not None and "pad_b" in inner_expt_opt,
355355
squeeze_batch_dim = mode == "plain",
356356
is_mx_rowmajor = not colmajor_mxfp_weight,
357-
value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
358-
value_hbm_swizzling_args = {"mx_axis":-2},
359-
scale_hbm_swizzling = layout.make_default_matmul_mxfp4_w_scale_layout if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
360-
scale_hbm_swizzling_args = dict(mx_axis=-2, num_warps=num_warps),
357+
value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout(mx_axis=-2) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
358+
scale_hbm_swizzling = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=-2, num_warps=num_warps) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
361359
)
362360
gather_indx = None if not do_gather else torch.randint(0, max(m, 1), (m, ), dtype=torch.int32, device=device)
363361
scatter_indx = None if not do_scatter else torch.randperm(m, dtype=torch.int32, device=device)
@@ -442,6 +440,6 @@ def test_set_idle_sms():
442440
from triton_kernels.matmul_details.opt_flags import make_opt_flags
443441
num_idle_sms = 24
444442
matmul_set_idle_sms(num_idle_sms)
445-
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
443+
flags = make_opt_flags(FP32, FP32, FP32, PrecisionConfig(), \
446444
1, 1024, 1024, 1024, None, True, False, 1, False, False, None)
447445
assert flags.idle_sms == num_idle_sms

python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
import triton_kernels.matmul_details.opt_flags as opt_flags
9-
9+
from triton_kernels.tensor_details.dtype import FP16
1010

1111
class _DummyPrecisionConfig:
1212
def __init__(self):
@@ -84,9 +84,9 @@ def test_make_default_opt_flags_amd_split_k_constraint(monkeypatch):
8484

8585
precision_config = _DummyPrecisionConfig()
8686
flags = opt_flags.make_default_opt_flags_amd(
87-
torch.float16,
88-
torch.float16,
89-
torch.float16,
87+
FP16,
88+
FP16,
89+
FP16,
9090
precision_config,
9191
2,
9292
128,

python/triton_kernels/tests/test_mxfp.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ def test_mxfp4_rounding_cases(dst_dtype, device):
2929
torch.float16: 0.250244140625,
3030
torch.float32: 0.2500000298023223877,
3131
}[dst_dtype]
32+
pad_values = [0] * 22
3233
# Construct an example where scale is 1 (when max value is 6.0, the maximum value of e2m1)
33-
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3, -1.25, two_point_five_plus_ulp], dtype=dst_dtype,
34-
device=device).view(1, -1, 1)
34+
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3, -1.25, two_point_five_plus_ulp] + pad_values,
35+
dtype=dst_dtype, device=device).view(1, -1, 1)
3536
quant, scale = downcast_to_mxfp(x, torch.uint8, axis=1)
3637
dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1)
3738
# Tie-breaking cases (RTNE):
@@ -42,7 +43,7 @@ def test_mxfp4_rounding_cases(dst_dtype, device):
4243
# - -1.25 is halfway between -1.0 and -1.5. RTNE selects -1.0 (even). Away-from-zero would pick -1.5;
4344
# towards-zero would pick -1.0.
4445
# - two_point_five_plus_ulp is slightly bigger than 0.25, so it rounds to 0.5.
45-
assert dequant.flatten().tolist() == [6, 0, 0, 0.0, 1.0, 1.0, 1.0, 1.5, -1.0, 0.5], f"{dequant=}"
46+
assert dequant.flatten().tolist() == [6, 0, 0, 0.0, 1.0, 1.0, 1.0, 1.5, -1.0, 0.5] + pad_values, f"{dequant=}"
4647

4748
quant_torch, scale_torch = downcast_to_mxfp_torch(x, torch.uint8, axis=1)
4849
assert_equal(quant_torch, quant)
@@ -56,7 +57,9 @@ def test_mxfp4_rounding_cases(dst_dtype, device):
5657
# 2**floor(log2(33/(e2m1 max power of 2 = 4)) = 2**3 = 8 (exponent 127+3),
5758
# and the other values are multiples of representable FP4 values times 8
5859
# that allow exact reconstruction.
59-
x = torch.tensor([33.0, 24.0, 16.0, 8.0, 4.0, 0.0, -32.0, 0.0], device=device).bfloat16().view(1, -1, 1)
60+
pad_values = [0] * 24
61+
x = torch.tensor([33.0, 24.0, 16.0, 8.0, 4.0, 0.0, -32.0, 0.0] + pad_values,
62+
device=device).bfloat16().view(1, -1, 1)
6063
quant, scale = downcast_to_mxfp(
6164
x,
6265
torch.uint8,
@@ -88,7 +91,8 @@ def test_mxfp_extreme_values(src_dtype, dst_dtype, device):
8891
src_dtype = dtype_str_to_torch(src_dtype)
8992
dst_dtype = dtype_str_to_torch(dst_dtype)
9093
BIG_VALUE = 65470 if dst_dtype == torch.float16 else 3.3895e38
91-
x = torch.tensor([BIG_VALUE, BIG_VALUE], dtype=dst_dtype, device=device)
94+
pad_values = [0] * 30
95+
x = torch.tensor([BIG_VALUE, BIG_VALUE] + pad_values, dtype=dst_dtype, device=device)
9296
xq_value, xq_scale = downcast_to_mxfp(x, src_dtype, axis=-1)
9397
xdq = upcast_from_mxfp(xq_value, xq_scale, dst_dtype, axis=-1)
9498
xdq_ref = upcast_from_mxfp_torch(xq_value, xq_scale, dst_dtype, axis=-1)
@@ -127,6 +131,7 @@ def test_mxfp_quant_dequant(src_dtype, dst_dtype, device):
127131
weight = weight.repeat((9, 32)) # Repeat the dimensions to test multi block launches.
128132
weight = weight.reshape([1, weight.shape[0], weight.shape[1]])
129133
weight = weight.mT.contiguous().mT
134+
weight = torch.nn.functional.pad(weight, (0, 0, 0, 16))
130135
quant, scale = downcast_to_mxfp(weight, src_dtype, axis=1)
131136
dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1)
132137
assert_equal(weight, dequant)
@@ -143,7 +148,7 @@ def test_mxfp_quant_dequant(src_dtype, dst_dtype, device):
143148
((0, 0, 1024), 2, "float8_e4m3fn", DequantScaleRoundingMode.ROUND_DOWN),
144149
145150
((3, 4096, 1024), 1, "float4_e2m1", DequantScaleRoundingMode.ROUND_UP),
146-
((10, 254, 60), 0, "float4_e2m1", DequantScaleRoundingMode.ROUND_DOWN),
151+
((32, 254, 60), 0, "float4_e2m1", DequantScaleRoundingMode.ROUND_DOWN),
147152
((1, 320, 160), 2, "float8_e5m2", DequantScaleRoundingMode.ROUND_UP),
148153
((2, 16, 512), -1, "float8_e4m3fn", DequantScaleRoundingMode.ROUND_DOWN),
149154
],

python/triton_kernels/tests/test_tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import pytest
22
import torch
3-
from triton_kernels.tensor import Bitmatrix, BIT
3+
from triton_kernels.tensor_details.dtype import BIT
44
from triton_kernels.tensor import (
55
make_ragged_tensor_metadata,
66
make_ragged_tensor_metadata_torch,
77
remap_ragged_tensor_metadata,
88
remap_ragged_tensor_metadata_torch,
99
make_bitmatrix_metadata,
1010
make_bitmatrix_metadata_torch,
11+
wrap_torch_tensor,
1112
)
1213
from triton_kernels.testing import assert_equal
1314

@@ -65,7 +66,7 @@ def test_make_bitmatrix_metadata(n_rows, n_cols, k):
6566
rows = torch.arange(n_rows, device=device).unsqueeze(1).expand_as(indx)
6667
bitmask_data = torch.zeros((n_rows, (n_cols + 31) // 32), dtype=torch.int32, device=device)
6768
bitmask_data.index_put_((rows, indx // 32), 1 << (indx % 32), accumulate=True)
68-
bitmask = Bitmatrix(bitmask_data.view(torch.uint32), dtype=BIT, shape=(n_rows, n_cols))
69+
bitmask = wrap_torch_tensor(bitmask_data.view(torch.uint32), dtype=BIT, shape=(n_rows, n_cols))
6970
# make metadata and compare
7071
metadata_tri = make_bitmatrix_metadata(indx, bitmask)
7172
metadata_ref = make_bitmatrix_metadata_torch(indx, bitmask)

python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@
2020
)
2121
def test_mxfp4_scale_roundtrip(shape):
2222
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
23-
layout = BlackwellMXScaleLayout(x.shape)
24-
res = layout.unswizzle_data(layout.swizzle_data(x))
23+
layout = BlackwellMXScaleLayout()
24+
transformation = layout.make_transformation(x.shape, is_fp4=False)
25+
res = transformation.unswizzle_data(transformation.swizzle_data(x))
2526
assert (res == x).all()
2627

2728

2829
@pytest.mark.parametrize("shape", [(2, 256, 192), (1, 128, 64)])
2930
def test_act_scale_roundtrip_batched(shape):
3031
x = torch.randn(shape, device="cuda", dtype=torch.float32)
31-
layout = BlackwellActMXScaleLayout(x.shape)
32-
res = layout.unswizzle_data(layout.swizzle_data(x))
32+
layout = BlackwellActMXScaleLayout(ragged_metadata=None)
33+
transformation = layout.make_transformation(x.shape, is_fp4=False)
34+
res = transformation.unswizzle_data(transformation.swizzle_data(x))
3335
torch.testing.assert_close(res, x)
3436

3537

@@ -45,8 +47,9 @@ def test_act_scale_roundtrip_ragged(slice_sizes, m, k, align_m):
4547
m = max(m, slice_sizes.sum().item()) # there can be padded tokens in the input
4648
ragged_metadata = make_ragged_tensor_metadata(slice_sizes, m)
4749
x = torch.randn((m, k), device="cuda", dtype=torch.float32)
48-
layout = BlackwellActMXScaleLayout((m, k), ragged_metadata=ragged_metadata)
49-
res = layout.unswizzle_data(layout.swizzle_data(x))
50+
layout = BlackwellActMXScaleLayout(ragged_metadata=ragged_metadata)
51+
transformation = layout.make_transformation(x.shape, is_fp4=False)
52+
res = transformation.unswizzle_data(transformation.swizzle_data(x))
5053

5154
x_useful_rows = x[ragged_metadata.slice_offs[:-1], :]
5255
res_useful_rows = res[ragged_metadata.slice_offs[:-1], :]

python/triton_kernels/tests/test_tensor_details/test_layout_cdna4.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
def test_mxfp4_scale_roundtrip(shape):
2121
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
22-
layout = CDNA4MXScaleLayout(x.shape)
23-
res = layout.unswizzle_data(layout.swizzle_data(x))
22+
layout = CDNA4MXScaleLayout()
23+
transformation = layout.make_transformation(x.shape, is_fp4=False)
24+
res = transformation.unswizzle_data(transformation.swizzle_data(x))
2425
assert (res == x).all()

python/triton_kernels/tests/test_tensor_details/test_layout_hopper.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from triton._internal_testing import is_cuda
3-
from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4, get_layout
3+
from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4
44
from triton_kernels.tensor_details.layout import HopperMXScaleLayout, HopperMXValueLayout
55
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
66
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
@@ -25,8 +25,11 @@ def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
2525
x = x.mT
2626
if x.shape[1 - mx_axis] < 32:
2727
pytest.skip("Not enough elements along non-mx axis")
28-
layout = HopperMXValueLayout(x.shape, mx_axis, mma_version)
29-
res = layout.unswizzle_data(layout.swizzle_data(x))
28+
layout = HopperMXValueLayout(mx_axis, mma_version)
29+
shape = list(x.shape)
30+
shape[-1] *= 2
31+
transformation = layout.make_transformation(shape, is_fp4=False)
32+
res = transformation.unswizzle_data(transformation.swizzle_data(x))
3033
assert (res == x).all()
3134

3235

@@ -35,8 +38,9 @@ def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
3538
@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)])
3639
def test_mxfp4_scale_roundtrip(shape, mx_axis, num_warps):
3740
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
38-
layout = HopperMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps)
39-
res = layout.unswizzle_data(layout.swizzle_data(x))
41+
layout = HopperMXScaleLayout(mx_axis=mx_axis, num_warps=num_warps)
42+
transformation = layout.make_transformation(x.shape, is_fp4=False)
43+
res = transformation.unswizzle_data(transformation.swizzle_data(x))
4044
assert (res[:shape[0], :shape[1]] == x).all()
4145

4246

@@ -85,13 +89,13 @@ def test_upcast_mxfp4_to_bf16(num_warps, mx_axis):
8589
x_bf16 = upcast_from_mxfp(x_fp4_val, x_fp4_scale, x.dtype, axis=mx_axis)
8690
x_fp4_val = wrap_torch_tensor(x_fp4_val, dtype=FP4)
8791
x_fp4_scale = wrap_torch_tensor(x_fp4_scale)
88-
x_fp4_val = convert_layout(x_fp4_val, HopperMXValueLayout, mx_axis=mx_axis)
89-
x_fp4_scale = convert_layout(x_fp4_scale, HopperMXScaleLayout, mx_axis=mx_axis, num_warps=num_warps)
92+
x_fp4_val = convert_layout(x_fp4_val, HopperMXValueLayout(mx_axis=mx_axis - 2, mma_version=3))
93+
x_fp4_scale = convert_layout(x_fp4_scale, HopperMXScaleLayout(mx_axis=mx_axis - 2, num_warps=num_warps))
9094
y = torch.empty_like(x_bf16)
9195
scale_block = [s // 32 if i == mx_axis else s for i, s in enumerate(shape)]
92-
scale_block = get_layout(x_fp4_scale).swizzle_block_shape(scale_block)
96+
scale_block = x_fp4_scale.storage.layout.swizzle_block_shape(scale_block)
9397
value_block = [s // 2 if i == mx_axis else s for i, s in enumerate(shape)]
94-
value_block = get_layout(x_fp4_val).swizzle_block_shape(value_block)
98+
value_block = x_fp4_val.storage.layout.swizzle_block_shape(value_block)
9599
_upcast_mxfp4_to_bf16[(1, )](
96100
y, x_fp4_val.storage.data, x_fp4_scale.storage.data, #
97101
x_fp4_val.storage.data.stride(0), x_fp4_val.storage.data.stride(1), #

python/triton_kernels/triton_kernels/compaction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from .compaction_details._masked_compaction import _masked_compaction
3-
from .tensor import Bitmatrix
3+
from .tensor import Tensor
44

55

66
def compaction(yv, yi, bitmask, sentinel=-1):
@@ -32,7 +32,7 @@ def compaction(yv, yi, bitmask, sentinel=-1):
3232
n_rows, n_cols = yi.shape
3333
ret_yv = torch.empty_like(yv)
3434
ret_yi = torch.empty_like(yi)
35-
if isinstance(bitmask, Bitmatrix):
35+
if isinstance(bitmask, Tensor):
3636
bitmask = bitmask.storage.data
3737

3838
_masked_compaction[(n_rows, )](

0 commit comments

Comments
 (0)