Skip to content

Commit 4d94143

Browse files
committed
more flexpoint removal
1 parent af2196d commit 4d94143

8 files changed

Lines changed: 176 additions & 230 deletions

File tree

python/triton_kernels/tests/test_matmul.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
328328
do_scatter = do_scatter and mode != "batched"
329329

330330
# --- create inputs ---
331-
a, a_scales, a_ragged_metadata = make_random_tensor(
331+
a, a_ragged_metadata = make_random_tensor(
332332
shape=(m, k),
333333
n_slices = n_slices,
334334
dtype = a_dtype,
@@ -339,8 +339,9 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
339339
ragged_padding = inner_expt_opt is not None and "pad_a" in inner_expt_opt,
340340
squeeze_batch_dim = mode == "plain",
341341
scale_hbm_swizzling = layout.make_default_matmul_mxfp8_act_scale_layout if a_hbm_swizzling else None,
342+
scale_global = 1.25 if a_dtype.has_global_scale else None,
342343
)
343-
b, b_scale_tri, b_ragged_metadata = make_random_tensor(
344+
b, b_ragged_metadata = make_random_tensor(
344345
shape=(k, n),
345346
n_slices = n_slices,
346347
dtype = b_dtype,
@@ -353,6 +354,7 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
353354
is_mx_rowmajor = not colmajor_mxfp_weight,
354355
value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout(mx_axis=-2) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
355356
scale_hbm_swizzling = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=-2, num_warps=num_warps) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
357+
scale_global = 1.25 if b_dtype.has_global_scale else None,
356358
)
357359
if not isinstance(a, Tensor):
358360
a = wrap_torch_tensor(a)
@@ -378,32 +380,21 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
378380

379381
# --- create precision config ---
380382
wrap_list = lambda vals: torch.tensor(vals, dtype=torch.float32, device=device)
381-
a_scale_global = wrap_list([1.25]) if c_dtype.has_global_scale else None
382-
b_scale_global = wrap_list([1.25]) if b_dtype.has_global_scale else None
383383
c_scale_global = wrap_list([4.00]) if c_dtype.has_global_scale else None
384384
c_absmax = wrap_list([0]) if c_dtype.has_global_scale else None
385385
precision_opt = PrecisionConfig(
386386
acc_scale=2.0 if c_dtype.has_global_scale or b_dtype.has_global_scale else 1.0,
387387
out_dtype=c_dtype.torch_dtype,
388388
)
389-
c_scale_mx = None
390389

391390
# --- create epilogue ---
391+
c_scale_mx = None
392392
epilogue = None
393393
if c_dtype.has_mx_scale:
394394
c_scale_shape = c_shape[:-1] + (triton.cdiv(c_shape[-1], MXFP_BLOCK_SIZE),)
395395
c_scale_mx = torch.empty(c_scale_shape, dtype=torch.uint8, device=a.device)
396396
epilogue_spec = FnSpecs(FnName.QUANTIZE_MXFP8.name, quantize_mxfp8_fn, (), ())
397397
epilogue = Epilogue(epilogue_spec, tuple(), tuple(), effective_itemsize=6.0)
398-
399-
if isinstance(a, Tensor):
400-
a = Tensor(a.storage, dtype=a.dtype, shape=a.shape, shape_max=a.shape_max, scale_global=a_scale_global, scale_mx=a_scales)
401-
else:
402-
a = wrap_torch_tensor(a, scale_global=a_scale_global, scale_mx=a_scales)
403-
if isinstance(b, Tensor):
404-
b = Tensor(b.storage, dtype=b.dtype, shape=b.shape, shape_max=b.shape_max, scale_global=b_scale_global, scale_mx=b_scale_tri)
405-
else:
406-
b = wrap_torch_tensor(b, scale_global=b_scale_global, scale_mx=b_scale_tri)
407398
c = wrap_torch_tensor(c, scale_global=c_scale_global, scale_mx=c_scale_mx)
408399

409400
# --- triton implementation ---

python/triton_kernels/tests/test_reduce.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from triton.testing import do_bench
44
from triton_kernels.reduce import reduce, reduce_torch, PostprocessFn, FnSpecs
55
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp_torch, downcast_to_mxfp_torch
6-
from triton_kernels.numerics import InFlexData, OutFlexData
76
from triton_kernels.target_info import is_cuda, is_hip, is_hip_cdna3, is_hip_cdna4
87
import triton
98
import triton.language as tl
@@ -70,24 +69,25 @@ def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn):
7069
torch.manual_seed(0)
7170
device = "cuda"
7271
x = torch.randn((B, M, N), device=device, dtype=torch.float32, requires_grad=True)
73-
x_mscale, x_flex = None, None
74-
y_flex_tri, y_flex_ref = None, None
72+
x_scale_mx, x_scale_global = None, None
73+
y_scale_global, y_absmax_tri, y_absmax_ref = None, None, None
7574
if is_mx := dtype_str.startswith("mx"):
7675
dtype = dtype_str_to_torch(dtype_str.removeprefix("mx"))
77-
x, x_mscale = downcast_to_mxfp_torch(x.to(torch.float16), dtype, axis=-1)
76+
x, x_scale_mx = downcast_to_mxfp_torch(x.to(torch.float16), dtype, axis=-1)
7877
if is_flex := dtype_str.startswith("flex"):
7978
dtype = dtype_str_to_torch(dtype_str.removeprefix("flex"))
8079
expected_scale = torch.tensor([4], device=device, dtype=torch.float32)
81-
x_flex = InFlexData(scale=torch.tensor([2], device=device, dtype=torch.float32))
82-
x = x / x_flex.scale
80+
x_scale_global = torch.tensor([2], device=device, dtype=torch.float32)
81+
x = x / x_scale_global
8382
x = x.to(dtype)
84-
y_flex_tri = OutFlexData(expected_scale=expected_scale, actual_scale=torch.empty_like(expected_scale))
85-
y_flex_ref = OutFlexData(expected_scale=expected_scale, actual_scale=torch.empty_like(expected_scale))
83+
y_scale_global = expected_scale
84+
y_absmax_tri = torch.zeros_like(expected_scale)
85+
y_absmax_ref = torch.zeros_like(expected_scale)
8686
mask = init_mask(mask_mode, B, M, N, device)
8787
expected_exception = ValueError if dim == 2 and is_mx else None
8888
if expected_exception is not None:
8989
with pytest.raises(expected_exception):
90-
reduce(x, dim=dim, mask=mask, x_mxscale=x_mscale)
90+
reduce(x, dim=dim, mask=mask, x_scale_mx=x_scale_mx)
9191
return
9292
if postprocess_fn == "plus_ten":
9393
postprocess_fn_tri = PostprocessFn(specs=FnSpecs("plus_a", plus_a_reduce, ("a", ), reduction_n=2),
@@ -98,16 +98,32 @@ def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn):
9898
# run forward pass
9999
x_tri = x.clone().detach().requires_grad_(True)
100100
x_ref = x.clone().detach().requires_grad_(True)
101-
y_tri, y_tri_mxscale = reduce(x_tri, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_tri,
102-
postprocess_fn1=postprocess_fn_tri)
103-
y_ref, y_ref_mxscale = reduce_torch(x_ref, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_ref,
104-
postprocess_fn1=postprocess_fn_ref)
101+
y_tri, y_tri_mxscale = reduce(
102+
x_tri,
103+
dim=dim,
104+
mask=mask,
105+
x_scale_mx=x_scale_mx,
106+
x_scale_global=x_scale_global,
107+
y_scale_global=y_scale_global,
108+
y_absmax=y_absmax_tri,
109+
postprocess_fn1=postprocess_fn_tri,
110+
)
111+
y_ref, y_ref_mxscale = reduce_torch(
112+
x_ref,
113+
dim=dim,
114+
mask=mask,
115+
x_scale_mx=x_scale_mx,
116+
x_scale_global=x_scale_global,
117+
y_scale_global=y_scale_global,
118+
y_absmax=y_absmax_ref,
119+
postprocess_fn1=postprocess_fn_ref,
120+
)
105121
if is_mx:
106122
y_ref = upcast_from_mxfp_torch(y_ref, y_ref_mxscale, torch.float16, axis=-1)
107123
y_tri = upcast_from_mxfp_torch(y_tri, y_tri_mxscale, torch.float16, axis=-1)
108124
assert torch.allclose(y_tri.float(), y_ref.float(), atol=1e-3, rtol=1e-3)
109125
if is_flex:
110-
torch.allclose(y_flex_tri.actual_scale, y_flex_ref.actual_scale, atol=1e-3, rtol=1e-3)
126+
assert torch.allclose(y_absmax_tri, y_absmax_ref, atol=1e-3, rtol=1e-3)
111127
run_bwd = postprocess_fn is None and "float8" not in dtype_str
112128
if run_bwd:
113129
dy = torch.randn_like(y_tri)

