Skip to content

Commit 2b74996

Browse files
authored
[Example] Add example of blackwell matmul v3 (#66)
This PR adds the matmul implementation of tilus for blackwell architecture, with warp specialization optimization. This version of matmul achieves over 85% performance of cublas. ``` m n k name latency (ms) tflops 0 4096 4096 4096 torch 0.160752 854.975085 1 4096 4096 4096 tilus 0.179200 766.958441 2 4096 4096 14336 torch 0.460816 1043.879426 3 4096 4096 14336 tilus 0.494528 972.718110 4 8192 8192 8192 torch 0.909792 1208.530764 5 8192 8192 8192 tilus 1.022016 1075.826186 6 10240 10240 10240 torch 1.710224 1255.673881 7 10240 10240 10240 tilus 1.996880 1075.419481 ``` ```python @tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256]]) @tilus.autotune("block_k", [16, 32, 64]) @tilus.autotune("stages", [2, 3, 4]) class BlackwellMatmul(tilus.Script): def __init__(self, block_m: int, block_n: int, block_k: int, stages: int): super().__init__() self.block_m = block_m self.block_n = block_n self.block_k = block_k self.stages = stages def __call__( self, m_size: int32, n_size: int, k_size: int, a_ptr: ~float16, b_ptr: ~float16, c_ptr: ~float16, ): self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)] self.attrs.warps = 4 offset_m: int32 = self.block_m * self.blockIdx.x offset_n: int32 = self.block_n * self.blockIdx.y g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) s_a = self.shared_tensor( dtype=float16, shape=[self.stages, self.block_m, self.block_k] ) s_b = self.shared_tensor( dtype=float16, shape=[self.stages, self.block_n, self.block_k] ) # allocate a tensor in tensor memory (tmem) t_acc = self.tcgen05.alloc( dtype=float32, shape=[self.block_m, self.block_n], init=0.0 ) # allocate barriers and the initial phases consumer_barriers = self.mbarrier.alloc( count=[1 for _ in range(self.stages)] ) # whether the data is ready for consumption producer_barriers = self.mbarrier.alloc( count=[1 for _ in range(self.stages)] ) # whether the data is ready to be filled with self.thread_group(group_index=0, group_size=32): # tma warp stage: int32 = 0 producer_phases = self.register_tensor( dtype=uint32, shape=[self.stages], init=1 ) # all stages are ready to be filled at the beginning for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages): self.mbarrier.wait( producer_barriers[stage], phase=producer_phases[stage] ) # wait until the stage is ready to be filled producer_phases[stage] ^= 1 with self.single_thread(): self.tma.global_to_shared( src=g_a, dst=s_a[stage], offsets=[offset_m, offset_k], mbarrier=consumer_barriers[stage], ) self.tma.global_to_shared( src=g_b, dst=s_b[stage], offsets=[offset_n, offset_k], mbarrier=consumer_barriers[stage], ) self.mbarrier.arrive(consumer_barriers[stage]) stage = (stage + 1) % self.stages # remaining mma stages to wait for completion for _ in self.range(min(self.stages, cdiv(k_size, self.block_k))): self.mbarrier.wait( producer_barriers[stage], phase=producer_phases[stage] ) # wait until the stage is ready to be filled producer_phases[stage] ^= 1 stage = (stage + 1) % self.stages with self.thread_group(group_index=1, group_size=32): # mma warp consumer_phases = self.register_tensor( dtype=uint32, shape=[self.stages], init=0 ) # all stages are not ready for consumption at the beginning stage: int32 = 0 for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages): self.mbarrier.wait( consumer_barriers[stage], phase=consumer_phases[stage] ) # wait until the stage is ready for consumption consumer_phases[stage] ^= 1 with self.single_thread(): self.tcgen05.mma(s_a[stage], s_b[stage].transpose(), t_acc) self.tcgen05.commit(mbarrier=producer_barriers[stage]) stage = (stage + 1) % self.stages self.sync() # load the result from tensor memory to register r_acc = self.tcgen05.load( t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n] ) g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n]) # all allocated tensor memory must be deallocated self.sync() self.tcgen05.dealloc(t_acc) ``` Minors: 1. enhance the `scripts/sign-commits.py` utility to only rebase from the unsigned commit, instead of the main branch. 2. add `BarrierAllocContext` and `SyncContext` in codegen to unify mbarrier allocation and sub-group syncrhonization (that based on mbarrier). --------- Signed-off-by: Yaoyao Ding <[email protected]>
1 parent 4c33850 commit 2b74996

File tree

17 files changed

+488
-171
lines changed

17 files changed

+488
-171
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import os
4+
5+
import pandas
6+
import tilus
7+
import torch
8+
from tilus import float16, float32, int32, uint32
9+
from tilus.utils import benchmark_func, cdiv
10+
11+
if not tilus.target.get_current_target().supports(tilus.target.nvgpu_sm100a):
12+
# skip this example if the current target does not support nvgpu_sm100a
13+
exit(0)
14+
15+
tilus.option.cache_dir(os.path.join(os.path.dirname(__file__), "cache"))
16+
tilus.option.debug.dump_ir()
17+
18+
19+
@tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256]])
20+
@tilus.autotune("block_k", [16, 32, 64])
21+
@tilus.autotune("stages", [2, 3, 4])
22+
class BlackwellMatmul(tilus.Script):
23+
def __init__(self, block_m: int, block_n: int, block_k: int, stages: int):
24+
super().__init__()
25+
self.block_m = block_m
26+
self.block_n = block_n
27+
self.block_k = block_k
28+
self.stages = stages
29+
30+
def __call__(
31+
self,
32+
m_size: int32,
33+
n_size: int,
34+
k_size: int,
35+
a_ptr: ~float16,
36+
b_ptr: ~float16,
37+
c_ptr: ~float16,
38+
):
39+
self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
40+
self.attrs.warps = 4
41+
42+
offset_m: int32 = self.block_m * self.blockIdx.x
43+
offset_n: int32 = self.block_n * self.blockIdx.y
44+
45+
g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
46+
g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
47+
s_a = self.shared_tensor(
48+
dtype=float16, shape=[self.stages, self.block_m, self.block_k]
49+
)
50+
s_b = self.shared_tensor(
51+
dtype=float16, shape=[self.stages, self.block_n, self.block_k]
52+
)
53+
54+
# allocate a tensor in tensor memory (tmem)
55+
t_acc = self.tcgen05.alloc(
56+
dtype=float32, shape=[self.block_m, self.block_n], init=0.0
57+
)
58+
59+
# allocate barriers and the initial phases
60+
consumer_barriers = self.mbarrier.alloc(
61+
count=[1 for _ in range(self.stages)]
62+
) # whether the data is ready for consumption
63+
producer_barriers = self.mbarrier.alloc(
64+
count=[1 for _ in range(self.stages)]
65+
) # whether the data is ready to be filled
66+
67+
with self.thread_group(group_index=0, group_size=32):
68+
# tma warp
69+
stage: int32 = 0
70+
producer_phases = self.register_tensor(
71+
dtype=uint32, shape=[self.stages], init=1
72+
) # all stages are ready to be filled at the beginning
73+
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
74+
self.mbarrier.wait(
75+
producer_barriers[stage], phase=producer_phases[stage]
76+
) # wait until the stage is ready to be filled
77+
producer_phases[stage] ^= 1
78+
with self.single_thread():
79+
self.tma.global_to_shared(
80+
src=g_a,
81+
dst=s_a[stage],
82+
offsets=[offset_m, offset_k],
83+
mbarrier=consumer_barriers[stage],
84+
)
85+
self.tma.global_to_shared(
86+
src=g_b,
87+
dst=s_b[stage],
88+
offsets=[offset_n, offset_k],
89+
mbarrier=consumer_barriers[stage],
90+
)
91+
self.mbarrier.arrive(consumer_barriers[stage])
92+
stage = (stage + 1) % self.stages
93+
94+
# remaining mma stages to wait for completion
95+
for _ in self.range(min(self.stages, cdiv(k_size, self.block_k))):
96+
self.mbarrier.wait(
97+
producer_barriers[stage], phase=producer_phases[stage]
98+
) # wait until the stage is ready to be filled
99+
producer_phases[stage] ^= 1
100+
stage = (stage + 1) % self.stages
101+
102+
with self.thread_group(group_index=1, group_size=32):
103+
# mma warp
104+
consumer_phases = self.register_tensor(
105+
dtype=uint32, shape=[self.stages], init=0
106+
) # all stages are not ready for consumption at the beginning
107+
stage: int32 = 0
108+
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
109+
self.mbarrier.wait(
110+
consumer_barriers[stage], phase=consumer_phases[stage]
111+
) # wait until the stage is ready for consumption
112+
consumer_phases[stage] ^= 1
113+
with self.single_thread():
114+
self.tcgen05.mma(s_a[stage], s_b[stage].transpose(), t_acc)
115+
self.tcgen05.commit(mbarrier=producer_barriers[stage])
116+
stage = (stage + 1) % self.stages
117+
118+
self.sync()
119+
120+
# load the result from tensor memory to register
121+
r_acc = self.tcgen05.load(
122+
t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
123+
)
124+
125+
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
126+
self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])
127+
128+
# all allocated tensor memory must be deallocated
129+
self.sync()
130+
self.tcgen05.dealloc(t_acc)
131+
132+
133+
def main(bench=True):
134+
matmul = BlackwellMatmul()
135+
136+
headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
137+
rows = []
138+
139+
for m_size, n_size, k_size in [
140+
[4096, 4096, 4096],
141+
[4096, 4096, 14336],
142+
[8192, 8192, 8192],
143+
[10240, 10240, 10240],
144+
]:
145+
print(f"Running with m_size={m_size}, n_size={n_size}, k_size={k_size}")
146+
a = torch.randn(m_size, k_size, dtype=torch.float16, device="cuda")
147+
b = torch.randn(n_size, k_size, dtype=torch.float16, device="cuda")
148+
c = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda")
149+
150+
matmul(m_size, n_size, k_size, a, b, c)
151+
torch.cuda.synchronize()
152+
153+
c_ref = a @ b.T
154+
155+
torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=1e-2)
156+
157+
# benchmark
158+
if bench:
159+
for name, func in [
160+
("torch", lambda: a @ b.T),
161+
("tilus", lambda: matmul(m_size, n_size, k_size, a, b, c)),
162+
]:
163+
latency = benchmark_func(func, warmup=5, repeat=20)
164+
tflops = 2 * m_size * n_size * k_size / latency * 1e-9
165+
rows.append([m_size, n_size, k_size, name, latency, tflops])
166+
167+
if bench:
168+
df = pandas.DataFrame(rows, columns=headers)
169+
print(df)
170+
171+
172+
if __name__ == "__main__":
173+
main(bench=True)
174+
# ncu_run(main, bench=False, kernel_regex="hidet|nvjet")

