Skip to content

Commit 6b67b3c

Browse files
authored
Revert "[TRITON_KERNELS] some refactoring" (#9140)
Reverts #9134
1 parent f91230f commit 6b67b3c

32 files changed

Lines changed: 772 additions & 1037 deletions

python/triton_kernels/tests/test_matmul.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from triton_kernels.swiglu import swiglu, swiglu_fn
2222
from triton_kernels.swiglu import PrecisionConfig as SwiGLUPrecisionConfig
2323
from triton_kernels.tensor_details import layout
24-
from triton_kernels.tensor_details.dtype import FP32
25-
2624
# ---------------
2725
# numerics stuff
2826
# ---------------
@@ -136,9 +134,9 @@ def _build_test_op_cases():
136134
Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1"),
137135
Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9),
138136
Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9, b_hbm_swizzling=True),
139-
Case(300, 400, 416, "ragged", "float8_e5m2", "mxfloat8_e4m3fn"),
137+
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn"),
140138
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1"),
141-
Case(300, 400, 416, "batched", "float8_e5m2", "mxfloat8_e4m3fn"),
139+
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn"),
142140
])
143141
# mxfloat x mxfloat
144142
test_cases.extend([
@@ -147,11 +145,11 @@ def _build_test_op_cases():
147145
Case(1024, 1024, 1024, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", split_k=9, colmajor_mxfp_weight=False),
148146
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True, a_hbm_swizzling=True),
149147
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True, a_hbm_swizzling=True),
150-
Case(300, 400, 416, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True, a_hbm_swizzling=True),
148+
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True, a_hbm_swizzling=True),
151149
Case(256, 1024, 512, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True, a_hbm_swizzling=True),
152-
Case(300, 400, 416, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"),
153-
Case(300, 400, 416, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", b_hbm_swizzling=True),
154-
Case(300, 400, 416, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"),
150+
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"),
151+
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", b_hbm_swizzling=True),
152+
Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn"),
155153
Case(1024, 1024, 1024, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", b_hbm_swizzling=True),
156154
])
157155
# amd-specific float8
@@ -342,7 +340,9 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
342340
ragged_padding = inner_expt_opt is not None and "pad_a" in inner_expt_opt,
343341
squeeze_batch_dim = mode == "plain",
344342
scale_hbm_swizzling = layout.make_default_matmul_mxfp8_act_scale_layout if a_hbm_swizzling else None,
343+
scale_hbm_swizzling_args = {"ragged_metadata": None}, # ragged_metadata will be set in the make_random_tensor function
345344
)
345+
346346
b, b_scale_tri, b_ragged_metadata = make_random_tensor(
347347
shape=(k, n),
348348
n_slices = n_slices,
@@ -354,8 +354,10 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
354354
ragged_padding = inner_expt_opt is not None and "pad_b" in inner_expt_opt,
355355
squeeze_batch_dim = mode == "plain",
356356
is_mx_rowmajor = not colmajor_mxfp_weight,
357-
value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout(mx_axis=-2) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
358-
scale_hbm_swizzling = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=-2, num_warps=num_warps) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
357+
value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
358+
value_hbm_swizzling_args = {"mx_axis":-2},
359+
scale_hbm_swizzling = layout.make_default_matmul_mxfp4_w_scale_layout if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
360+
scale_hbm_swizzling_args = dict(mx_axis=-2, num_warps=num_warps),
359361
)
360362
gather_indx = None if not do_gather else torch.randint(0, max(m, 1), (m, ), dtype=torch.int32, device=device)
361363
scatter_indx = None if not do_scatter else torch.randperm(m, dtype=torch.int32, device=device)
@@ -440,6 +442,6 @@ def test_set_idle_sms():
440442
from triton_kernels.matmul_details.opt_flags import make_opt_flags
441443
num_idle_sms = 24
442444
matmul_set_idle_sms(num_idle_sms)
443-
flags = make_opt_flags(FP32, FP32, FP32, PrecisionConfig(), \
445+
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
444446
1, 1024, 1024, 1024, None, True, False, 1, False, False, None)
445447
assert flags.idle_sms == num_idle_sms

python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
import triton_kernels.matmul_details.opt_flags as opt_flags
9-
from triton_kernels.tensor_details.dtype import FP16
9+
1010

1111
class _DummyPrecisionConfig:
1212
def __init__(self):
@@ -84,9 +84,9 @@ def test_make_default_opt_flags_amd_split_k_constraint(monkeypatch):
8484

