1111
1212
1313@tilus .autotune ("num_stages" , [2 , 3 , 4 , 5 , 6 , 7 ])
14- @tilus .autotune ("block_m, block_n" , [[128 , 64 ], [128 , 128 ], [128 , 256 ], [256 , 128 ], [256 , 256 ]])
14+ @tilus .autotune (
15+ "block_m, block_n" , [[128 , 64 ], [128 , 128 ], [128 , 256 ], [256 , 128 ], [256 , 256 ]]
16+ )
1517@tilus .autotune ("block_k" , [16 , 32 , 64 ])
1618class MatmulWGMMAV3 (tilus .Script ):
17-
1819 def __init__ (
1920 self ,
2021 num_stages ,
@@ -54,11 +55,15 @@ def __call__(
5455 acc = self .register_tensor (dtype = float32 , shape = [block_m , block_n ], init = 0.0 )
5556
5657 consumer_barriers = self .mbarrier .alloc (count = [2 for _ in range (self .num_stages )])
57- producer_barriers = self .mbarrier .alloc (count = [128 for _ in range (self .num_stages )])
58+ producer_barriers = self .mbarrier .alloc (
59+ count = [128 for _ in range (self .num_stages )]
60+ )
5861
5962 with self .thread_group (thread_begin = 128 , num_threads = 32 ):
6063 stage : int32 = 0
61- producer_phases = self .register_tensor (dtype = uint32 , shape = [self .num_stages ], init = 1 )
64+ producer_phases = self .register_tensor (
65+ dtype = uint32 , shape = [self .num_stages ], init = 1
66+ )
6267 for offset_k in self .range (0 , k_size , block_k , unroll = self .num_stages ):
6368 self .mbarrier .wait (producer_barriers [stage ], phase = producer_phases [stage ])
6469 producer_phases [stage ] ^= 1
@@ -76,7 +81,7 @@ def __call__(
7681 mbarrier = consumer_barriers [stage ],
7782 )
7883 stage = (stage + 1 ) % self .num_stages
79-
84+
8085 for _ in self .range (min (self .num_stages , cdiv (k_size , self .block_k ))):
8186 self .mbarrier .wait (
8287 producer_barriers [stage ], phase = producer_phases [stage ]
@@ -85,9 +90,11 @@ def __call__(
8590 stage = (stage + 1 ) % self .num_stages
8691
8792 with self .thread_group (thread_begin = 0 , num_threads = 128 ):
88- consumer_phases = self .register_tensor (dtype = uint32 , shape = [self .num_stages ], init = 0 )
93+ consumer_phases = self .register_tensor (
94+ dtype = uint32 , shape = [self .num_stages ], init = 0
95+ )
8996 stage : int32 = 0
90- for offset_k in self .range (0 , k_size , block_k , unroll = self .num_stages ):
97+ for offset_k in self .range (0 , k_size , block_k , unroll = self .num_stages ):
9198 self .mbarrier .wait (consumer_barriers [stage ], phase = consumer_phases [stage ])
9299 consumer_phases [stage ] ^= 1
93100 self .wgmma .fence ()
0 commit comments