python/tilus/backends/context.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def __init__(self, codegen):
3535
def __post_init__(self):
3636
pass
3737

38+
@property
39+
def contexts(self):
40+
return self.codegen.contexts
41+
3842
def host_prepend(self, stmt: Expr | HidetStmt) -> None:
3943
"""Prepend a statement to the host function.
4044
@@ -43,7 +47,7 @@ def host_prepend(self, stmt: Expr | HidetStmt) -> None:
4347
stmt: Expr or HidetStmt
4448
The statement to be prepended.
4549
"""
46-
self.codegen.host_builder.scope_stack[-1].insert(0, stmt)
50+
self.codegen.host_builder.scope_stack[0].insert(0, stmt)
4751

4852
def host_append(self, stmt: Expr | HidetStmt) -> None:
4953
"""Append a statement to the host function.
@@ -63,7 +67,7 @@ def kernel_prepend(self, stmt: Expr | HidetStmt) -> None:
6367
stmt: Expr or HidetStmt
6468
The statement to be prepended.
6569
"""
66-
self.codegen.builder.scope_stack[-1].insert(0, stmt)
70+
self.codegen.builder.scope_stack[0].insert(0, stmt)
6771

6872
def kernel_append(self, stmt: Expr | HidetStmt) -> None:
6973
"""Append a statement to the kernel function.

python/tilus/backends/contexts/contexts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from tilus.backends.contexts.global_view_ctx import GlobalTensorViewContext
1717
from tilus.backends.contexts.gmem_alloc_ctx import GlobalMemoryAllocationContext
1818
from tilus.backends.contexts.invariant_ctx import InvariantTrackingContext
19+
from tilus.backends.contexts.mbarrier_alloc_ctx import BarrierAllocContext
1920
from tilus.backends.contexts.smem_alloc_ctx import SharedMemoryAllocationContext
21+
from tilus.backends.contexts.sync_ctx import SyncContext
2022
from tilus.backends.contexts.tcgen05_ctx import Tcgen05EmitContext
2123

2224

@@ -32,6 +34,8 @@ def __init__(self, codegen):
3234
self.invariant_ctx: InvariantTrackingContext = InvariantTrackingContext(codegen)
3335
self.smem_alloc_ctx: SharedMemoryAllocationContext = SharedMemoryAllocationContext(codegen)
3436
self.tcgen05_ctx: Tcgen05EmitContext = Tcgen05EmitContext(codegen)
37+
self.barrier_alloc_ctx: BarrierAllocContext = BarrierAllocContext(codegen)
38+
self.sync_ctx: SyncContext = SyncContext(codegen)
3539

3640
def contexts(self) -> list[BaseEmitContext]:
3741
"""Get all contexts as a list.
@@ -56,5 +60,5 @@ def finalize(self):
5660
5761
This method is called when the codegen is finished for all instructions.
5862
"""
59-
for ctx in self.contexts():
63+
for ctx in reversed(self.contexts()):
6064
ctx.finalize()
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
from typing import Sequence
18+
19+
from hidet.ir.builders import StmtBuilder
20+
from hidet.ir.dtypes import uint32, uint64
21+
from hidet.ir.expr import Expr, Var
22+
from hidet.ir.primitives.cuda.barrier import fence_view_async_shared
23+
from hidet.ir.primitives.cuda.cvta import cvta_generic_to_shared
24+
from hidet.ir.primitives.cuda.smem import dynamic_shared_memory
25+
26+
from tilus.backends.context import BaseEmitContext
27+
from tilus.extensions.hidet.ir.primitives.cuda.mbarrier import mbarrier_init_shared
28+
from tilus.ir.layout import ops
29+
from tilus.ir.tensor import SharedTensor
30+
31+
32+
class BarrierAllocContext(BaseEmitContext):
33+
"""Context used to manage the allocation of barriers."""
34+
35+
def __post_init__(self):
36+
self.counts: list[Expr] = []
37+
self.barriers: list[Var] = []
38+
self.barrier_addr: Var = Var("barriers", type=uint32)
39+
40+
def finalize(self):
41+
# allocate shared memory for all barriers
42+
num_barriers = len(self.counts)
43+
44+
if num_barriers == 0:
45+
# No barriers to allocate
46+
return
47+
48+
tensor = SharedTensor(dtype=uint64, shape=(num_barriers,), optional_layout=ops.shared_row_major(num_barriers))
49+
virtual_smem_addr = self.contexts.smem_alloc_ctx.allocate_shared_tensor(tensor, nbytes=tensor.nbytes)
50+
sb = StmtBuilder()
51+
sb.declare(
52+
v=self.barrier_addr,
53+
init=cvta_generic_to_shared(dynamic_shared_memory(byte_offset=virtual_smem_addr, dtype=uint64)),
54+
)
55+
56+
for i in range(num_barriers):
57+
sb.declare(v=self.barriers[i], init=self.barrier_addr + uint32(i * uint64.nbytes))
58+
sb.append(mbarrier_init_shared(mbarrier_addr=self.barriers[i], arrive_count=uint32(self.counts[i])))
59+
sb.append(fence_view_async_shared())
60+
self.kernel_prepend(sb.finish())
61+
62+
def allocate_barriers(self, counts: Sequence[Expr | int]) -> list[Var]:
63+
"""
64+
Allocate a list of barriers with given counts. Each barrier is a 64-bit data structure stored in shared memory.
65+
This function returns the address of the first barrier in the shared space.
66+
"""
67+
barrier_vars = [Var("barrier_{}".format(c), type=uint32) for c in counts]
68+
self.counts.extend([uint32(c) if isinstance(c, int) else c for c in counts])
69+
self.barriers.extend(barrier_vars)
70+
return barrier_vars

0 commit comments

Comments
 (0)