8585
precision_config = _DummyPrecisionConfig()
8686
flags = opt_flags.make_default_opt_flags_amd(
87-
FP16,
88-
FP16,
89-
FP16,
87+
torch.float16,
88+
torch.float16,
89+
torch.float16,
9090
precision_config,
9191
2,
9292
128,

python/triton_kernels/tests/test_mxfp.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@ def test_mxfp4_rounding_cases(dst_dtype, device):
2929
torch.float16: 0.250244140625,
3030
torch.float32: 0.2500000298023223877,
3131
}[dst_dtype]
32-
pad_values = [0] * 22
3332
# Construct an example where scale is 1 (when max value is 6.0, the maximum value of e2m1)
34-
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3, -1.25, two_point_five_plus_ulp] + pad_values,
35-
dtype=dst_dtype, device=device).view(1, -1, 1)
33+
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3, -1.25, two_point_five_plus_ulp], dtype=dst_dtype,
34+
device=device).view(1, -1, 1)
3635
quant, scale = downcast_to_mxfp(x, torch.uint8, axis=1)
3736
dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1)
3837
# Tie-breaking cases (RTNE):
@@ -43,7 +42,7 @@ def test_mxfp4_rounding_cases(dst_dtype, device):
4342
# - -1.25 is halfway between -1.0 and -1.5. RTNE selects -1.0 (even). Away-from-zero would pick -1.5;
4443
# towards-zero would pick -1.0.
4544
# - two_point_five_plus_ulp is slightly bigger than 0.25, so it rounds to 0.5.
46-
assert dequant.flatten().tolist() == [6, 0, 0, 0.0, 1.0, 1.0, 1.0, 1.5, -1.0, 0.5] + pad_values, f"{dequant=}"
45+
assert dequant.flatten().tolist() == [6, 0, 0, 0.0, 1.0, 1.0, 1.0, 1.5, -1.0, 0.5], f"{dequant=}"
4746

4847
quant_torch, scale_torch = downcast_to_mxfp_torch(x, torch.uint8, axis=1)
4948
assert_equal(quant_torch, quant)
@@ -57,9 +56,7 @@ def test_mxfp4_rounding_cases(dst_dtype, device):
5756
# 2**floor(log2(33/(e2m1 max power of 2 = 4)) = 2**3 = 8 (exponent 127+3),
5857
# and the other values are multiples of representable FP4 values times 8
5958
# that allow exact reconstruction.
60-
pad_values = [0] * 24
61-
x = torch.tensor([33.0, 24.0, 16.0, 8.0, 4.0, 0.0, -32.0, 0.0] + pad_values,
62-
device=device).bfloat16().view(1, -1, 1)
59+
x = torch.tensor([33.0, 24.0, 16.0, 8.0, 4.0, 0.0, -32.0, 0.0], device=device).bfloat16().view(1, -1, 1)
6360
quant, scale = downcast_to_mxfp(
6461
x,
6562
torch.uint8,
@@ -91,8 +88,7 @@ def test_mxfp_extreme_values(src_dtype, dst_dtype, device):
9188
src_dtype = dtype_str_to_torch(src_dtype)
9289
dst_dtype = dtype_str_to_torch(dst_dtype)
9390
BIG_VALUE = 65470 if dst_dtype == torch.float16 else 3.3895e38
94-
pad_values = [0] * 30
95-
x = torch.tensor([BIG_VALUE, BIG_VALUE] + pad_values, dtype=dst_dtype, device=device)
91+
x = torch.tensor([BIG_VALUE, BIG_VALUE], dtype=dst_dtype, device=device)
9692
xq_value, xq_scale = downcast_to_mxfp(x, src_dtype, axis=-1)
9793
xdq = upcast_from_mxfp(xq_value, xq_scale, dst_dtype, axis=-1)
9894
xdq_ref = upcast_from_mxfp_torch(xq_value, xq_scale, dst_dtype, axis=-1)
@@ -131,7 +127,6 @@ def test_mxfp_quant_dequant(src_dtype, dst_dtype, device):
131127
weight = weight.repeat((9, 32)) # Repeat the dimensions to test multi block launches.
132128
weight = weight.reshape([1, weight.shape[0], weight.shape[1]])
133129
weight = weight.mT.contiguous().mT
134-
weight = torch.nn.functional.pad(weight, (0, 0, 0, 16))
135130
quant, scale = downcast_to_mxfp(weight, src_dtype, axis=1)
136131
dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1)
137132
assert_equal(weight, dequant)
@@ -148,7 +143,7 @@ def test_mxfp_quant_dequant(src_dtype, dst_dtype, device):
148143
((0, 0, 1024), 2, "float8_e4m3fn", DequantScaleRoundingMode.ROUND_DOWN),
149144
150145
((3, 4096, 1024), 1, "float4_e2m1", DequantScaleRoundingMode.ROUND_UP),
151-
((32, 254, 60), 0, "float4_e2m1", DequantScaleRoundingMode.ROUND_DOWN),
146+
((10, 254, 60), 0, "float4_e2m1", DequantScaleRoundingMode.ROUND_DOWN),
152147
((1, 320, 160), 2, "float8_e5m2", DequantScaleRoundingMode.ROUND_UP),
153148
((2, 16, 512), -1, "float8_e4m3fn", DequantScaleRoundingMode.ROUND_DOWN),
154149
],

