@@ -104,10 +104,11 @@ def gemm_tdm_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, #
104104# Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth.
105105# ---------------------------------------------------------------------------
106106
107+
107108@gluon .jit
108109def issue_loads_specialized (producer , a_desc , b_desc , off_am , off_bn , a_buffer , b_buffer , BLOCK_K : ttgl .constexpr ,
109- NUM_BUFFERS : ttgl .constexpr , TRANSPOSE_B : ttgl .constexpr ,
110- TDM_WARP_BASES : ttgl . constexpr , pred = 1 ):
110+ NUM_BUFFERS : ttgl .constexpr , TRANSPOSE_B : ttgl .constexpr , TDM_WARP_BASES : ttgl . constexpr ,
111+ pred = 1 ):
111112 pred_i32 = pred .to (ttgl .int32 ) if hasattr (pred , 'to' ) else pred
112113 ttgl .amd .gfx1250 .tdm .async_load (a_desc , [off_am , producer * BLOCK_K ], a_buffer .index (producer % NUM_BUFFERS ),
113114 pred = pred_i32 , warp_bases = TDM_WARP_BASES )
@@ -171,8 +172,8 @@ def gemm_tdm_specialized_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, #
171172 with ttgl .amd .warp_pipeline_stage ("stage0" , priority = 1 ):
172173 consumer , a , b = lds_load (consumer , a_buffer , OPERAND_LAYOUT_A , b_buffer , OPERAND_LAYOUT_B , NUM_BUFFERS ,
173174 TRANSPOSE_B )
174- producer = issue_loads_specialized (producer , a_desc , b_desc , 0 , 0 , a_buffer , b_buffer , BLOCK_K ,
175- NUM_BUFFERS , TRANSPOSE_B , TDM_WARP_BASES )
175+ producer = issue_loads_specialized (producer , a_desc , b_desc , 0 , 0 , a_buffer , b_buffer , BLOCK_K , NUM_BUFFERS ,
176+ TRANSPOSE_B , TDM_WARP_BASES )
176177 with ttgl .amd .warp_pipeline_stage ("stage1" , priority = 0 ):
177178 accumulator = issue_wmma_compute (a , b , accumulator )
178179 ttgl .amd .gfx1250 .tdm .async_wait (2 )
@@ -193,6 +194,7 @@ def gemm_tdm_specialized_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, #
193194# Helper
194195# ---------------------------------------------------------------------------
195196
197+
196198def _compute_tdm_warp_bases (block_shape , num_warps , active_warps ):
197199 """Compute warp_bases for partial TDM copy with the given active warp count.
198200
@@ -231,6 +233,7 @@ def _compute_tdm_warp_bases(block_shape, num_warps, active_warps):
231233# Tests
232234# ---------------------------------------------------------------------------
233235
236+
234237@pytest .mark .parametrize ("BLOCK_M,BLOCK_N,BLOCK_K" , [(256 , 256 , 64 )])
235238@pytest .mark .parametrize ("NUM_BUFFERS" , [3 ])
236239@pytest .mark .parametrize ("TRANSPOSE_B" , [True ])
0 commit comments