Skip to content

Commit d0820b7

Browse files
committed
fmt
Signed-off-by: Qidong Su <soodoshll@gmail.com>
1 parent cc7f24c commit d0820b7

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

examples/hopper_matmul/matmul_v3.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111

1212

1313
@tilus.autotune("num_stages", [2, 3, 4, 5, 6, 7])
14-
@tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256], [256, 128], [256, 256]])
14+
@tilus.autotune(
15+
"block_m, block_n", [[128, 64], [128, 128], [128, 256], [256, 128], [256, 256]]
16+
)
1517
@tilus.autotune("block_k", [16, 32, 64])
1618
class MatmulWGMMAV3(tilus.Script):
17-
1819
def __init__(
1920
self,
2021
num_stages,
@@ -54,11 +55,15 @@ def __call__(
5455
acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0)
5556

5657
consumer_barriers = self.mbarrier.alloc(count=[2 for _ in range(self.num_stages)])
57-
producer_barriers = self.mbarrier.alloc(count=[128 for _ in range(self.num_stages)])
58+
producer_barriers = self.mbarrier.alloc(
59+
count=[128 for _ in range(self.num_stages)]
60+
)
5861

5962
with self.thread_group(thread_begin=128, num_threads=32):
6063
stage: int32 = 0
61-
producer_phases = self.register_tensor(dtype=uint32, shape=[self.num_stages], init=1)
64+
producer_phases = self.register_tensor(
65+
dtype=uint32, shape=[self.num_stages], init=1
66+
)
6267
for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages):
6368
self.mbarrier.wait(producer_barriers[stage], phase=producer_phases[stage])
6469
producer_phases[stage] ^= 1
@@ -76,7 +81,7 @@ def __call__(
7681
mbarrier=consumer_barriers[stage],
7782
)
7883
stage = (stage + 1) % self.num_stages
79-
84+
8085
for _ in self.range(min(self.num_stages, cdiv(k_size, self.block_k))):
8186
self.mbarrier.wait(
8287
producer_barriers[stage], phase=producer_phases[stage]
@@ -85,9 +90,11 @@ def __call__(
8590
stage = (stage + 1) % self.num_stages
8691

8792
with self.thread_group(thread_begin=0, num_threads=128):
88-
consumer_phases = self.register_tensor(dtype=uint32, shape=[self.num_stages], init=0)
93+
consumer_phases = self.register_tensor(
94+
dtype=uint32, shape=[self.num_stages], init=0
95+
)
8996
stage: int32 = 0
90-
for offset_k in self.range(0, k_size, block_k , unroll=self.num_stages):
97+
for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages):
9198
self.mbarrier.wait(consumer_barriers[stage], phase=consumer_phases[stage])
9299
consumer_phases[stage] ^= 1
93100
self.wgmma.fence()

0 commit comments

Comments
 (0)