77from enum import Enum , auto
88import math
99from typing import Callable
10- from types import SimpleNamespace
1110# utilities
1211from triton_kernels import target_info
1312from triton_kernels .meta import Closure
@@ -260,36 +259,25 @@ def matmul(a, b, bias,
260259 if epilogue is None :
261260 epilogue = Epilogue (FnSpecs .default (), tuple (), tuple (), False )
262261 n_slices = max (1 , b .shape [0 ]) if a_ragged_metadata is None else a_ragged_metadata .n_slices
263- if c is not None and not isinstance (c , Tensor ):
264- c = wrap_torch_tensor (c )
265- if d is not None and not isinstance (d , Tensor ):
266- d = wrap_torch_tensor (d )
267- c_data = None if c is None else c .storage .data
268- d_data = None if d is None else d .storage .data
269262 if not isinstance (a , Tensor ):
270263 a = wrap_torch_tensor (a )
271264 if not isinstance (b , Tensor ):
272265 b_dtype = FP4 if b .dtype == torch .uint8 else None
273266 b = wrap_torch_tensor (b , dtype = b_dtype )
274- a_scale_global = a .scale_global
267+ if c is not None and not isinstance (c , Tensor ):
268+ c = wrap_torch_tensor (c )
269+ if d is not None and not isinstance (d , Tensor ):
270+ d = wrap_torch_tensor (d )
271+ d_data = None if d is None else d .storage .data
275272 a_scale = a .scale_mx
276- if isinstance (a_scale , torch .Tensor ):
277- a_scale = wrap_torch_tensor (a_scale )
278- b_scale_global = b .scale_global
279273 b_scale = b .scale_mx
280- if isinstance (b_scale , torch .Tensor ):
281- b_scale = wrap_torch_tensor (b_scale )
282274 b_has_mx = b_scale is not None
283275 if b_has_mx and (torch .cuda .get_device_capability ()[0 ] < 10 or b .storage .layout is not None and not isinstance (b .storage .layout , StridedLayout )):
284276 assert b .stride (- 2 ) == 1 , "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
285- if b_scale is not None :
286- b_scale .storage .data = b_scale .data .view (torch .uint8 )
287277 is_hopper_fp8 = is_cuda () and not target_info .cuda_capability_geq (10 , 0 ) and b .dtype .bitwidth == 8
288278 if is_hopper_fp8 : assert b .stride (- 2 ) == 1 , "`w` must be column-major when it has data-type FP8 on capability < 10"
289279 c_scale_global = None if c is None else c .scale_global
290280 c_scale_mx = None if c is None else c .scale_mx
291- if isinstance (c_scale_mx , torch .Tensor ):
292- c_scale_mx = wrap_torch_tensor (c_scale_mx )
293281 d_scale_global = None if d is None else d .scale_global
294282
295283 # unpack a scale
@@ -310,8 +298,8 @@ def matmul(a, b, bias,
310298 batch_size = b .shape [0 ]
311299 else :
312300 batch_size = 1
313- if d_data is not None :
314- d_is_c = c_data is not None and d_data .data_ptr () == c_data . data_ptr () and d_data .stride () == c_data .stride ()
301+ if d_data is not None and c is not None :
302+ d_is_c = d_data .data_ptr () == c . storage . data . data_ptr () and d_data .stride () == c . storage . data .stride ()
315303 else :
316304 d_is_c = None
317305 K = a .shape [- 1 ]
@@ -327,8 +315,8 @@ def matmul(a, b, bias,
327315 (b_scale is None or is_tma_compliant (b_scale )) and
328316 (ragged_dimension != "M" or a .stride (- 1 ) == 1 ) and
329317 # Currently we don't support tma if y is column major; may revisit later if this becomes an issue.
330- (c_data is None or c_data .stride (- 1 ) == 1 ) and
331- (d_data is None or d_is_c ) and
318+ (c is None or c . storage . data .stride (- 1 ) == 1 ) and
319+ (d is None or d_is_c ) and
332320 # if ragged dimension is K, w must be either padded or row major to ensure alignment
333321 (ragged_dimension != "K" or b .stride (- 1 ) == 1 or b_ragged_metadata .slice_sizes_divisibility is not None )
334322 )
@@ -382,7 +370,7 @@ def matmul(a, b, bias,
382370 gather_indx , scatter_indx , batch_size ,
383371 fused_comm .n_reduce_shards if fused_comm is not None else 1 ,
384372 opt_flags )
385- memory = apply_allocation (allocation , c_data )
373+ memory = apply_allocation (allocation , None if c is None else c . storage . data )
386374 # early exit
387375 if batch_size * M * N == 0 :
388376 ret = memory ["output" ].squeeze (0 )
@@ -420,10 +408,10 @@ def matmul(a, b, bias,
420408 # canonicalize storage
421409 has_scatter_tma = scatter_indx is not None and target_info .has_tma_gather ()
422410 c = wrap_torch_tensor (out_matmul .view (math .prod (out_matmul .shape [:- 1 ]), out_matmul .shape [- 1 ]) if has_scatter else out_matmul .view (math .prod (out_matmul .shape [:- 2 ]), * out_matmul .shape [- 2 :]))
423- a = Tensor (_canonicalize_storage (a .storage , 2 if has_gather_tma else 3 ), dtype = a .dtype , shape = a .shape , shape_max = a .shape_max )
424- b = Tensor (_canonicalize_storage (b .storage , 3 ), dtype = b .dtype , shape = b .shape , shape_max = b .shape_max )
425- c = Tensor (_canonicalize_storage (c .storage , 2 if has_scatter_tma else 3 ), dtype = c .dtype , shape = c .shape , shape_max = c .shape_max )
426- # create tma descriptor for x
411+ a = Tensor (_canonicalize_storage (a .storage , 2 if has_gather_tma else 3 ), dtype = a .dtype , shape = a .shape , shape_max = a .shape_max , scale_global = a . scale_global , scale_mx = a . scale_mx )
412+ b = Tensor (_canonicalize_storage (b .storage , 3 ), dtype = b .dtype , shape = b .shape , shape_max = b .shape_max , scale_global = b . scale_global , scale_mx = b . scale_mx )
413+ c = Tensor (_canonicalize_storage (c .storage , 2 if has_scatter_tma else 3 ), dtype = c .dtype , shape = c .shape , shape_max = c .shape_max , scale_global = c . scale_global , scale_mx = c . scale_mx )
414+ # create tma descriptor for d
427415 if d_data is not None :
428416 assert opt_flags .split_k == 1 , "d + split_k is not supported."
429417 assert scatter_indx is None , "d + scatter is not supported."
@@ -511,10 +499,10 @@ def matmul(a, b, bias,
511499 * ((None , out_matmul_scale , None ) if out_matmul_has_mx else (out_global_scale , out_absmax , None )),
512500 * out_matmul_scale_strides [- 4 :],
513501 a_tensor_or_tma , a .storage .data , * a_strides , a_transpose ,
514- a_scale_global ,
502+ a . scale_global ,
515503 a_scale_tensor_or_tma , * a_scale_strides ,
516504 b_tensor_or_tma , b .storage .data , * b .storage .data .stride (), b_transpose ,
517- b_scale_global ,
505+ b . scale_global ,
518506 b_scale_tensor_or_tma , * b_scale_strides ,
519507 d_data , * d_strides ,
520508 d_scale_global , d_is_c ,
@@ -536,7 +524,7 @@ def matmul(a, b, bias,
536524 precision_config .max_num_imprecise_acc ,
537525 precision_config .allow_tf32 ,
538526 precision_config .flexpoint_saturate_inf ,
539- _is_per_batch_scale (b_scale_global ),
527+ _is_per_batch_scale (b . scale_global ),
540528 _is_per_batch_scale (out_global_scale ),
541529 _is_per_batch_scale (d_scale_global ),
542530 opt_flags .block_m ,
@@ -569,33 +557,24 @@ def matmul(a, b, bias,
569557 assert not out_matmul_has_mx
570558 postprocess_fn1 = ReducePostprocessFn (specs = reduce_fused_activation .specs , fn_args = reduce_fused_activation .fn_args )
571559 postprocess_fn2 = ReducePostprocessFn (specs = epilogue .specs , fn_args = epilogue .fn_arg_values_finalize )
572- reduce_y_flex = None
573- if c_scale_global is not None or c_absmax is not None :
574- reduce_y_flex = SimpleNamespace (
575- expected_scale = c_scale_global ,
576- actual_scale = c_absmax ,
577- checksum_scale = None ,
578- is_per_batch = _is_per_batch_scale (c_scale_global ),
579- reinterpret = lambda x : x ,
580- )
581- c , y_mx_scale = reduce (
560+ c , c_mx_scale = reduce (
582561 x = out_matmul .view (out_matmul .shape [0 ], - 1 , out_matmul .shape [- 1 ]),
583562 dim = 0 ,
584563 # output data/metadata
585564 y = memory ["output" ].view (- 1 , memory ["output" ].shape [- 1 ]),
586565 y_dtype = memory ["output" ].dtype ,
587- x_flex = None ,
588- y_flex = reduce_y_flex ,
589- y_flex_saturate_inf = precision_config .flexpoint_saturate_inf ,
566+ y_scale_global = c_scale_global ,
567+ y_absmax = c_absmax ,
568+ y_saturate_inf = precision_config .flexpoint_saturate_inf ,
590569 y_has_mx = c_scale_mx is not None ,
591570 # fused functions
592571 postprocess_fn1 = postprocess_fn1 ,
593572 postprocess_fn2 = postprocess_fn2 ,
594573 )
595574 y_shape = out_matmul .shape [1 :- 1 ] + (out_matmul .shape [- 1 ] // reduce_fused_activation .specs .reduction_n ,)
596575 out_final = c .view (* y_shape )
597- if y_mx_scale is not None :
598- out_final_mx_scale = y_mx_scale .view (out_matmul .shape [- 2 ], triton .cdiv (out_matmul .shape [- 1 ], 32 ))
576+ if c_mx_scale is not None :
577+ out_final_mx_scale = c_mx_scale .view (out_matmul .shape [- 2 ], triton .cdiv (out_matmul .shape [- 1 ], 32 ))
599578 else :
600579 out_final = out_matmul .squeeze (0 )
601580 out_final_mx_scale = out_matmul_scale
@@ -627,7 +606,7 @@ def apply(x, scale):
627606 return x .float () * scale
628607
629608 if x_tri .scale_mx is not None :
630- a_scale = x_tri .scale_mx if isinstance ( x_tri . scale_mx , Tensor ) else wrap_torch_tensor ( x_tri . scale_mx )
609+ a_scale = x_tri .scale_mx
631610 mx_axis = x_tri .storage .data .ndim - 1
632611 canonical_layout = layout .StridedLayout (major_dim = mx_axis )
633612 x_tri = convert_layout (x_tri , canonical_layout )
@@ -637,7 +616,7 @@ def apply(x, scale):
637616 x_ref = apply (x_tri .storage .data , x_tri .scale_global )
638617
639618 if w_tri .scale_mx is not None :
640- b_scale = w_tri .scale_mx if isinstance ( w_tri . scale_mx , Tensor ) else wrap_torch_tensor ( w_tri . scale_mx )
619+ b_scale = w_tri .scale_mx
641620 mx_axis = w_tri .storage .data .ndim - 2
642621 canonical_layout = layout .StridedLayout (major_dim = mx_axis )
643622 w_tri = convert_layout (w_tri , canonical_layout )
0 commit comments