Skip to content

Commit af2196d

Browse files
committed
.
1 parent 4027076 commit af2196d

5 files changed

Lines changed: 59 additions & 51 deletions

File tree

python/triton_kernels/bench/bench_mlp.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata # ragged tensor
1616
from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_assignment, SymmetricMemoryPool
1717
# quantization
18-
from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4
18+
from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4, Tensor
1919
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
2020

2121

@@ -39,17 +39,16 @@ def quantize_weight(w, dtype, **opt):
3939
fp8e4_dtype = torch.float8_e4m3fn if get_cdna_version() != 3 else torch.float8_e4m3fnuz
4040
wq = w.to(fp8e4_dtype)
4141
wq = wq.transpose(-1, -2).contiguous().transpose(-1, -2)
42-
wq = wrap_torch_tensor(wq)
43-
wq.scale_global = w.abs().max().unsqueeze(0)
44-
return wq
42+
return wrap_torch_tensor(wq, scale_global=w.abs().max().unsqueeze(0))
4543
else:
4644
assert dtype == FP4, f"{dtype=}"
4745
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
4846
if opt:
4947
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], **opt["value_layout_opts"])
5048
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], **opt["scale_layout_opts"])
51-
w.scale_mx = w_scale
52-
return w
49+
if isinstance(w, Tensor):
50+
return Tensor(w.storage, dtype=w.dtype, shape=w.shape, shape_max=w.shape_max, scale_mx=w_scale)
51+
return wrap_torch_tensor(w, dtype=FP4, scale_mx=w_scale)
5352

5453

5554
def run_mlp(x_dp_local_bf16, x_dp_local_fp8, # activations

python/triton_kernels/bench/bench_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,16 @@ def _quantize_weight(w, dtype, **opt):
2020
wq = w.to(fp8e4_dtype)
2121
if is_cuda() and not cuda_capability_geq(10, 0):
2222
wq = wq.transpose(-1, -2).contiguous().transpose(-1, -2)
23-
wq = wrap_torch_tensor(wq)
24-
wq.scale_global = w.abs().max().unsqueeze(0)
25-
return wq
23+
return wrap_torch_tensor(wq, scale_global=w.abs().max().unsqueeze(0))
2624
else:
2725
assert dtype == "mx4", f"{dtype=}"
2826
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
2927
if opt:
3028
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], **opt["value_layout_opts"])
3129
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], **opt["scale_layout_opts"])
32-
w.scale_mx = w_scale
33-
return w
30+
if isinstance(w, Tensor):
31+
return Tensor(w.storage, dtype=w.dtype, shape=w.shape, shape_max=w.shape_max, scale_mx=w_scale)
32+
return wrap_torch_tensor(w, dtype=FP4, scale_mx=w_scale)
3433

3534

3635
@dataclass

python/triton_kernels/tests/test_matmul.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
375375
c = torch.empty(c_shape, dtype=c_dtype.torch_dtype, device=device)
376376
if c_transpose:
377377
c = c.mT.contiguous().mT
378-
c = wrap_torch_tensor(c)
379378

380379
# --- create precision config ---
381380
wrap_list = lambda vals: torch.tensor(vals, dtype=torch.float32, device=device)
@@ -387,48 +386,48 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
387386
acc_scale=2.0 if c_dtype.has_global_scale or b_dtype.has_global_scale else 1.0,
388387
out_dtype=c_dtype.torch_dtype,
389388
)
390-
a_scale_mx = a_scales
391-
b_scale_mx = b_scale_tri
392389
c_scale_mx = None
393-
a.scale_global = a_scale_global
394-
b.scale_global = b_scale_global
395-
c.scale_global = c_scale_global
396-
c.scale_actual = c_absmax
397-
a.scale_mx = a_scale_mx
398-
b.scale_mx = b_scale_mx
399-
c.scale_mx = c_scale_mx
400390

401391
# --- create epilogue ---
402392
epilogue = None
403393
if c_dtype.has_mx_scale:
404394
c_scale_shape = c_shape[:-1] + (triton.cdiv(c_shape[-1], MXFP_BLOCK_SIZE),)
405395
c_scale_mx = torch.empty(c_scale_shape, dtype=torch.uint8, device=a.device)
406-
c.scale_mx = c_scale_mx
407396
epilogue_spec = FnSpecs(FnName.QUANTIZE_MXFP8.name, quantize_mxfp8_fn, (), ())
408397
epilogue = Epilogue(epilogue_spec, tuple(), tuple(), effective_itemsize=6.0)
409398

399+
if isinstance(a, Tensor):
400+
a = Tensor(a.storage, dtype=a.dtype, shape=a.shape, shape_max=a.shape_max, scale_global=a_scale_global, scale_mx=a_scales)
401+
else:
402+
a = wrap_torch_tensor(a, scale_global=a_scale_global, scale_mx=a_scales)
403+
if isinstance(b, Tensor):
404+
b = Tensor(b.storage, dtype=b.dtype, shape=b.shape, shape_max=b.shape_max, scale_global=b_scale_global, scale_mx=b_scale_tri)
405+
else:
406+
b = wrap_torch_tensor(b, scale_global=b_scale_global, scale_mx=b_scale_tri)
407+
c = wrap_torch_tensor(c, scale_global=c_scale_global, scale_mx=c_scale_mx)
410408