python/triton_kernels/tests/test_tensor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import pytest
22
import torch
3-
from triton_kernels.tensor_details.dtype import BIT
3+
from triton_kernels.tensor import Bitmatrix, BIT
44
from triton_kernels.tensor import (
55
make_ragged_tensor_metadata,
66
make_ragged_tensor_metadata_torch,
77
remap_ragged_tensor_metadata,
88
remap_ragged_tensor_metadata_torch,
99
make_bitmatrix_metadata,
1010
make_bitmatrix_metadata_torch,
11-
wrap_torch_tensor,
1211
)
1312
from triton_kernels.testing import assert_equal
1413

@@ -66,7 +65,7 @@ def test_make_bitmatrix_metadata(n_rows, n_cols, k):
6665
rows = torch.arange(n_rows, device=device).unsqueeze(1).expand_as(indx)
6766
bitmask_data = torch.zeros((n_rows, (n_cols + 31) // 32), dtype=torch.int32, device=device)
6867
bitmask_data.index_put_((rows, indx // 32), 1 << (indx % 32), accumulate=True)
69-
bitmask = wrap_torch_tensor(bitmask_data.view(torch.uint32), dtype=BIT, shape=(n_rows, n_cols))
68+
bitmask = Bitmatrix(bitmask_data.view(torch.uint32), dtype=BIT, shape=(n_rows, n_cols))
7069
# make metadata and compare
7170
metadata_tri = make_bitmatrix_metadata(indx, bitmask)
7271
metadata_ref = make_bitmatrix_metadata_torch(indx, bitmask)

python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,16 @@
2020
)
2121
def test_mxfp4_scale_roundtrip(shape):
2222
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
23-
layout = BlackwellMXScaleLayout()
24-
transformation = layout.make_transformation(x.shape, is_fp4=False)
25-
res = transformation.unswizzle_data(transformation.swizzle_data(x))
23+
layout = BlackwellMXScaleLayout(x.shape)
24+
res = layout.unswizzle_data(layout.swizzle_data(x))
2625
assert (res == x).all()
2726

2827

2928
@pytest.mark.parametrize("shape", [(2, 256, 192), (1, 128, 64)])
3029
def test_act_scale_roundtrip_batched(shape):
3130
x = torch.randn(shape, device="cuda", dtype=torch.float32)
32-
layout = BlackwellActMXScaleLayout(ragged_metadata=None)
33-
transformation = layout.make_transformation(x.shape, is_fp4=False)
34-
res = transformation.unswizzle_data(transformation.swizzle_data(x))
31+
layout = BlackwellActMXScaleLayout(x.shape)
32+
res = layout.unswizzle_data(layout.swizzle_data(x))
3533
torch.testing.assert_close(res, x)
3634

3735

@@ -47,9 +45,8 @@ def test_act_scale_roundtrip_ragged(slice_sizes, m, k, align_m):
4745
m = max(m, slice_sizes.sum().item()) # there can be padded tokens in the input
4846
ragged_metadata = make_ragged_tensor_metadata(slice_sizes, m)
4947
x = torch.randn((m, k), device="cuda", dtype=torch.float32)
50-
layout = BlackwellActMXScaleLayout(ragged_metadata=ragged_metadata)
51-
transformation = layout.make_transformation(x.shape, is_fp4=False)
52-
res = transformation.unswizzle_data(transformation.swizzle_data(x))
48+
layout = BlackwellActMXScaleLayout((m, k), ragged_metadata=ragged_metadata)
49+
res = layout.unswizzle_data(layout.swizzle_data(x))
5350

5451
x_useful_rows = x[ragged_metadata.slice_offs[:-1], :]
5552
res_useful_rows = res[ragged_metadata.slice_offs[:-1], :]

python/triton_kernels/tests/test_tensor_details/test_layout_cdna4.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
)
2020
def test_mxfp4_scale_roundtrip(shape):
2121
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
22-
layout = CDNA4MXScaleLayout()
23-
transformation = layout.make_transformation(x.shape, is_fp4=False)
24-
res = transformation.unswizzle_data(transformation.swizzle_data(x))
22+
layout = CDNA4MXScaleLayout(x.shape)
23+
res = layout.unswizzle_data(layout.swizzle_data(x))
2524
assert (res == x).all()

python/triton_kernels/tests/test_tensor_details/test_layout_hopper.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from triton._internal_testing import is_cuda
3-
from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4
3+
from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4, get_layout
44
from triton_kernels.tensor_details.layout import HopperMXScaleLayout, HopperMXValueLayout
55
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
66
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
@@ -25,11 +25,8 @@ def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
2525
x = x.mT
2626
if x.shape[1 - mx_axis] < 32:
2727
pytest.skip("Not enough elements along non-mx axis")
28-
layout = HopperMXValueLayout(mx_axis, mma_version)
29-
shape = list(x.shape)
30-
shape[-1] *= 2
31-
transformation = layout.make_transformation(shape, is_fp4=False)
32-
res = transformation.unswizzle_data(transformation.swizzle_data(x))
28+
layout = HopperMXValueLayout(x.shape, mx_axis, mma_version)
29+
res = layout.unswizzle_data(layout.swizzle_data(x))
3330
assert (res == x).all()
3431

3532

@@ -38,9 +35,8 @@ def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
3835
@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)])
3936
def test_mxfp4_scale_roundtrip(shape, mx_axis, num_warps):
4037
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
41-
layout = HopperMXScaleLayout(mx_axis=mx_axis, num_warps=num_warps)
42-
transformation = layout.make_transformation(x.shape, is_fp4=False)
43-
res = transformation.unswizzle_data(transformation.swizzle_data(x))
38+
layout = HopperMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps)
39+
res = layout.unswizzle_data(layout.swizzle_data(x))
4440
assert (res[:shape[0], :shape[1]] == x).all()
4541

