@@ -223,7 +223,6 @@ def matmul(a, b, bias,
223223 gammas : torch .Tensor | None = None ,
224224 out_alpha : float | None = None ,
225225 c : torch .Tensor | Tensor | None = None ,
226- c_absmax : torch .Tensor | None = None ,
227226 fused_comm : FusedComm | None = None ,
228227 fused_activation : FusedActivation | None = None ,
229228 epilogue : Epilogue | None = None ,
@@ -491,12 +490,11 @@ def matmul(a, b, bias,
491490 } if fused_comm is not None else {}
492491 n_valid_slices = b_tensor_or_tma .shape [0 ] if ragged_dimension == "M" else n_slices
493492 # split-k scratchpad is fp32/fp16 accumulation, not the final output dtype.
494- # output flex scaling is applied in the reduce step.
493+ # output scaling is applied in the reduce step.
495494 out_global_scale = None if has_scratchpad else c_scale_global
496- out_absmax = None if has_scratchpad else c_absmax
497495 (kernels ._p_matmul if opt_flags .is_persistent else kernels ._matmul )[(grid ,)](
498496 c_tensor_or_tma , c .storage .data , * out_matmul .stride (),
499- * ((None , out_matmul_scale , None ) if out_matmul_has_mx else (out_global_scale , out_absmax , None )),
497+ * ((None , out_matmul_scale , None ) if out_matmul_has_mx else (out_global_scale , None , None )),
500498 * out_matmul_scale_strides [- 4 :],
501499 a_tensor_or_tma , a .storage .data , * a_strides , a_transpose ,
502500 a .scale_global ,
@@ -564,8 +562,6 @@ def matmul(a, b, bias,
564562 y = memory ["output" ].view (- 1 , memory ["output" ].shape [- 1 ]),
565563 y_dtype = memory ["output" ].dtype ,
566564 y_scale_global = c_scale_global ,
567- y_absmax = c_absmax ,
568- y_saturate_inf = precision_config .flexpoint_saturate_inf ,
569565 y_has_mx = c_scale_mx is not None ,
570566 # fused functions
571567 postprocess_fn1 = postprocess_fn1 ,
@@ -639,17 +635,6 @@ def scale(val, scal):
639635 assert val .ndim == 3
640636 return val / scal [:, None , None ]
641637
642- def compute_actual_scale (x , dtype , per_batch_scale = False ):
643- from triton_kernels .numerics import MAX_FINITE_FLOAT8E4B8 , MAX_FINITE_FLOAT8E4NV , MAX_FINITE_FLOAT8E5
644- max_finite = {
645- torch .float8_e5m2 : MAX_FINITE_FLOAT8E5 ,
646- torch .float8_e4m3fn : MAX_FINITE_FLOAT8E4NV ,
647- torch .float8_e4m3fnuz : MAX_FINITE_FLOAT8E4B8 ,
648- }[dtype ]
649- maxvals = x .abs ().amax (dim = tuple (range (1 , x .ndim ))) if per_batch_scale else x .abs ().max ()
650- return maxvals / max_finite
651-
652-
653638def matmul_torch (a , b , bias ,
654639 a_ragged_metadata : RaggedTensorMetadata | None = None ,
655640 b_ragged_metadata : RaggedTensorMetadata | None = None ,
@@ -660,7 +645,6 @@ def matmul_torch(a, b, bias,
660645 gammas = None ,
661646 round_x = None , round_y = None ,
662647 c : torch .Tensor | Tensor | None = None ,
663- c_absmax : torch .Tensor | None = None ,
664648 ):
665649 if precision_config is None :
666650 precision_config = PrecisionConfig ()
@@ -696,8 +680,6 @@ def matmul_torch(a, b, bias,
696680 round_y = round_y ,
697681 )
698682 out [expt ] = out_expt .to (out .dtype )
699- if c_absmax is not None :
700- c_absmax .copy_ (compute_actual_scale (out , precision_config .out_dtype ))
701683 return scale (out , None if c is None else c .scale_global )
702684
703685 is_input_batched = a .ndim == 3
@@ -748,8 +730,6 @@ def matmul_torch(a, b, bias,
748730 out = torch .zeros ((scatter_indx .shape [0 ], y .shape [- 1 ]), dtype = y .dtype , device = a .device )
749731 msk = scatter_indx != - 1
750732 out [scatter_indx [msk ], :] = y [msk , :]
751- if c_absmax is not None :
752- c_absmax .copy_ (compute_actual_scale (out , precision_config .out_dtype ))
753733 return scale (out , None if c is None else c .scale_global )
754734
755735
0 commit comments