Skip to content

Commit 818bd30

Browse files
authored
[Example] Update the example v4 of blackwell matmul (#75)
This PR rewrites the example v4 with `tilus.Class`. Others: 1. rename `tilus.State` to `tilus.Class`, since we also want to represent worker concept with this language construct. Using `Class` (like the `class` in C++) is better. --------- Signed-off-by: Yaoyao Ding <[email protected]>
1 parent 3e300c9 commit 818bd30

File tree

8 files changed

+437
-347
lines changed

8 files changed

+437
-347
lines changed

examples/blackwell_matmul/matmul_v4.py

Lines changed: 127 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import tilus
66
import torch
77
from tilus import float16, float32, int32, uint32
8-
from tilus.ir.tensor import RegisterTensor
8+
from tilus.ir.tensor import GlobalTensor, RegisterTensor
99
from tilus.utils import benchmark_func, cdiv
1010

1111

12-
class PipelineState(tilus.State):
12+
class Pipeline(tilus.Class):
1313
def __init__(
1414
self, num_stages: int, producer_arrive_count: int, consumer_arrive_count: int
1515
):
@@ -50,6 +50,113 @@ def consumer_release_barrier(self) -> RegisterTensor:
5050
return self.full_barriers[self.consumer_stage]
5151

5252

53+
class BlockInfo(tilus.Class):
54+
def __init__(
55+
self,
56+
m_size: int32,
57+
n_size: int,
58+
k_size: int,
59+
block_m: int,
60+
block_n: int,
61+
block_k: int,
62+
offset_m: int32,
63+
offset_n: int32,
64+
):
65+
self.m_size: int32 = m_size
66+
self.n_size: int = n_size
67+
self.k_size: int = k_size
68+
self.block_m: int = block_m
69+
self.block_n: int = block_n
70+
self.block_k: int = block_k
71+
self.offset_m: int32 = offset_m
72+
self.offset_n: int32 = offset_n
73+
74+
75+
class LoadPipeline(Pipeline):
76+
def __init__(
77+
self,
78+
num_stages: int,
79+
info: BlockInfo,
80+
):
81+
super().__init__(
82+
num_stages=num_stages, producer_arrive_count=2, consumer_arrive_count=1
83+
)
84+
self.info: BlockInfo = info
85+
self.s_a = self.shared_tensor(
86+
dtype=float16, shape=[num_stages, info.block_m, info.block_k]
87+
)
88+
self.s_b = self.shared_tensor(
89+
dtype=float16, shape=[num_stages, info.block_n, info.block_k]
90+
)
91+
92+
93+
class LoadWorker(tilus.Class):
94+
def __init__(
95+
self, pipe: LoadPipeline, g_a: GlobalTensor, g_b: GlobalTensor, info: BlockInfo
96+
):
97+
self.pipe: LoadPipeline = pipe
98+
self.g_a: GlobalTensor = g_a
99+
self.g_b: GlobalTensor = g_b
100+
self.info: BlockInfo = info
101+
102+
def async_run(self):
103+
pipe, g_a, g_b, info = self.pipe, self.g_a, self.g_b, self.info
104+
s_a, s_b = pipe.s_a, pipe.s_b
105+
num_stages: int = pipe.num_stages
106+
with self.thread_group(thread_begin=0, num_threads=32):
107+
for offset_k in self.range(0, info.k_size, info.block_k, unroll=num_stages):
108+
self.pipe.producer_acquire()
109+
with self.single_thread():
110+
self.tma.global_to_shared(
111+
src=g_a,
112+
dst=s_a[pipe.producer_stage],
113+
offsets=[info.offset_m, offset_k],
114+
mbarrier=pipe.producer_release_barrier(),
115+
)
116+
self.tma.global_to_shared(
117+
src=g_b,
118+
dst=s_b[pipe.producer_stage],
119+
offsets=[info.offset_n, offset_k],
120+
mbarrier=pipe.producer_release_barrier(),
121+
)
122+
pipe.producer_advance()
123+
124+
# remaining mma stages to wait for completion
125+
for _ in self.range(min(num_stages, cdiv(info.k_size, info.block_k))):
126+
pipe.producer_acquire()
127+
pipe.producer_advance()
128+
129+
130+
class MmaWorker(tilus.Class):
131+
def __init__(self, pipe: LoadPipeline, info: BlockInfo):
132+
self.pipe: LoadPipeline = pipe
133+
self.info: BlockInfo = info
134+
self.t_acc = self.tcgen05.alloc(
135+
dtype=float32, shape=[info.block_m, info.block_n], init=0.0
136+
)
137+
138+
def async_run(self):
139+
pipe = self.pipe
140+
s_a, s_b = pipe.s_a, pipe.s_b
141+
num_stages: int = pipe.num_stages
142+
with self.thread_group(thread_begin=32, num_threads=32):
143+
for offset_k in self.range(
144+
0, self.info.k_size, self.info.block_k, unroll=num_stages
145+
):
146+
pipe.consumer_acquire()
147+
with self.single_thread():
148+
self.tcgen05.mma(
149+
s_a[pipe.consumer_stage],
150+
s_b[pipe.consumer_stage].transpose(),
151+
self.t_acc,
152+
)
153+
self.tcgen05.commit(mbarrier=pipe.consumer_release_barrier())
154+
pipe.consumer_advance()
155+
156+
def dealloc(self):
157+
self.tcgen05.dealloc(self.t_acc)
158+
159+
53160
@tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256]])
54161
@tilus.autotune("block_k", [16, 32, 64])
55162
@tilus.autotune("stages", [2, 3, 4])
@@ -78,79 +185,41 @@ def __call__(
78185

79186
g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
80187
g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
81-
s_a = self.shared_tensor(
82-
dtype=float16, shape=[self.stages, self.block_m, self.block_k]
83-
)
84-
s_b = self.shared_tensor(
85-
dtype=float16, shape=[self.stages, self.block_n, self.block_k]
86-
)
87-
88-
# allocate a tensor in tensor memory (tmem)
89-
t_acc = self.tcgen05.alloc(
90-
dtype=float32, shape=[self.block_m, self.block_n], init=0.0
91-
)
92188

93-
state = PipelineState(
94-
num_stages=self.stages,
95-
producer_arrive_count=2,
96-
consumer_arrive_count=1,
189+
info = BlockInfo(
190+
m_size=m_size,
191+
n_size=n_size,
192+
k_size=k_size,
193+
block_m=self.block_m,
194+
block_n=self.block_n,
195+
block_k=self.block_k,
196+
offset_m=self.block_m * self.blockIdx.x,
197+
offset_n=self.block_n * self.blockIdx.y,
97198
)
98199

99-
with self.thread_group(thread_begin=0, num_threads=32):
100-
# producer
101-
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
102-
# self.printf("[%d][%d] producer acquring offset_k=%d\n", self.blockIdx.x, state.producer_stage, offset_k)
103-
state.producer_acquire()
104-
# self.printf("[%d][%d] producer acquired offset_k=%d\n", self.blockIdx.x, state.producer_stage, offset_k)
105-
with self.single_thread():
106-
self.tma.global_to_shared(
107-
src=g_a,
108-
dst=s_a[state.producer_stage],
109-
offsets=[offset_m, offset_k],
110-
mbarrier=state.producer_release_barrier(),
111-
)
112-
self.tma.global_to_shared(
113-
src=g_b,
114-
dst=s_b[state.producer_stage],
115-
offsets=[offset_n, offset_k],
116-
mbarrier=state.producer_release_barrier(),
117-
)
118-
# self.printf("[%d][%d] producer produced offset_k=%d\n", self.blockIdx.x, state.producer_stage, offset_k)
119-
state.producer_advance()
200+
pipe = LoadPipeline(num_stages=self.stages, info=info)
201+
load_worker = LoadWorker(pipe, g_a, g_b, info)
202+
mma_worker = MmaWorker(pipe, info)
120203

121-
# remaining mma stages to wait for completion
122-
for _ in self.range(min(self.stages, cdiv(k_size, self.block_k))):
123-
state.producer_acquire()
124-
state.producer_advance()
204+
# producer
205+
load_worker.async_run()
125206

126-
with self.thread_group(thread_begin=32, num_threads=32):
127-
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
128-
# self.printf("[%d][%d] consumer acquring offset_k=%d\n", self.blockIdx.x, state.consumer_stage, offset_k)
129-
state.consumer_acquire()
130-
# self.printf("[%d][%d] consumer acquired offset_k=%d\n", self.blockIdx.x, state.consumer_stage, offset_k)
131-
with self.single_thread():
132-
self.tcgen05.mma(
133-
s_a[state.consumer_stage],
134-
s_b[state.consumer_stage].transpose(),
135-
t_acc,
136-
)
137-
self.tcgen05.commit(mbarrier=state.consumer_release_barrier())
138-
# self.printf("[%d][%d] consumer consumed offset_k=%d\n", self.blockIdx.x, state.consumer_stage, offset_k)
139-
state.consumer_advance()
207+
# consumer
208+
mma_worker.async_run()
140209

141210
self.sync()
142211

143212
# load the result from tensor memory to register
144213
r_acc = self.tcgen05.load(
145-
t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
214+
mma_worker.t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
146215
)
147216

148217
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
149218
self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])
150219

151220
# all allocated tensor memory must be deallocated
152221
self.sync()
153-
self.tcgen05.dealloc(t_acc)
222+
mma_worker.dealloc()
154223

155224

156225
def main(bench=True):

python/tilus/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
from tilus.ir.tensor import RegisterTensor, SharedTensor, GlobalTensor
107107
from tilus.lang.instantiated_script import InstantiatedScript
108108
from tilus.lang.script import Script, autotune
109-
from tilus.lang.constructs.state import State
109+
from tilus.lang.constructs.state import Class
110110
from tilus.tensor import empty, from_torch, full, ones, rand, randint, randn, view_torch, zeros
111111

112112
from . import kernels, logging, option, utils, testing

python/tilus/lang/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from .constructs.state import State
15+
from .constructs.state import Class
1616
from .script import Attributes, Script

python/tilus/lang/constructs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from .state import State
15+
from .state import Class

python/tilus/lang/constructs/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@
1515
from tilus.lang.instructions import InstructionInterface
1616

1717

18-
class State(InstructionInterface):
18+
class Class(InstructionInterface):
1919
pass

0 commit comments

Comments
 (0)