Skip to content

Commit 3fd1ef8

Browse files
committed
cleanup
1 parent 4d94143 commit 3fd1ef8

9 files changed

Lines changed: 40 additions & 318 deletions

File tree

python/triton_kernels/tests/test_matmul.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
381381
# --- create precision config ---
382382
wrap_list = lambda vals: torch.tensor(vals, dtype=torch.float32, device=device)
383383
c_scale_global = wrap_list([4.00]) if c_dtype.has_global_scale else None
384-
c_absmax = wrap_list([0]) if c_dtype.has_global_scale else None
385384
precision_opt = PrecisionConfig(
386385
acc_scale=2.0 if c_dtype.has_global_scale or b_dtype.has_global_scale else 1.0,
387386
out_dtype=c_dtype.torch_dtype,
@@ -402,23 +401,18 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
402401
tri_y = matmul(a, b, bias,
403402
a_ragged_metadata, b_ragged_metadata,
404403
gather_indx, scatter_indx, precision_opt,
405-
gammas=gammas, epilogue=epilogue, c=c, c_absmax=c_absmax,
404+
gammas=gammas, epilogue=epilogue, c=c,
406405
fused_activation=fused_activation)
407-
if c_dtype.has_global_scale:
408-
tri_y_scale = c_absmax.clone()
409406
except (opt_flags.InapplicableConstraint, NotImplementedError) as e:
410407
pytest.skip(f"inapplicable opt_flags constraint {e}")
411408
# --- torch implementation ---
412409
ref_y = matmul_torch(a, b, bias, #
413410
a_ragged_metadata, b_ragged_metadata,
414411
gather_indx, scatter_indx, precision_opt,
415412
gammas=gammas,
416-
c=c,
417-
c_absmax=c_absmax)
413+
c=c)
418414
if swiglu_opts is not None:
419415
ref_y = swiglu(ref_y, alpha=swiglu_opts[0], precision_config=SwiGLUPrecisionConfig(swiglu_opts[1]))
420-
if c_dtype.has_global_scale:
421-
ref_y_scale = c_absmax.clone()
422416

423417
# --- check results ---
424418
if c_dtype.has_mx_scale:
@@ -430,9 +424,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
430424
elif b_dtype.is_mxfloat4:
431425
maxtol, rmstol = 3e-2, None
432426
assert_close(ref_y, tri_y, maxtol=maxtol, rmstol=rmstol)
433-
if c_dtype.has_global_scale:
434-
assert torch.all((ref_y_scale - tri_y_scale).abs() < 1e-10), \
435-
f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}"
436427

437428

438429
def test_set_idle_sms():

