@@ -224,6 +224,7 @@ def matmul(a, b, bias,
224224 gammas : torch .Tensor | None = None ,
225225 out_alpha : float | None = None ,
226226 c : torch .Tensor | Tensor | None = None ,
227+ c_absmax : torch .Tensor | None = None ,
227228 fused_comm : FusedComm | None = None ,
228229 fused_activation : FusedActivation | None = None ,
229230 epilogue : Epilogue | None = None ,
@@ -259,32 +260,37 @@ def matmul(a, b, bias,
259260 if epilogue is None :
260261 epilogue = Epilogue (FnSpecs .default (), tuple (), tuple (), False )
261262 n_slices = max (1 , b .shape [0 ]) if a_ragged_metadata is None else a_ragged_metadata .n_slices
262- c_data = c .storage .data if isinstance (c , Tensor ) else c
263- d_data = d .storage .data if isinstance (d , Tensor ) else d
263+ if c is not None and not isinstance (c , Tensor ):
264+ c = wrap_torch_tensor (c )
265+ if d is not None and not isinstance (d , Tensor ):
266+ d = wrap_torch_tensor (d )
267+ c_data = None if c is None else c .storage .data
268+ d_data = None if d is None else d .storage .data
264269 if not isinstance (a , Tensor ):
265270 a = wrap_torch_tensor (a )
266271 if not isinstance (b , Tensor ):
267- dtype = FP4 if b .dtype == torch .uint8 else None
268- b = wrap_torch_tensor (b , dtype = dtype )
272+ b_dtype = FP4 if b .dtype == torch .uint8 else None
273+ b = wrap_torch_tensor (b , dtype = b_dtype )
269274 a_scale_global = a .scale_global
270275 a_scale = a .scale_mx
271- if a_scale is not None and not isinstance (a_scale , Tensor ):
276+ if isinstance (a_scale , torch . Tensor ):
272277 a_scale = wrap_torch_tensor (a_scale )
273278 b_scale_global = b .scale_global
274279 b_scale = b .scale_mx
280+ if isinstance (b_scale , torch .Tensor ):
281+ b_scale = wrap_torch_tensor (b_scale )
275282 b_has_mx = b_scale is not None
276283 if b_has_mx and (torch .cuda .get_device_capability ()[0 ] < 10 or b .storage .layout is not None and not isinstance (b .storage .layout , StridedLayout )):
277284 assert b .stride (- 2 ) == 1 , "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
278- if b_scale is not None and not isinstance (b_scale , Tensor ):
279- b_scale = wrap_torch_tensor (b_scale )
280285 if b_scale is not None :
281286 b_scale .storage .data = b_scale .data .view (torch .uint8 )
282287 is_hopper_fp8 = is_cuda () and not target_info .cuda_capability_geq (10 , 0 ) and b .dtype .bitwidth == 8
283288 if is_hopper_fp8 : assert b .stride (- 2 ) == 1 , "`w` must be column-major when it has data-type FP8 on capability < 10"
284- c_scale_global = None if not isinstance (c , Tensor ) else c .scale_global
285- c_absmax = None if not isinstance (c , Tensor ) else c .scale_actual
286- c_scale_mx = None if not isinstance (c , Tensor ) else c .scale_mx
287- d_scale_global = None if not isinstance (d , Tensor ) else d .scale_global
289+ c_scale_global = None if c is None else c .scale_global
290+ c_scale_mx = None if c is None else c .scale_mx
291+ if isinstance (c_scale_mx , torch .Tensor ):
292+ c_scale_mx = wrap_torch_tensor (c_scale_mx )
293+ d_scale_global = None if d is None else d .scale_global
288294
289295 # unpack a scale
290296 a_has_mx = a_scale is not None
@@ -597,7 +603,7 @@ def matmul(a, b, bias,
597603 if not (is_input_batched or b_ragged_metadata is not None ):
598604 out_final = out_final .squeeze (0 )
599605 if out_final_mx_scale is not None and c_scale_mx is not None :
600- c_scale_mx_torch = c_scale_mx .storage .data if isinstance ( c_scale_mx , Tensor ) else c_scale_mx
606+ c_scale_mx_torch = c_scale_mx .storage .data
601607 if out_final_mx_scale .data_ptr () != c_scale_mx_torch .data_ptr ():
602608 c_scale_mx_torch .copy_ (out_final_mx_scale )
603609 return out_final
@@ -675,14 +681,17 @@ def matmul_torch(a, b, bias,
675681 gammas = None ,
676682 round_x = None , round_y = None ,
677683 c : torch .Tensor | Tensor | None = None ,
684+ c_absmax : torch .Tensor | None = None ,
678685 ):
679686 if precision_config is None :
680687 precision_config = PrecisionConfig ()
688+ if c is not None and not isinstance (c , Tensor ):
689+ c = wrap_torch_tensor (c )
681690 if not isinstance (a , Tensor ):
682691 a = wrap_torch_tensor (a )
683692 if not isinstance (b , Tensor ):
684- dtype = FP4 if b .dtype == torch .uint8 else None
685- b = wrap_torch_tensor (b , dtype = dtype )
693+ b_dtype = FP4 if b .dtype == torch .uint8 else None
694+ b = wrap_torch_tensor (b , dtype = b_dtype )
686695 a , b = apply_precision (a , b , precision_config )
687696
688697 if b_ragged_metadata is not None :
@@ -708,9 +717,9 @@ def matmul_torch(a, b, bias,
708717 round_y = round_y ,
709718 )
710719 out [expt ] = out_expt .to (out .dtype )
711- if isinstance ( c , Tensor ) and c . scale_actual is not None :
712- c . scale_actual .copy_ (compute_actual_scale (out , precision_config .out_dtype ))
713- return scale (out , c . scale_global if isinstance ( c , Tensor ) else None )
720+ if c_absmax is not None :
721+ c_absmax .copy_ (compute_actual_scale (out , precision_config .out_dtype ))
722+ return scale (out , None if c is None else c . scale_global )
714723
715724 is_input_batched = a .ndim == 3
716725 assert a .dtype .itemsize > 1
@@ -760,9 +769,9 @@ def matmul_torch(a, b, bias,
760769 out = torch .zeros ((scatter_indx .shape [0 ], y .shape [- 1 ]), dtype = y .dtype , device = a .device )
761770 msk = scatter_indx != - 1
762771 out [scatter_indx [msk ], :] = y [msk , :]
763- if isinstance ( c , Tensor ) and c . scale_actual is not None :
764- c . scale_actual .copy_ (compute_actual_scale (out , precision_config .out_dtype ))
765- return scale (out , c . scale_global if isinstance ( c , Tensor ) else None )
772+ if c_absmax is not None :
773+ c_absmax .copy_ (compute_actual_scale (out , precision_config .out_dtype ))
774+ return scale (out , None if c is None else c . scale_global )
766775
767776
768777def post_matmul_comm_torch (y : torch .Tensor , rank : int , n_reduce_shards : int ,
0 commit comments