@@ -158,13 +158,10 @@ def test_consan_uses_profile_scratch(device, fresh_knobs, num_ctas):
158158
159159@pytest .mark .skipif (not is_cuda () or torch .cuda .get_device_capability ()[0 ] < 9 , reason = "Requires hopper or newer" )
160160@pytest .mark .parametrize ("FAILURE" , [True , False ])
161- @pytest .mark .parametrize ("TWO_CTA_BARRIER" , [False , True ])
162- def test_async_tma_kernel (FAILURE , TWO_CTA_BARRIER , device , run_wrapper , monkeypatch , num_ctas ):
163- if TWO_CTA_BARRIER and num_ctas == 1 :
164- pytest .skip ("Need at least 2 CTAs for a two-CTA barrier" )
161+ def test_async_tma_kernel (FAILURE , device , run_wrapper , monkeypatch , num_ctas ):
165162 if run_wrapper :
166- result = run_in_process (test_async_tma_kernel , (FAILURE , TWO_CTA_BARRIER , device , False , monkeypatch , num_ctas ))
167- if FAILURE or TWO_CTA_BARRIER :
163+ result = run_in_process (test_async_tma_kernel , (FAILURE , device , False , monkeypatch , num_ctas ))
164+ if FAILURE :
168165 assert_expected_cuda_failure (result .exc )
169166 assert "Buffer being accessed has outstanding writes" in result .driver_stderr_output
170167 else :
@@ -177,13 +174,13 @@ def test_async_tma_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrapper, monkeyp
177174 knobs .refresh_knobs ()
178175
179176 @gluon .jit
180- def kernel (input_desc , out , FAILURE : ttgl .constexpr , TWO_CTA_BARRIER : ttgl . constexpr ):
177+ def kernel (input_desc , out , FAILURE : ttgl .constexpr ):
181178 block_m : ttgl .constexpr = XBLOCK * ttgl .num_ctas ()
182179 cga_layout : ttgl .constexpr = default_cga_layout (ttgl .num_ctas (), 2 )
183180 blocked_layout : ttgl .constexpr = ttgl .BlockedLayout (size_per_thread = [1 , 1 ], threads_per_warp = [32 , 1 ],
184181 warps_per_cta = [4 , 1 ], order = [0 , 1 ], cga_layout = cga_layout )
185182 smem = ttgl .allocate_shared_memory (ttgl .float16 , [block_m , XBLOCK ], input_desc .layout )
186- bar = mbarrier .allocate_mbarrier (two_ctas = TWO_CTA_BARRIER )
183+ bar = mbarrier .allocate_mbarrier ()
187184 mbarrier .init (bar , count = 1 )
188185 mbarrier .expect (bar , input_desc .nbytes_per_cta )
189186 tma .async_copy_global_to_shared (input_desc , [0 , 0 ], bar , smem )
@@ -203,19 +200,17 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.const
203200 shared_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 ,
204201 cga_layout = default_cga_layout (num_ctas , 2 ))
205202 input_desc = gluon .nvidia .hopper .TensorDescriptor .from_tensor (input , [block_m , XBLOCK .value ], shared_layout )
206- kernel [(1 , )](input_desc , output , FAILURE = FAILURE , TWO_CTA_BARRIER = TWO_CTA_BARRIER , num_warps = 4 , num_ctas = num_ctas )
203+ kernel [(1 , )](input_desc , output , FAILURE = FAILURE , num_warps = 4 , num_ctas = num_ctas )
207204
208205
209206@pytest .mark .skipif (not is_cuda () or torch .cuda .get_device_capability ()[0 ] < 9 , reason = "Requires hopper or newer" )
210207@pytest .mark .parametrize ("FAILURE" , [True , False ])
211- @pytest .mark .parametrize ("TWO_CTA_BARRIER" , [False , True ])
212- def test_async_tma_multicast_kernel (FAILURE , TWO_CTA_BARRIER , device , run_wrapper , monkeypatch , num_ctas ):
208+ def test_async_tma_multicast_kernel (FAILURE , device , run_wrapper , monkeypatch , num_ctas ):
213209 if num_ctas == 1 :
214210 pytest .skip ("Need at least 2 CTAs for multicast in this test" )
215211 if run_wrapper :
216- result = run_in_process (test_async_tma_multicast_kernel ,
217- (FAILURE , TWO_CTA_BARRIER , device , False , monkeypatch , num_ctas ))
218- if FAILURE or TWO_CTA_BARRIER :
212+ result = run_in_process (test_async_tma_multicast_kernel , (FAILURE , device , False , monkeypatch , num_ctas ))
213+ if FAILURE :
219214 assert_expected_cuda_failure (result .exc )
220215 assert "Buffer being accessed has outstanding writes" in result .driver_stderr_output
221216 else :
@@ -228,12 +223,12 @@ def test_async_tma_multicast_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrappe
228223 knobs .refresh_knobs ()
229224
230225 @gluon .jit
231- def kernel (input_desc , out , FAILURE : ttgl .constexpr , TWO_CTA_BARRIER : ttgl . constexpr ):
226+ def kernel (input_desc , out , FAILURE : ttgl .constexpr ):
232227 cga_layout : ttgl .constexpr = multicast_cga_layout (ttgl .num_ctas (), 2 )
233228 blocked_layout : ttgl .constexpr = ttgl .BlockedLayout (size_per_thread = [1 , 1 ], threads_per_warp = [32 , 1 ],
234229 warps_per_cta = [4 , 1 ], order = [0 , 1 ], cga_layout = cga_layout )
235230 smem = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], input_desc .layout )
236- bar = mbarrier .allocate_mbarrier (two_ctas = TWO_CTA_BARRIER )
231+ bar = mbarrier .allocate_mbarrier ()
237232 mbarrier .init (bar , count = 1 )
238233 mbarrier .expect (bar , input_desc .nbytes_per_cta )
239234 tma .async_copy_global_to_shared (input_desc , [0 , 0 ], bar , smem , multicast = True )
@@ -252,54 +247,7 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.const
252247 shared_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 ,
253248 cga_layout = multicast_cga_layout (num_ctas , 2 ))
254249 input_desc = gluon .nvidia .hopper .TensorDescriptor .from_tensor (input , [XBLOCK .value , XBLOCK .value ], shared_layout )
255- kernel [(1 , )](input_desc , output , FAILURE = FAILURE , TWO_CTA_BARRIER = TWO_CTA_BARRIER , num_warps = 4 , num_ctas = num_ctas )
256-
257-
258- @pytest .mark .skipif (not is_cuda () or torch .cuda .get_device_capability ()[0 ] < 9 , reason = "Requires hopper or newer" )
259- @pytest .mark .parametrize ("TWO_CTA_BARRIER" , [False , True ])
260- def test_async_tma_multicast_kernel_two_cta_barrier (TWO_CTA_BARRIER , device , run_wrapper , monkeypatch , num_ctas ):
261- if num_ctas != 2 :
262- pytest .skip ("This test covers a single 2-CTA multicast group" )
263- if run_wrapper :
264- result = run_in_process (test_async_tma_multicast_kernel_two_cta_barrier ,
265- (TWO_CTA_BARRIER , device , False , monkeypatch , num_ctas ))
266- if TWO_CTA_BARRIER :
267- assert_expected_cuda_failure (result .exc )
268- assert "Buffer being accessed has outstanding writes" in result .driver_stderr_output
269- else :
270- assert result .exc is None
271- assert result .driver_stderr_output == ""
272- return
273-
274- monkeypatch .setenv ("TRITON_INSTRUMENTATION_MODE" , "consan" )
275- monkeypatch .setenv ("CUDA_LAUNCH_BLOCKING" , "1" )
276- knobs .refresh_knobs ()
277-
278- @gluon .jit
279- def kernel (input_desc , out , TWO_CTA_BARRIER : ttgl .constexpr ):
280- cga_layout : ttgl .constexpr = multicast_cga_layout (ttgl .num_ctas (), 2 )
281- blocked_layout : ttgl .constexpr = ttgl .BlockedLayout (size_per_thread = [1 , 1 ], threads_per_warp = [32 , 1 ],
282- warps_per_cta = [4 , 1 ], order = [0 , 1 ], cga_layout = cga_layout )
283- smem = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], input_desc .layout )
284- bar = mbarrier .allocate_mbarrier (two_ctas = TWO_CTA_BARRIER )
285- mbarrier .init (bar , count = 1 )
286- mbarrier .expect (bar , input_desc .nbytes_per_cta )
287- tma .async_copy_global_to_shared (input_desc , [0 , 0 ], bar , smem , multicast = True )
288- mbarrier .wait (bar , 0 , deps = [smem ])
289- val = smem .load (blocked_layout )
290- mbarrier .invalidate (bar )
291-
292- out_m = ttgl .arange (0 , XBLOCK , ttgl .SliceLayout (1 , blocked_layout ))[:, None ]
293- out_n = ttgl .arange (0 , XBLOCK , ttgl .SliceLayout (0 , blocked_layout ))[None , :]
294- out_ptr = out + out_m * XBLOCK + out_n
295- ttgl .store (out_ptr , val )
296-
297- input = torch .randn ((XBLOCK .value , XBLOCK .value ), device = device , dtype = torch .float16 )
298- output = torch .empty ((XBLOCK .value , XBLOCK .value ), device = device , dtype = torch .float16 )
299- shared_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 ,
300- cga_layout = multicast_cga_layout (num_ctas , 2 ))
301- input_desc = gluon .nvidia .hopper .TensorDescriptor .from_tensor (input , [XBLOCK .value , XBLOCK .value ], shared_layout )
302- kernel [(1 , )](input_desc , output , TWO_CTA_BARRIER = TWO_CTA_BARRIER , num_warps = 4 , num_ctas = num_ctas )
250+ kernel [(1 , )](input_desc , output , FAILURE = FAILURE , num_warps = 4 , num_ctas = num_ctas )
303251
304252
305253@pytest .mark .skipif (not is_cuda () or torch .cuda .get_device_capability ()[0 ] < 9 , reason = "Requires hopper or newer" )
@@ -741,21 +689,24 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt
741689 ttgl .NVMMASharedLayout .get_default_for ([XBLOCK , block_n ], ttgl .float16 ,
742690 cga_layout = mma_cga_layout (ttgl .num_ctas (), 1 , TWO_CTAS )),
743691 )
744- bar = mbarrier .allocate_mbarrier (batch = 2 )
692+ mma_bar = mbarrier .allocate_mbarrier ()
745693 acc = blackwell .allocate_tensor_memory (ttgl .float32 , [block_m , block_n ], acc_layout )
746- mbarrier .init (bar .index (0 ), count = 1 )
747- mbarrier .init (bar .index (1 ), count = 1 )
694+ mbarrier .init (mma_bar , count = 1 )
695+ if MEM_ACCESS_KIND == "tma_cp" :
696+ tma_bar = mbarrier .allocate_mbarrier (two_ctas = TWO_CTAS )
697+ mbarrier .init (tma_bar , count = 1 )
748698
749699 blackwell .tcgen05_mma (smemA , smemB , acc )
750- blackwell .tcgen05_commit (bar . index ( 0 ) )
700+ blackwell .tcgen05_commit (mma_bar )
751701
752702 if not FAILURE :
753- mbarrier .wait (bar . index ( 0 ) , 0 )
703+ mbarrier .wait (mma_bar , 0 )
754704
755705 if MEM_ACCESS_KIND == "tma_cp" :
756- mbarrier .expect (bar .index (1 ), input_desc .nbytes_per_cta )
757- tma .async_copy_global_to_shared (input_desc , [0 , 0 ], bar .index (1 ), smemA )
758- mbarrier .wait (bar .index (1 ), 0 )
706+ mbarrier .expect (tma_bar , input_desc .nbytes_per_cta )
707+ tma .async_copy_global_to_shared (input_desc , [0 , 0 ], tma_bar , smemA )
708+ mbarrier .wait (tma_bar , 0 )
709+ mbarrier .invalidate (tma_bar )
759710 elif MEM_ACCESS_KIND == "local_store" :
760711 smemA .store (ttgl .full ([block_m , XBLOCK ], 42 , ttgl .float16 , smem_a_blocked_layout ))
761712 elif MEM_ACCESS_KIND == "tmem_load" :
@@ -769,8 +720,7 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt
769720 elif MEM_ACCESS_KIND == "tmem_store" :
770721 acc .store (ttgl .full ([block_m , block_n ], 42 , ttgl .float32 , acc_blocked_layout ))
771722
772- mbarrier .invalidate (bar .index (0 ))
773- mbarrier .invalidate (bar .index (1 ))
723+ mbarrier .invalidate (mma_bar )
774724
775725 block_m = mma_block_m (num_ctas )
776726 block_n = mma_block_n (num_ctas )
0 commit comments