411409
# --- triton implementation ---
412410
try:
413411
tri_y = matmul(a, b, bias,
414412
a_ragged_metadata, b_ragged_metadata,
415413
gather_indx, scatter_indx, precision_opt,
416-
gammas=gammas, epilogue=epilogue, c=c,
414+
gammas=gammas, epilogue=epilogue, c=c, c_absmax=c_absmax,
417415
fused_activation=fused_activation)
418416
if c_dtype.has_global_scale:
419-
tri_y_scale = c.scale_actual.clone()
417+
tri_y_scale = c_absmax.clone()
420418
except (opt_flags.InapplicableConstraint, NotImplementedError) as e:
421419
pytest.skip(f"inapplicable opt_flags constraint {e}")
422420
# --- torch implementation ---
423421
ref_y = matmul_torch(a, b, bias, #
424422
a_ragged_metadata, b_ragged_metadata,
425423
gather_indx, scatter_indx, precision_opt,
426424
gammas=gammas,
427-
c=c)
425+
c=c,
426+
c_absmax=c_absmax)
428427
if swiglu_opts is not None:
429428
ref_y = swiglu(ref_y, alpha=swiglu_opts[0], precision_config=SwiGLUPrecisionConfig(swiglu_opts[1]))
430429
if c_dtype.has_global_scale:
431-
ref_y_scale = c.scale_actual.clone()
430+
ref_y_scale = c_absmax.clone()
432431

433432
# --- check results ---
434433
if c_dtype.has_mx_scale:

