2121from triton_kernels .swiglu import swiglu , swiglu_fn
2222from triton_kernels .swiglu import PrecisionConfig as SwiGLUPrecisionConfig
2323from triton_kernels .tensor_details import layout
24+ from triton_kernels .tensor_details .dtype import FP32
25+
2426# ---------------
2527# numerics stuff
2628# ---------------
@@ -134,9 +136,9 @@ def _build_test_op_cases():
134136 Case (1024 , 1024 , 1024 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" ),
135137 Case (1024 , 1024 , 1024 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , split_k = 9 ),
136138 Case (1024 , 1024 , 1024 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , split_k = 9 , b_hbm_swizzling = True ),
137- Case (300 , 400 , 400 , "ragged" , "float8_e5m2" , "mxfloat8_e4m3fn" ),
139+ Case (300 , 400 , 416 , "ragged" , "float8_e5m2" , "mxfloat8_e4m3fn" ),
138140 Case (300 , 400 , 832 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" ),
139- Case (300 , 400 , 400 , "batched" , "float8_e5m2" , "mxfloat8_e4m3fn" ),
141+ Case (300 , 400 , 416 , "batched" , "float8_e5m2" , "mxfloat8_e4m3fn" ),
140142 ])
141143 # mxfloat x mxfloat
142144 test_cases .extend ([
@@ -145,11 +147,11 @@ def _build_test_op_cases():
145147 Case (1024 , 1024 , 1024 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , split_k = 9 , colmajor_mxfp_weight = False ),
146148 Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True , a_hbm_swizzling = True ),
147149 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True , a_hbm_swizzling = True ),
148- Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True , a_hbm_swizzling = True ),
150+ Case (300 , 400 , 416 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True , a_hbm_swizzling = True ),
149151 Case (256 , 1024 , 512 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True , a_hbm_swizzling = True ),
150- Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" ),
151- Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , b_hbm_swizzling = True ),
152- Case (300 , 400 , 400 , "batched" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" ),
152+ Case (300 , 400 , 416 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" ),
153+ Case (300 , 400 , 416 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , b_hbm_swizzling = True ),
154+ Case (300 , 400 , 416 , "batched" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" ),
153155 Case (1024 , 1024 , 1024 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True ),
154156 ])
155157 # amd-specific float8
@@ -340,9 +342,7 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
340342 ragged_padding = inner_expt_opt is not None and "pad_a" in inner_expt_opt ,
341343 squeeze_batch_dim = mode == "plain" ,
342344 scale_hbm_swizzling = layout .make_default_matmul_mxfp8_act_scale_layout if a_hbm_swizzling else None ,
343- scale_hbm_swizzling_args = {"ragged_metadata" : None }, # ragged_metadata will be set in the make_random_tensor function
344345 )
345-
346346 b , b_scale_tri , b_ragged_metadata = make_random_tensor (
347347 shape = (k , n ),
348348 n_slices = n_slices ,
@@ -354,10 +354,8 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
354354 ragged_padding = inner_expt_opt is not None and "pad_b" in inner_expt_opt ,
355355 squeeze_batch_dim = mode == "plain" ,
356356 is_mx_rowmajor = not colmajor_mxfp_weight ,
357- value_hbm_swizzling = layout .make_default_matmul_mxfp4_w_layout if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype .is_mxfloat4 else None ,
358- value_hbm_swizzling_args = {"mx_axis" :- 2 },
359- scale_hbm_swizzling = layout .make_default_matmul_mxfp4_w_scale_layout if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype .is_mxfloat4 else None ,
360- scale_hbm_swizzling_args = dict (mx_axis = - 2 , num_warps = num_warps ),
357+ value_hbm_swizzling = layout .make_default_matmul_mxfp4_w_layout (mx_axis = - 2 ) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype .is_mxfloat4 else None ,
358+ scale_hbm_swizzling = layout .make_default_matmul_mxfp4_w_scale_layout (mx_axis = - 2 , num_warps = num_warps ) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype .is_mxfloat4 else None ,
361359 )
362360 gather_indx = None if not do_gather else torch .randint (0 , max (m , 1 ), (m , ), dtype = torch .int32 , device = device )
363361 scatter_indx = None if not do_scatter else torch .randperm (m , dtype = torch .int32 , device = device )
@@ -442,6 +440,6 @@ def test_set_idle_sms():
442440 from triton_kernels .matmul_details .opt_flags import make_opt_flags
443441 num_idle_sms = 24
444442 matmul_set_idle_sms (num_idle_sms )
445- flags = make_opt_flags (torch . float32 , torch . float32 , torch . float32 , PrecisionConfig (), \
443+ flags = make_opt_flags (FP32 , FP32 , FP32 , PrecisionConfig (), \
446444 1 , 1024 , 1024 , 1024 , None , True , False , 1 , False , False , None )
447445 assert flags .idle_sms == num_idle_sms
0 commit comments