2121from triton_kernels .swiglu import swiglu , swiglu_fn
2222from triton_kernels .swiglu import PrecisionConfig as SwiGLUPrecisionConfig
2323from triton_kernels .tensor_details import layout
24- from triton_kernels .tensor_details .dtype import FP32
25-
2624# ---------------
2725# numerics stuff
2826# ---------------
@@ -136,9 +134,9 @@ def _build_test_op_cases():
136134 Case (1024 , 1024 , 1024 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" ),
137135 Case (1024 , 1024 , 1024 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , split_k = 9 ),
138136 Case (1024 , 1024 , 1024 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , split_k = 9 , b_hbm_swizzling = True ),
139- Case (300 , 400 , 416 , "ragged" , "float8_e5m2" , "mxfloat8_e4m3fn" ),
137+ Case (300 , 400 , 400 , "ragged" , "float8_e5m2" , "mxfloat8_e4m3fn" ),
140138 Case (300 , 400 , 832 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" ),
141- Case (300 , 400 , 416 , "batched" , "float8_e5m2" , "mxfloat8_e4m3fn" ),
139+ Case (300 , 400 , 400 , "batched" , "float8_e5m2" , "mxfloat8_e4m3fn" ),
142140 ])
143141 # mxfloat x mxfloat
144142 test_cases .extend ([
@@ -147,11 +145,11 @@ def _build_test_op_cases():
147145 Case (1024 , 1024 , 1024 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , split_k = 9 , colmajor_mxfp_weight = False ),
148146 Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True , a_hbm_swizzling = True ),
149147 Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True , a_hbm_swizzling = True ),
150- Case (300 , 400 , 416 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True , a_hbm_swizzling = True ),
148+ Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True , a_hbm_swizzling = True ),
151149 Case (256 , 1024 , 512 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True , a_hbm_swizzling = True ),
152- Case (300 , 400 , 416 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" ),
153- Case (300 , 400 , 416 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , b_hbm_swizzling = True ),
154- Case (300 , 400 , 416 , "batched" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" ),
150+ Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" ),
151+ Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , b_hbm_swizzling = True ),
152+ Case (300 , 400 , 400 , "batched" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" ),
155153 Case (1024 , 1024 , 1024 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , b_hbm_swizzling = True ),
156154 ])
157155 # amd-specific float8
@@ -342,7 +340,9 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
342340 ragged_padding = inner_expt_opt is not None and "pad_a" in inner_expt_opt ,
343341 squeeze_batch_dim = mode == "plain" ,
344342 scale_hbm_swizzling = layout .make_default_matmul_mxfp8_act_scale_layout if a_hbm_swizzling else None ,
343+ scale_hbm_swizzling_args = {"ragged_metadata" : None }, # ragged_metadata will be set in the make_random_tensor function
345344 )
345+
346346 b , b_scale_tri , b_ragged_metadata = make_random_tensor (
347347 shape = (k , n ),
348348 n_slices = n_slices ,
@@ -354,8 +354,10 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
354354 ragged_padding = inner_expt_opt is not None and "pad_b" in inner_expt_opt ,
355355 squeeze_batch_dim = mode == "plain" ,
356356 is_mx_rowmajor = not colmajor_mxfp_weight ,
357- value_hbm_swizzling = layout .make_default_matmul_mxfp4_w_layout (mx_axis = - 2 ) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype .is_mxfloat4 else None ,
358- scale_hbm_swizzling = layout .make_default_matmul_mxfp4_w_scale_layout (mx_axis = - 2 , num_warps = num_warps ) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype .is_mxfloat4 else None ,
357+ value_hbm_swizzling = layout .make_default_matmul_mxfp4_w_layout if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype .is_mxfloat4 else None ,
358+ value_hbm_swizzling_args = {"mx_axis" :- 2 },
359+ scale_hbm_swizzling = layout .make_default_matmul_mxfp4_w_scale_layout if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype .is_mxfloat4 else None ,
360+ scale_hbm_swizzling_args = dict (mx_axis = - 2 , num_warps = num_warps ),
359361 )
360362 gather_indx = None if not do_gather else torch .randint (0 , max (m , 1 ), (m , ), dtype = torch .int32 , device = device )
361363 scatter_indx = None if not do_scatter else torch .randperm (m , dtype = torch .int32 , device = device )
@@ -440,6 +442,6 @@ def test_set_idle_sms():
440442 from triton_kernels .matmul_details .opt_flags import make_opt_flags
441443 num_idle_sms = 24
442444 matmul_set_idle_sms (num_idle_sms )
443- flags = make_opt_flags (FP32 , FP32 , FP32 , PrecisionConfig (), \
445+ flags = make_opt_flags (torch . float32 , torch . float32 , torch . float32 , PrecisionConfig (), \
444446 1 , 1024 , 1024 , 1024 , None , True , False , 1 , False , False , None )
445447 assert flags .idle_sms == num_idle_sms
0 commit comments