python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def setup_nvidia(monkeypatch):
4444
monkeypatch.setattr(
4545
opt_flags.opt_flags_nvidia,
4646
"compute_block_n",
47-
lambda n, arch, precision_config: (64, 32),
47+
lambda n, arch, precision_config, **kwargs: (64, 32),
4848
)
4949
monkeypatch.setattr(
5050
opt_flags.opt_flags_nvidia,
@@ -54,7 +54,7 @@ def setup_nvidia(monkeypatch):
5454
monkeypatch.setattr(
5555
opt_flags.opt_flags_nvidia,
5656
"compute_block_k",
57-
lambda m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in: 32,
57+
lambda m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in, **kwargs: 32,
5858
)
5959
monkeypatch.setattr(
6060
opt_flags.opt_flags_nvidia,
@@ -69,7 +69,7 @@ def setup_nvidia(monkeypatch):
6969
monkeypatch.setattr(
7070
opt_flags.opt_flags_nvidia,
7171
"compute_num_warps",
72-
lambda block_m, block_n, is_persistent, precision_config, constraints: 4,
72+
lambda block_m, block_n, is_persistent, precision_config, constraints, **kwargs: 4,
7373
)
7474

7575
fake_target = types.SimpleNamespace(backend="cuda", arch=100)

python/triton_kernels/tests/test_reduce.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,10 @@ def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn):
7070
device = "cuda"
7171
x = torch.randn((B, M, N), device=device, dtype=torch.float32, requires_grad=True)
7272
x_scale_mx, x_scale_global = None, None
73-
y_scale_global, y_absmax_tri, y_absmax_ref = None, None, None
73+
y_scale_global = None
7474
if is_mx := dtype_str.startswith("mx"):
7575
dtype = dtype_str_to_torch(dtype_str.removeprefix("mx"))
7676
x, x_scale_mx = downcast_to_mxfp_torch(x.to(torch.float16), dtype, axis=-1)
77-
if is_flex := dtype_str.startswith("flex"):
78-
dtype = dtype_str_to_torch(dtype_str.removeprefix("flex"))
79-
expected_scale = torch.tensor([4], device=device, dtype=torch.float32)
80-
x_scale_global = torch.tensor([2], device=device, dtype=torch.float32)
81-
x = x / x_scale_global
82-
x = x.to(dtype)
83-
y_scale_global = expected_scale
84-
y_absmax_tri = torch.zeros_like(expected_scale)
85-
y_absmax_ref = torch.zeros_like(expected_scale)
8677
mask = init_mask(mask_mode, B, M, N, device)
8778
expected_exception = ValueError if dim == 2 and is_mx else None
8879
if expected_exception is not None:
@@ -105,7 +96,6 @@ def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn):
10596
x_scale_mx=x_scale_mx,
10697
x_scale_global=x_scale_global,
10798
y_scale_global=y_scale_global,
108-
y_absmax=y_absmax_tri,
10999
postprocess_fn1=postprocess_fn_tri,
110100
)
111101
y_ref, y_ref_mxscale = reduce_torch(
@@ -115,15 +105,12 @@ def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn):
115105
x_scale_mx=x_scale_mx,
116106
x_scale_global=x_scale_global,
117107
y_scale_global=y_scale_global,
118-
y_absmax=y_absmax_ref,
119108
postprocess_fn1=postprocess_fn_ref,
120109
)
121110
if is_mx:
122111
y_ref = upcast_from_mxfp_torch(y_ref, y_ref_mxscale, torch.float16, axis=-1)
123112
y_tri = upcast_from_mxfp_torch(y_tri, y_tri_mxscale, torch.float16, axis=-1)
124113
assert torch.allclose(y_tri.float(), y_ref.float(), atol=1e-3, rtol=1e-3)
125-
if is_flex:
126-
assert torch.allclose(y_absmax_tri, y_absmax_ref, atol=1e-3, rtol=1e-3)
127114
run_bwd = postprocess_fn is None and "float8" not in dtype_str
128115
if run_bwd:
129116
dy = torch.randn_like(y_tri)