4642

@@ -89,13 +85,13 @@ def test_upcast_mxfp4_to_bf16(num_warps, mx_axis):
8985
x_bf16 = upcast_from_mxfp(x_fp4_val, x_fp4_scale, x.dtype, axis=mx_axis)
9086
x_fp4_val = wrap_torch_tensor(x_fp4_val, dtype=FP4)
9187
x_fp4_scale = wrap_torch_tensor(x_fp4_scale)
92-
x_fp4_val = convert_layout(x_fp4_val, HopperMXValueLayout(mx_axis=mx_axis - 2, mma_version=3))
93-
x_fp4_scale = convert_layout(x_fp4_scale, HopperMXScaleLayout(mx_axis=mx_axis - 2, num_warps=num_warps))
88+
x_fp4_val = convert_layout(x_fp4_val, HopperMXValueLayout, mx_axis=mx_axis)
89+
x_fp4_scale = convert_layout(x_fp4_scale, HopperMXScaleLayout, mx_axis=mx_axis, num_warps=num_warps)
9490
y = torch.empty_like(x_bf16)
9591
scale_block = [s // 32 if i == mx_axis else s for i, s in enumerate(shape)]
96-
scale_block = x_fp4_scale.storage.layout.swizzle_block_shape(scale_block)
92+
scale_block = get_layout(x_fp4_scale).swizzle_block_shape(scale_block)
9793
value_block = [s // 2 if i == mx_axis else s for i, s in enumerate(shape)]
98-
value_block = x_fp4_val.storage.layout.swizzle_block_shape(value_block)
94+
value_block = get_layout(x_fp4_val).swizzle_block_shape(value_block)
9995
_upcast_mxfp4_to_bf16[(1, )](
10096
y, x_fp4_val.storage.data, x_fp4_scale.storage.data, #
10197
x_fp4_val.storage.data.stride(0), x_fp4_val.storage.data.stride(1), #

python/triton_kernels/triton_kernels/compaction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from .compaction_details._masked_compaction import _masked_compaction
3-
from .tensor import Tensor
3+
from .tensor import Bitmatrix
44

55

66
def compaction(yv, yi, bitmask, sentinel=-1):
@@ -32,7 +32,7 @@ def compaction(yv, yi, bitmask, sentinel=-1):
3232
n_rows, n_cols = yi.shape
3333
ret_yv = torch.empty_like(yv)
3434
ret_yi = torch.empty_like(yi)
35-
if isinstance(bitmask, Tensor):
35+
if isinstance(bitmask, Bitmatrix):
3636
bitmask = bitmask.storage.data
3737

3838
_masked_compaction[(n_rows, )](

0 commit comments

Comments
 (0)