python/triton_kernels/triton_kernels/matmul.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def matmul(a, b, bias,
224224
gammas: torch.Tensor | None = None,
225225
out_alpha: float | None = None,
226226
c: torch.Tensor | Tensor | None = None,
227+
c_absmax: torch.Tensor | None = None,
227228
fused_comm: FusedComm | None = None,
228229
fused_activation: FusedActivation | None = None,
229230
epilogue: Epilogue | None = None,
@@ -259,32 +260,37 @@ def matmul(a, b, bias,
259260
if epilogue is None:
260261
epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
261262
n_slices = max(1, b.shape[0]) if a_ragged_metadata is None else a_ragged_metadata.n_slices
262-
c_data = c.storage.data if isinstance(c, Tensor) else c
263-
d_data = d.storage.data if isinstance(d, Tensor) else d
263+
if c is not None and not isinstance(c, Tensor):
264+
c = wrap_torch_tensor(c)
265+
if d is not None and not isinstance(d, Tensor):
266+
d = wrap_torch_tensor(d)
267+
c_data = None if c is None else c.storage.data
268+
d_data = None if d is None else d.storage.data
264269
if not isinstance(a, Tensor):
265270
a = wrap_torch_tensor(a)
266271
if not isinstance(b, Tensor):
267-
dtype = FP4 if b.dtype == torch.uint8 else None
268-
b = wrap_torch_tensor(b, dtype=dtype)
272+
b_dtype = FP4 if b.dtype == torch.uint8 else None
273+
b = wrap_torch_tensor(b, dtype=b_dtype)
269274
a_scale_global = a.scale_global
270275
a_scale = a.scale_mx
271-
if a_scale is not None and not isinstance(a_scale, Tensor):
276+
if isinstance(a_scale, torch.Tensor):
272277
a_scale = wrap_torch_tensor(a_scale)
273278
b_scale_global = b.scale_global
274279
b_scale = b.scale_mx
280+
if isinstance(b_scale, torch.Tensor):
281+
b_scale = wrap_torch_tensor(b_scale)
275282
b_has_mx = b_scale is not None
276283
if b_has_mx and (torch.cuda.get_device_capability()[0] < 10 or b.storage.layout is not None and not isinstance(b.storage.layout, StridedLayout)):
277284
assert b.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
278-
if b_scale is not None and not isinstance(b_scale, Tensor):
279-
b_scale = wrap_torch_tensor(b_scale)
280285
if b_scale is not None:
281286
b_scale.storage.data = b_scale.data.view(torch.uint8)
282287
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and b.dtype.bitwidth == 8
283288
if is_hopper_fp8: assert b.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
284-
c_scale_global = None if not isinstance(c, Tensor) else c.scale_global
285-
c_absmax = None if not isinstance(c, Tensor) else c.scale_actual
286-
c_scale_mx = None if not isinstance(c, Tensor) else c.scale_mx
287-
d_scale_global = None if not isinstance(d, Tensor) else d.scale_global
289+
c_scale_global = None if c is None else c.scale_global
290+
c_scale_mx = None if c is None else c.scale_mx
291+
if isinstance(c_scale_mx, torch.Tensor):
292+
c_scale_mx = wrap_torch_tensor(c_scale_mx)
293+
d_scale_global = None if d is None else d.scale_global
288294

289295
# unpack a scale
290296
a_has_mx = a_scale is not None
@@ -597,7 +603,7 @@ def matmul(a, b, bias,
597603
if not (is_input_batched or b_ragged_metadata is not None):
598604
out_final = out_final.squeeze(0)
599605
if out_final_mx_scale is not None and c_scale_mx is not None:
600-
c_scale_mx_torch = c_scale_mx.storage.data if isinstance(c_scale_mx, Tensor) else c_scale_mx
606+
c_scale_mx_torch = c_scale_mx.storage.data
601607
if out_final_mx_scale.data_ptr() != c_scale_mx_torch.data_ptr():
602608
c_scale_mx_torch.copy_(out_final_mx_scale)
603609
return out_final
@@ -675,14 +681,17 @@ def matmul_torch(a, b, bias,
675681
gammas = None,
676682
round_x = None, round_y = None,
677683
c: torch.Tensor | Tensor | None = None,
684+
c_absmax: torch.Tensor | None = None,
678685
):
679686
if precision_config is None:
680687
precision_config = PrecisionConfig()
688+
if c is not None and not isinstance(c, Tensor):
689+
c = wrap_torch_tensor(c)
681690
if not isinstance(a, Tensor):
682691
a = wrap_torch_tensor(a)
683692
if not isinstance(b, Tensor):
684-
dtype = FP4 if b.dtype == torch.uint8 else None
685-
b = wrap_torch_tensor(b, dtype=dtype)
693+
b_dtype = FP4 if b.dtype == torch.uint8 else None
694+
b = wrap_torch_tensor(b, dtype=b_dtype)
686695
a, b = apply_precision(a, b, precision_config)
687696

688697
if b_ragged_metadata is not None:
@@ -708,9 +717,9 @@ def matmul_torch(a, b, bias,
708717
round_y=round_y,
709718
)
710719
out[expt] = out_expt.to(out.dtype)
711-
if isinstance(c, Tensor) and c.scale_actual is not None:
712-
c.scale_actual.copy_(compute_actual_scale(out, precision_config.out_dtype))
713-
return scale(out, c.scale_global if isinstance(c, Tensor) else None)
720+
if c_absmax is not None:
721+
c_absmax.copy_(compute_actual_scale(out, precision_config.out_dtype))
722+
return scale(out, None if c is None else c.scale_global)
714723

715724
is_input_batched = a.ndim == 3
716725
assert a.dtype.itemsize > 1
@@ -760,9 +769,9 @@ def matmul_torch(a, b, bias,
760769
out = torch.zeros((scatter_indx.shape[0], y.shape[-1]), dtype=y.dtype, device=a.device)
761770
msk = scatter_indx != -1
762771
out[scatter_indx[msk], :] = y[msk, :]
763-
if isinstance(c, Tensor) and c.scale_actual is not None:
764-
c.scale_actual.copy_(compute_actual_scale(out, precision_config.out_dtype))
765-
return scale(out, c.scale_global if isinstance(c, Tensor) else None)
772+
if c_absmax is not None:
773+
c_absmax.copy_(compute_actual_scale(out, precision_config.out_dtype))
774+
return scale(out, None if c is None else c.scale_global)
766775

767776

768777
def post_matmul_comm_torch(y: torch.Tensor, rank: int, n_reduce_shards: int,

python/triton_kernels/triton_kernels/tensor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class Tensor:
3636
shape: list[int] | None = None
3737
shape_max: list[int] | None = None
3838
scale_global: torch.Tensor | None = None
39-
scale_actual: torch.Tensor | None = None
4039
scale_mx: object | None = None
4140

4241
def __post_init__(self):
@@ -209,7 +208,11 @@ def __post_init__(self):
209208
# ---------------------------------------------------------------------------- #
210209

211210

212-
def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layout=None):
211+
def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layout=None, scale_global=None,
212+
scale_mx=None):
213+
assert isinstance(torch_tensor, torch.Tensor), f"`wrap_torch_tensor` expects torch.Tensor, got {type(torch_tensor)}"
214+
if isinstance(scale_mx, torch.Tensor):
215+
scale_mx = wrap_torch_tensor(scale_mx)
213216
if dtype is None:
214217
dtype = torch_tensor.dtype
215218
dtype = torch_dtype_to_dtype(dtype)
@@ -229,9 +232,8 @@ def wrap_torch_tensor(torch_tensor, dtype=None, shape=None, shape_max=None, layo
229232
dtype=dtype,
230233
shape=shape,
231234
shape_max=shape_max,
232-
scale_global=None,
233-
scale_actual=None,
234-
scale_mx=None,
235+
scale_global=scale_global,
236+
scale_mx=scale_mx,
235237
)
236238

237239

0 commit comments

Comments
 (0)