python/triton_kernels/triton_kernels/matmul.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def matmul(a, b, bias,
223223
gammas: torch.Tensor | None = None,
224224
out_alpha: float | None = None,
225225
c: torch.Tensor | Tensor | None = None,
226-
c_absmax: torch.Tensor | None = None,
227226
fused_comm: FusedComm | None = None,
228227
fused_activation: FusedActivation | None = None,
229228
epilogue: Epilogue | None = None,
@@ -491,12 +490,11 @@ def matmul(a, b, bias,
491490
} if fused_comm is not None else {}
492491
n_valid_slices = b_tensor_or_tma.shape[0] if ragged_dimension == "M" else n_slices
493492
# split-k scratchpad is fp32/fp16 accumulation, not the final output dtype.
494-
# output flex scaling is applied in the reduce step.
493+
# output scaling is applied in the reduce step.
495494
out_global_scale = None if has_scratchpad else c_scale_global
496-
out_absmax = None if has_scratchpad else c_absmax
497495
(kernels._p_matmul if opt_flags.is_persistent else kernels._matmul)[(grid,)](
498496
c_tensor_or_tma, c.storage.data, *out_matmul.stride(),
499-
*((None, out_matmul_scale, None) if out_matmul_has_mx else (out_global_scale, out_absmax, None)),
497+
*((None, out_matmul_scale, None) if out_matmul_has_mx else (out_global_scale, None, None)),
500498
*out_matmul_scale_strides[-4:],
501499
a_tensor_or_tma, a.storage.data, *a_strides, a_transpose,
502500
a.scale_global,
@@ -564,8 +562,6 @@ def matmul(a, b, bias,
564562
y = memory["output"].view(-1, memory["output"].shape[-1]),
565563
y_dtype = memory["output"].dtype,
566564
y_scale_global = c_scale_global,
567-
y_absmax = c_absmax,
568-
y_saturate_inf = precision_config.flexpoint_saturate_inf,
569565
y_has_mx = c_scale_mx is not None,
570566
# fused functions
571567
postprocess_fn1 = postprocess_fn1,
@@ -639,17 +635,6 @@ def scale(val, scal):
639635
assert val.ndim == 3
640636
return val / scal[:, None, None]
641637

642-
def compute_actual_scale(x, dtype, per_batch_scale=False):
643-
from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
644-
max_finite = {
645-
torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
646-
torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
647-
torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
648-
}[dtype]
649-
maxvals = x.abs().amax(dim=tuple(range(1, x.ndim))) if per_batch_scale else x.abs().max()
650-
return maxvals / max_finite
651-
652-
653638
def matmul_torch(a, b, bias,
654639
a_ragged_metadata: RaggedTensorMetadata | None = None,
655640
b_ragged_metadata: RaggedTensorMetadata | None = None,
@@ -660,7 +645,6 @@ def matmul_torch(a, b, bias,
660645
gammas = None,
661646
round_x = None, round_y = None,
662647
c: torch.Tensor | Tensor | None = None,
663-
c_absmax: torch.Tensor | None = None,
664648
):
665649
if precision_config is None:
666650
precision_config = PrecisionConfig()
@@ -696,8 +680,6 @@ def matmul_torch(a, b, bias,
696680
round_y=round_y,
697681
)
698682
out[expt] = out_expt.to(out.dtype)
699-
if c_absmax is not None:
700-
c_absmax.copy_(compute_actual_scale(out, precision_config.out_dtype))
701683
return scale(out, None if c is None else c.scale_global)
702684

703685
is_input_batched = a.ndim == 3
@@ -748,8 +730,6 @@ def matmul_torch(a, b, bias,
748730
out = torch.zeros((scatter_indx.shape[0], y.shape[-1]), dtype=y.dtype, device=a.device)
749731
msk = scatter_indx != -1
750732
out[scatter_indx[msk], :] = y[msk, :]
751-
if c_absmax is not None:
752-
c_absmax.copy_(compute_actual_scale(out, precision_config.out_dtype))
753733
return scale(out, None if c is None else c.scale_global)
754734

755735

python/triton_kernels/triton_kernels/matmul_details/_matmul.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
77
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
88
from triton_kernels.tensor_details.layout_details.cdna4_scale import unswizzle_mx_scale_cdna4
9-
from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
109
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
1110
from triton_kernels.target_info import cuda_capability_geq
1211
from ._common import (
@@ -23,6 +22,11 @@ def round_f32_to_tf32(x: tl.tensor):
2322
ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;"
2423
return tl.inline_asm_elementwise(ASM, "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1)
2524

25+
26+
@triton.jit
27+
def load_scale(scale_ptr):
28+
return 1.0 if scale_ptr is None else tl.load(scale_ptr)
29+
2630
_matmul_repr = make_matmul_repr("_matmul", [0, 1, 2])
2731
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
2832
repr=_matmul_repr, launch_metadata=matmul_launch_metadata)
@@ -483,8 +487,8 @@ def _matmul(
483487
else:
484488
if PER_BATCH_OUT_SCALE:
485489
YExpectedScale = YExpectedScale + start_z_out
486-
YActualScale = YActualScale + start_z_out
487-
out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF)
490+
if YExpectedScale is not None:
491+
out = out / load_scale(YExpectedScale)
488492
if EPILOGUE_FN is not None and not IS_EPILOGUE_QUANT_MXFP8:
489493
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty)
490494
if pYPtrs is None:

python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@
77
from triton.tools.ragged_tma import load_ragged, store_ragged
88
from triton_kernels import target_info
99
from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw, unswizzle_act_mx_scale_bw
10-
from triton_kernels.numerics_details.flexpoint import (
11-
float_to_flex,
12-
load_scale,
13-
nan_propagating_absmax_reduce,
14-
compute_scale,
15-
)
1610
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
1711
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
1812
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
@@ -38,6 +32,10 @@ def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
3832
else:
3933
raise ValueError(f"Invalid type: {type(tensor_or_desc)}")
4034

35+
@triton.jit
36+
def load_scale(scale_ptr):
37+
return 1.0 if scale_ptr is None else tl.load(scale_ptr)
38+
4139
@triton.jit
4240
def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
4341
mask = mask & (offs < writeback_size)
@@ -172,7 +170,7 @@ def _p_matmul(
172170

173171
index_type: tl.constexpr = tl.int64
174172

175-
USE_FLEXPOINT_SCALE: tl.constexpr = YActualScale is not None or YChecksumScale is not None
173+
USE_FLEXPOINT_SCALE: tl.constexpr = YExpectedScale is not None or YChecksumScale is not None
176174
HAS_SCATTER: tl.constexpr = WriteBackIndx is not None
177175
HAS_GATHER: tl.constexpr = GatherIndx is not None
178176
USE_GATHER_TMA: tl.constexpr = HAS_GATHER and X_TMA_MODE == "dense"
@@ -200,12 +198,6 @@ def _p_matmul(
200198
if INDEPENDENT_EPILOGUE:
201199
tile_id1 = tl.program_id(0) - NUM_SMS
202200

203-
# Keep track of local max for updating flexpoint scales.
204-
USE_LOCAL_ABSMAX: tl.constexpr = (YActualScale is not None) and (not PER_BATCH_OUT_SCALE) and (not is_out_microscaled) and (pYPtrs is None)
205-
if USE_LOCAL_ABSMAX:
206-
THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
207-
local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32)
208-
209201
DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_w_microscaled and BLOCK_M * BLOCK_N >= 128 * 256
210202

211203
for block_id in tl.range(
@@ -566,23 +558,13 @@ def _p_matmul(
566558
YActualScalePtrs = YActualScale + offs_y_mx_k.to(index_type) * stride_y_mx_k + offs_y_mx_z.to(index_type) * stride_y_mx_z + offs_y_mx_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
567559
tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
568560
else:
569-
# Flexpoint
570-
if USE_LOCAL_ABSMAX:
571-
out_view = tl.reshape(out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True)
572-
local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0))
573-
561+
# Global scale
574562
if PER_BATCH_OUT_SCALE:
575563
ExpectedScale = YExpectedScale + start_z1
576-
ActualScale = YActualScale + start_z1
577564
else:
578565
ExpectedScale = YExpectedScale
579-
ActualScale = None # local absmax is tracked and updated after the loop
580-
581-
out = float_to_flex(
582-
out, ExpectedScale, ActualScale, YChecksumScale,
583-
None, # mask: out is manually masked to 0
584-
YPtr, FLEXPOINT_SATURATE_INF
585-
)
566+
if ExpectedScale is not None:
567+
out = out / load_scale(ExpectedScale)
586568
if EPILOGUE_FN is not None and not IS_EPILOGUE_QUANT_MXFP8:
587569
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtr.dtype.element_ty, pid=len(accs)*tile_id1 + a_i)
588570

@@ -636,11 +618,6 @@ def _p_matmul(
636618
tl.multiple_of(peer_Y_ptr, [16, 16])
637619
tl.store(peer_Y_ptr + offs_kzmn, out, mask=mask)
638620

639-
640-
# Update the flexpoint scales
641-
if USE_LOCAL_ABSMAX:
642-
tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), YPtr), sem="relaxed")
643-
644621
if pYPtrs is not None:
645622
all_writes_issued.fn(*all_writes_issued.captured)
646623

0 commit comments

Comments
 (0)