@@ -887,6 +887,89 @@ def kernel(c_ptr, a_ptr, a_scale_ptr, b_ptr, b_scale_ptr, #
887887 torch .testing .assert_close (c .cpu (), c_torch , atol = 1e-5 , rtol = 2e-5 )
888888
889889
890+ @pytest .mark .skipif (not is_hip_gfx1250 (), reason = "Requires GFX1250" )
891+ @pytest .mark .parametrize ("B" , [4 ])
892+ @pytest .mark .parametrize ("M, N, K" , get_test_mxfp_block_mnk ())
893+ @pytest .mark .parametrize ("a_type, b_type" , get_test_mxfp_variants ())
894+ def test_amd_wmma_scaled_batched (B , M , N , K , a_type , b_type ):
895+
896+ @gluon .constexpr_function
897+ def _slice_layout (layout , indices ):
898+ for i in reversed (indices ):
899+ layout = ttgl .SliceLayout (i , layout )
900+ return layout
901+
902+ @gluon .jit
903+ def _offsets (dim0 , dim1 , dim2 , layout ):
904+ return ttgl .arange (0 , dim0 , layout = _slice_layout (layout , [1 , 2 ]))[:, None , None ] * (dim1 * dim2 ) + \
905+ ttgl .arange (0 , dim1 , layout = _slice_layout (layout , [0 , 2 ]))[None , :, None ] * dim2 + \
906+ ttgl .arange (0 , dim2 , layout = _slice_layout (layout , [0 , 1 ]))[None , None , :]
907+
908+ @gluon .jit
909+ def kernel (c_ptr , a_ptr , a_scale_ptr , b_ptr , b_scale_ptr , #
910+ a_type : ttgl .constexpr , b_type : ttgl .constexpr , #
911+ BLOCK_B : ttgl .constexpr , BLOCK_M : ttgl .constexpr , #
912+ BLOCK_N : ttgl .constexpr , BLOCK_K : ttgl .constexpr ):
913+ DIV_FACTOR_A : ttgl .constexpr = 2 if a_type == "e2m1" else 1
914+ DIV_FACTOR_B : ttgl .constexpr = 2 if b_type == "e2m1" else 1
915+
916+ warp_bases : ttgl .constexpr = [[1 , 0 , 0 ], [2 , 0 , 0 ]]
917+ wmma_layout : ttgl .constexpr = \
918+ ttgl .amd .AMDWMMALayout (3 , True , warp_bases , instr_shape = [16 , 16 , 128 ], rank = 3 )
919+ wmma_layout_packed : ttgl .constexpr = \
920+ ttgl .amd .AMDWMMALayout (3 , True , warp_bases , instr_shape = [16 , 16 , 64 ], rank = 3 )
921+ a_layout : ttgl .constexpr = \
922+ ttgl .DotOperandLayout (0 , wmma_layout_packed if a_type == "e2m1" else wmma_layout , 16 )
923+ b_layout : ttgl .constexpr = \
924+ ttgl .DotOperandLayout (1 , wmma_layout_packed if b_type == "e2m1" else wmma_layout , 16 )
925+ a_scale_layout : ttgl .constexpr = \
926+ get_wmma_scale_layout (a_layout , [BLOCK_B , BLOCK_M , BLOCK_K // 32 ])
927+ b_scale_layout : ttgl .constexpr = \
928+ get_wmma_scale_layout (b_layout , [BLOCK_B , BLOCK_N , BLOCK_K // 32 ])
929+
930+ a_offs = _offsets (BLOCK_B , BLOCK_M , BLOCK_K // DIV_FACTOR_A , a_layout )
931+ a = ttgl .load (a_ptr + a_offs )
932+ b_offs = _offsets (BLOCK_B , BLOCK_K // DIV_FACTOR_B , BLOCK_N , b_layout )
933+ b = ttgl .load (b_ptr + b_offs )
934+
935+ a_scale_offs = _offsets (BLOCK_B , BLOCK_M , BLOCK_K // 32 , a_scale_layout )
936+ a_scale = ttgl .load (a_scale_ptr + a_scale_offs )
937+ b_scale_offs = _offsets (BLOCK_B , BLOCK_N , BLOCK_K // 32 , b_scale_layout )
938+ b_scale = ttgl .load (b_scale_ptr + b_scale_offs )
939+
940+ zero = ttgl .zeros ([BLOCK_B , BLOCK_M , BLOCK_N ], dtype = ttgl .float32 , layout = wmma_layout )
941+ c = ttgl .amd .gfx1250 .wmma_scaled (a , a_scale , a_type , b , b_scale , b_type , zero )
942+ c = c .to (c_ptr .dtype .element_ty )
943+
944+ c_offs = _offsets (BLOCK_B , BLOCK_M , BLOCK_N , wmma_layout )
945+ ttgl .store (c_ptr + c_offs , c )
946+
947+ torch .manual_seed (42 )
948+ a , a_ref = zip (* [create_mxfp_operand (0 , M , K , a_type ) for _ in range (B )])
949+ b , b_ref = zip (* [create_mxfp_operand (1 , K , N , b_type ) for _ in range (B )])
950+ a_scale , a_scale_ref = zip (* [create_mxfp_scale (0 , M , K ) for _ in range (B )])
951+ b_scale , b_scale_ref = zip (* [create_mxfp_scale (1 , K , N ) for _ in range (B )])
952+
953+ a = torch .stack (a , dim = 0 )
954+ b = torch .stack (b , dim = 0 )
955+ a_scale = torch .stack (a_scale , dim = 0 )
956+ b_scale = torch .stack (b_scale , dim = 0 ).permute (0 , 2 , 1 ).contiguous ()
957+
958+ a_ref = torch .stack (a_ref , dim = 0 )
959+ b_ref = torch .stack (b_ref , dim = 0 )
960+ a_scale_ref = torch .stack (a_scale_ref , dim = 0 )
961+ b_scale_ref = torch .stack (b_scale_ref , dim = 0 )
962+
963+ a , a_scale = a .cuda (), a_scale .cuda ()
964+ b , b_scale = b .cuda (), b_scale .cuda ()
965+
966+ c = torch .zeros ((B , M , N ), dtype = torch .float32 ).cuda ()
967+ kernel [(1 , )](c , a , a_scale , b , b_scale , a_type , b_type , B , M , N , K , num_warps = 4 )
968+
969+ c_torch = (a_ref * a_scale_ref ) @ (b_ref * b_scale_ref )
970+ torch .testing .assert_close (c .cpu (), c_torch , atol = 1e-5 , rtol = 2e-5 )
971+
972+
890973@pytest .mark .skipif (not is_hip_gfx1250 (), reason = "Requires GFX1250" )
891974@pytest .mark .parametrize ("M, N, K" , [(16 , 16 , 128 ), (32 , 32 , 128 ), (32 , 32 , 256 ), (32 , 32 , 512 ), (64 , 64 , 128 ),
892975 (128 , 128 , 256 )])
0 commit comments