python/triton_kernels/triton_kernels/matmul.py

Lines changed: 25 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from enum import Enum, auto
88
import math
99
from typing import Callable
10-
from types import SimpleNamespace
1110
# utilities
1211
from triton_kernels import target_info
1312
from triton_kernels.meta import Closure
@@ -260,36 +259,25 @@ def matmul(a, b, bias,
260259
if epilogue is None:
261260
epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
262261
n_slices = max(1, b.shape[0]) if a_ragged_metadata is None else a_ragged_metadata.n_slices
263-
if c is not None and not isinstance(c, Tensor):
264-
c = wrap_torch_tensor(c)
265-
if d is not None and not isinstance(d, Tensor):
266-
d = wrap_torch_tensor(d)
267-
c_data = None if c is None else c.storage.data
268-
d_data = None if d is None else d.storage.data
269262
if not isinstance(a, Tensor):
270263
a = wrap_torch_tensor(a)
271264
if not isinstance(b, Tensor):
272265
b_dtype = FP4 if b.dtype == torch.uint8 else None
273266
b = wrap_torch_tensor(b, dtype=b_dtype)
274-
a_scale_global = a.scale_global
267+
if c is not None and not isinstance(c, Tensor):
268+
c = wrap_torch_tensor(c)
269+
if d is not None and not isinstance(d, Tensor):
270+
d = wrap_torch_tensor(d)
271+
d_data = None if d is None else d.storage.data
275272
a_scale = a.scale_mx
276-
if isinstance(a_scale, torch.Tensor):
277-
a_scale = wrap_torch_tensor(a_scale)
278-
b_scale_global = b.scale_global
279273
b_scale = b.scale_mx
280-
if isinstance(b_scale, torch.Tensor):
281-
b_scale = wrap_torch_tensor(b_scale)
282274
b_has_mx = b_scale is not None
283275
if b_has_mx and (torch.cuda.get_device_capability()[0] < 10 or b.storage.layout is not None and not isinstance(b.storage.layout, StridedLayout)):
284276
assert b.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
285-
if b_scale is not None:
286-
b_scale.storage.data = b_scale.data.view(torch.uint8)
287277
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and b.dtype.bitwidth == 8
288278
if is_hopper_fp8: assert b.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
289279
c_scale_global = None if c is None else c.scale_global
290280
c_scale_mx = None if c is None else c.scale_mx
291-
if isinstance(c_scale_mx, torch.Tensor):
292-
c_scale_mx = wrap_torch_tensor(c_scale_mx)
293281
d_scale_global = None if d is None else d.scale_global
294282

295283
# unpack a scale
@@ -310,8 +298,8 @@ def matmul(a, b, bias,
310298
batch_size = b.shape[0]
311299
else:
312300
batch_size = 1
313-
if d_data is not None:
314-
d_is_c = c_data is not None and d_data.data_ptr() == c_data.data_ptr() and d_data.stride() == c_data.stride()
301+
if d_data is not None and c is not None:
302+
d_is_c = d_data.data_ptr() == c.storage.data.data_ptr() and d_data.stride() == c.storage.data.stride()
315303
else:
316304
d_is_c = None
317305
K = a.shape[-1]
@@ -327,8 +315,8 @@ def matmul(a, b, bias,
327315
(b_scale is None or is_tma_compliant(b_scale)) and
328316
(ragged_dimension != "M" or a.stride(-1) == 1) and
329317
# Currently we don't support tma if y is column major; may revisit later if this becomes an issue.
330-
(c_data is None or c_data.stride(-1) == 1) and
331-
(d_data is None or d_is_c) and
318+
(c is None or c.storage.data.stride(-1) == 1) and
319+
(d is None or d_is_c) and
332320
# if ragged dimension is K, w must be either padded or row major to ensure alignment
333321
(ragged_dimension != "K" or b.stride(-1) == 1 or b_ragged_metadata.slice_sizes_divisibility is not None)
334322
)
@@ -382,7 +370,7 @@ def matmul(a, b, bias,
382370
gather_indx, scatter_indx, batch_size,
383371
fused_comm.n_reduce_shards if fused_comm is not None else 1,
384372
opt_flags)
385-
memory = apply_allocation(allocation, c_data)
373+
memory = apply_allocation(allocation, None if c is None else c.storage.data)
386374
# early exit
387375
if batch_size * M * N == 0:
388376
ret = memory["output"].squeeze(0)
@@ -420,10 +408,10 @@ def matmul(a, b, bias,
420408
# canonicalize storage
421409
has_scatter_tma = scatter_indx is not None and target_info.has_tma_gather()
422410
c = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if has_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:]))
423-
a = Tensor(_canonicalize_storage(a.storage, 2 if has_gather_tma else 3), dtype=a.dtype, shape=a.shape, shape_max=a.shape_max)
424-
b = Tensor(_canonicalize_storage(b.storage, 3), dtype=b.dtype, shape=b.shape, shape_max=b.shape_max)
425-
c = Tensor(_canonicalize_storage(c.storage, 2 if has_scatter_tma else 3), dtype=c.dtype, shape=c.shape, shape_max=c.shape_max)
426-
# create tma descriptor for x
411+
a = Tensor(_canonicalize_storage(a.storage, 2 if has_gather_tma else 3), dtype=a.dtype, shape=a.shape, shape_max=a.shape_max, scale_global=a.scale_global, scale_mx=a.scale_mx)
412+
b = Tensor(_canonicalize_storage(b.storage, 3), dtype=b.dtype, shape=b.shape, shape_max=b.shape_max, scale_global=b.scale_global, scale_mx=b.scale_mx)
413+
c = Tensor(_canonicalize_storage(c.storage, 2 if has_scatter_tma else 3), dtype=c.dtype, shape=c.shape, shape_max=c.shape_max, scale_global=c.scale_global, scale_mx=c.scale_mx)
414+
# create tma descriptor for d
427415
if d_data is not None:
428416
assert opt_flags.split_k == 1, "d + split_k is not supported."
429417
assert scatter_indx is None, "d + scatter is not supported."
@@ -511,10 +499,10 @@ def matmul(a, b, bias,
511499
*((None, out_matmul_scale, None) if out_matmul_has_mx else (out_global_scale, out_absmax, None)),
512500
*out_matmul_scale_strides[-4:],
513501
a_tensor_or_tma, a.storage.data, *a_strides, a_transpose,
514-
a_scale_global,
502+
a.scale_global,
515503
a_scale_tensor_or_tma, *a_scale_strides,
516504
b_tensor_or_tma, b.storage.data, *b.storage.data.stride(), b_transpose,
517-
b_scale_global,
505+
b.scale_global,
518506
b_scale_tensor_or_tma, *b_scale_strides,
519507
d_data, *d_strides,
520508
d_scale_global, d_is_c,
@@ -536,7 +524,7 @@ def matmul(a, b, bias,
536524
precision_config.max_num_imprecise_acc,
537525
precision_config.allow_tf32,
538526
precision_config.flexpoint_saturate_inf,
539-
_is_per_batch_scale(b_scale_global),
527+
_is_per_batch_scale(b.scale_global),
540528
_is_per_batch_scale(out_global_scale),
541529
_is_per_batch_scale(d_scale_global),
542530
opt_flags.block_m,
@@ -569,33 +557,24 @@ def matmul(a, b, bias,
569557
assert not out_matmul_has_mx
570558
postprocess_fn1 = ReducePostprocessFn(specs=reduce_fused_activation.specs, fn_args=reduce_fused_activation.fn_args)
571559
postprocess_fn2 = ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize)
572-
reduce_y_flex = None
573-
if c_scale_global is not None or c_absmax is not None:
574-
reduce_y_flex = SimpleNamespace(
575-
expected_scale=c_scale_global,
576-
actual_scale=c_absmax,
577-
checksum_scale=None,
578-
is_per_batch=_is_per_batch_scale(c_scale_global),
579-
reinterpret=lambda x: x,
580-
)
581-
c, y_mx_scale = reduce(
560+
c, c_mx_scale = reduce(
582561
x = out_matmul.view(out_matmul.shape[0], -1, out_matmul.shape[-1]),
583562
dim = 0,
584563
# output data/metadata
585564
y = memory["output"].view(-1, memory["output"].shape[-1]),
586565
y_dtype = memory["output"].dtype,
587-
x_flex = None,
588-
y_flex = reduce_y_flex,
589-
y_flex_saturate_inf = precision_config.flexpoint_saturate_inf,
566+
y_scale_global = c_scale_global,
567+
y_absmax = c_absmax,
568+
y_saturate_inf = precision_config.flexpoint_saturate_inf,
590569
y_has_mx = c_scale_mx is not None,
591570
# fused functions
592571
postprocess_fn1 = postprocess_fn1,
593572
postprocess_fn2 = postprocess_fn2,
594573
)
595574
y_shape = out_matmul.shape[1:-1] + (out_matmul.shape[-1] // reduce_fused_activation.specs.reduction_n,)
596575
out_final = c.view(*y_shape)
597-
if y_mx_scale is not None:
598-
out_final_mx_scale = y_mx_scale.view(out_matmul.shape[-2], triton.cdiv(out_matmul.shape[-1], 32))
576+
if c_mx_scale is not None:
577+
out_final_mx_scale = c_mx_scale.view(out_matmul.shape[-2], triton.cdiv(out_matmul.shape[-1], 32))
599578
else:
600579
out_final = out_matmul.squeeze(0)
601580
out_final_mx_scale = out_matmul_scale
@@ -627,7 +606,7 @@ def apply(x, scale):
627606
return x.float() * scale
628607

629608
if x_tri.scale_mx is not None:
630-
a_scale = x_tri.scale_mx if isinstance(x_tri.scale_mx, Tensor) else wrap_torch_tensor(x_tri.scale_mx)
609+
a_scale = x_tri.scale_mx
631610
mx_axis = x_tri.storage.data.ndim -1
632611
canonical_layout = layout.StridedLayout(major_dim=mx_axis)
633612
x_tri = convert_layout(x_tri, canonical_layout)
@@ -637,7 +616,7 @@ def apply(x, scale):
637616
x_ref = apply(x_tri.storage.data, x_tri.scale_global)
638617

639618
if w_tri.scale_mx is not None:
640-
b_scale = w_tri.scale_mx if isinstance(w_tri.scale_mx, Tensor) else wrap_torch_tensor(w_tri.scale_mx)
619+
b_scale = w_tri.scale_mx
641620
mx_axis = w_tri.storage.data.ndim - 2
642621
canonical_layout = layout.StridedLayout(major_dim=mx_axis)
643622
w_tri = convert_layout(w_tri, canonical_layout)
Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,7 @@
1-
import torch
2-
from dataclasses import dataclass
3-
41
# ------ global scaling -------
52

63
MAX_FINITE_FLOAT8E5 = 57344.0
74
MAX_FINITE_FLOAT8E4NV = 448.0
85
MAX_FINITE_FLOAT8E4B8 = 240.0
96

10-
11-
@dataclass(frozen=True)
12-
class BaseFlexData:
13-
dtype: torch.dtype | None = None
14-
15-
def view(self, x: torch.Tensor):
16-
if self.dtype is None:
17-
return x
18-
return x.view(self.dtype)
19-
20-
def reinterpret(self, x):
21-
if self.dtype is None or x.dtype.itemsize > 1:
22-
return x
23-
return x.view(self.dtype)
24-
25-
26-
@dataclass(frozen=True)
27-
class InFlexData(BaseFlexData):
28-
scale: torch.Tensor | None = None
29-
30-
@property
31-
def is_per_batch(self):
32-
return False if self.scale is None else len(self.scale) > 1
33-
34-
35-
@dataclass(frozen=True)
36-
class OutFlexData(BaseFlexData):
37-
expected_scale: torch.Tensor | None = None
38-
actual_scale: torch.Tensor | None = None
39-
checksum_scale: torch.Tensor | None = None
40-
41-
@property
42-
def is_per_batch(self):
43-
return False if self.expected_scale is None else len(self.expected_scale) > 1
44-
45-
def __iter__(self):
46-
yield self.expected_scale
47-
yield self.actual_scale
48-
yield self.checksum_scale
49-
50-
517
# ------ block scaling -------

0 commit comments

Comments
 (0)