|
5 | 5 | import tilus |
6 | 6 | import torch |
7 | 7 | from tilus import float16, float32, int32, uint32 |
8 | | -from tilus.ir.tensor import RegisterTensor |
| 8 | +from tilus.ir.tensor import GlobalTensor, RegisterTensor |
9 | 9 | from tilus.utils import benchmark_func, cdiv |
10 | 10 |
|
11 | 11 |
|
12 | | -class PipelineState(tilus.State): |
| 12 | +class Pipeline(tilus.Class): |
13 | 13 | def __init__( |
14 | 14 | self, num_stages: int, producer_arrive_count: int, consumer_arrive_count: int |
15 | 15 | ): |
@@ -50,6 +50,113 @@ def consumer_release_barrier(self) -> RegisterTensor: |
50 | 50 | return self.full_barriers[self.consumer_stage] |
51 | 51 |
|
52 | 52 |
|
| 53 | +class BlockInfo(tilus.Class): |
| 54 | + def __init__( |
| 55 | + self, |
| 56 | + m_size: int32, |
| 57 | + n_size: int, |
| 58 | + k_size: int, |
| 59 | + block_m: int, |
| 60 | + block_n: int, |
| 61 | + block_k: int, |
| 62 | + offset_m: int32, |
| 63 | + offset_n: int32, |
| 64 | + ): |
| 65 | + self.m_size: int32 = m_size |
| 66 | + self.n_size: int = n_size |
| 67 | + self.k_size: int = k_size |
| 68 | + self.block_m: int = block_m |
| 69 | + self.block_n: int = block_n |
| 70 | + self.block_k: int = block_k |
| 71 | + self.offset_m: int32 = offset_m |
| 72 | + self.offset_n: int32 = offset_n |
| 73 | + |
| 74 | + |
| 75 | +class LoadPipeline(Pipeline): |
| 76 | + def __init__( |
| 77 | + self, |
| 78 | + num_stages: int, |
| 79 | + info: BlockInfo, |
| 80 | + ): |
| 81 | + super().__init__( |
| 82 | + num_stages=num_stages, producer_arrive_count=2, consumer_arrive_count=1 |
| 83 | + ) |
| 84 | + self.info: BlockInfo = info |
| 85 | + self.s_a = self.shared_tensor( |
| 86 | + dtype=float16, shape=[num_stages, info.block_m, info.block_k] |
| 87 | + ) |
| 88 | + self.s_b = self.shared_tensor( |
| 89 | + dtype=float16, shape=[num_stages, info.block_n, info.block_k] |
| 90 | + ) |
| 91 | + |
| 92 | + |
| 93 | +class LoadWorker(tilus.Class): |
| 94 | + def __init__( |
| 95 | + self, pipe: LoadPipeline, g_a: GlobalTensor, g_b: GlobalTensor, info: BlockInfo |
| 96 | + ): |
| 97 | + self.pipe: LoadPipeline = pipe |
| 98 | + self.g_a: GlobalTensor = g_a |
| 99 | + self.g_b: GlobalTensor = g_b |
| 100 | + self.info: BlockInfo = info |
| 101 | + |
| 102 | + def async_run(self): |
| 103 | + pipe, g_a, g_b, info = self.pipe, self.g_a, self.g_b, self.info |
| 104 | + s_a, s_b = pipe.s_a, pipe.s_b |
| 105 | + num_stages: int = pipe.num_stages |
| 106 | + with self.thread_group(thread_begin=0, num_threads=32): |
| 107 | + for offset_k in self.range(0, info.k_size, info.block_k, unroll=num_stages): |
| 108 | + self.pipe.producer_acquire() |
| 109 | + with self.single_thread(): |
| 110 | + self.tma.global_to_shared( |
| 111 | + src=g_a, |
| 112 | + dst=s_a[pipe.producer_stage], |
| 113 | + offsets=[info.offset_m, offset_k], |
| 114 | + mbarrier=pipe.producer_release_barrier(), |
| 115 | + ) |
| 116 | + self.tma.global_to_shared( |
| 117 | + src=g_b, |
| 118 | + dst=s_b[pipe.producer_stage], |
| 119 | + offsets=[info.offset_n, offset_k], |
| 120 | + mbarrier=pipe.producer_release_barrier(), |
| 121 | + ) |
| 122 | + pipe.producer_advance() |
| 123 | + |
| 124 | + # remaining mma stages to wait for completion |
| 125 | + for _ in self.range(min(num_stages, cdiv(info.k_size, info.block_k))): |
| 126 | + pipe.producer_acquire() |
| 127 | + pipe.producer_advance() |
| 128 | + |
| 129 | + |
| 130 | +class MmaWorker(tilus.Class): |
| 131 | + def __init__(self, pipe: LoadPipeline, info: BlockInfo): |
| 132 | + self.pipe: LoadPipeline = pipe |
| 133 | + self.info: BlockInfo = info |
| 134 | + self.t_acc = self.tcgen05.alloc( |
| 135 | + dtype=float32, shape=[info.block_m, info.block_n], init=0.0 |
| 136 | + ) |
| 137 | + |
| 138 | + def async_run(self): |
| 139 | + pipe = self.pipe |
| 140 | + s_a, s_b = pipe.s_a, pipe.s_b |
| 141 | + num_stages: int = pipe.num_stages |
| 142 | + with self.thread_group(thread_begin=32, num_threads=32): |
| 143 | + for offset_k in self.range( |
| 144 | + 0, self.info.k_size, self.info.block_k, unroll=num_stages |
| 145 | + ): |
| 146 | + pipe.consumer_acquire() |
| 147 | + with self.single_thread(): |
| 148 | + self.tcgen05.mma( |
| 149 | + s_a[pipe.consumer_stage], |
| 150 | + s_b[pipe.consumer_stage].transpose(), |
| 151 | + self.t_acc, |
| 152 | + ) |
| 153 | + self.tcgen05.commit(mbarrier=pipe.consumer_release_barrier()) |
| 154 | + pipe.consumer_advance() |
| 155 | + |
| 156 | + def dealloc(self): |
| 157 | + self.tcgen05.dealloc(self.t_acc) |
| 158 | + |
| 159 | + |
53 | 160 | @tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256]]) |
54 | 161 | @tilus.autotune("block_k", [16, 32, 64]) |
55 | 162 | @tilus.autotune("stages", [2, 3, 4]) |
@@ -78,79 +185,41 @@ def __call__( |
78 | 185 |
|
79 | 186 | g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) |
80 | 187 | g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) |
81 | | - s_a = self.shared_tensor( |
82 | | - dtype=float16, shape=[self.stages, self.block_m, self.block_k] |
83 | | - ) |
84 | | - s_b = self.shared_tensor( |
85 | | - dtype=float16, shape=[self.stages, self.block_n, self.block_k] |
86 | | - ) |
87 | | - |
88 | | - # allocate a tensor in tensor memory (tmem) |
89 | | - t_acc = self.tcgen05.alloc( |
90 | | - dtype=float32, shape=[self.block_m, self.block_n], init=0.0 |
91 | | - ) |
92 | 188 |
|
93 | | - state = PipelineState( |
94 | | - num_stages=self.stages, |
95 | | - producer_arrive_count=2, |
96 | | - consumer_arrive_count=1, |
| 189 | + info = BlockInfo( |
| 190 | + m_size=m_size, |
| 191 | + n_size=n_size, |
| 192 | + k_size=k_size, |
| 193 | + block_m=self.block_m, |
| 194 | + block_n=self.block_n, |
| 195 | + block_k=self.block_k, |
| 196 | + offset_m=self.block_m * self.blockIdx.x, |
| 197 | + offset_n=self.block_n * self.blockIdx.y, |
97 | 198 | ) |
98 | 199 |
|
99 | | - with self.thread_group(thread_begin=0, num_threads=32): |
100 | | - # producer |
101 | | - for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages): |
102 | | - # self.printf("[%d][%d] producer acquring offset_k=%d\n", self.blockIdx.x, state.producer_stage, offset_k) |
103 | | - state.producer_acquire() |
104 | | - # self.printf("[%d][%d] producer acquired offset_k=%d\n", self.blockIdx.x, state.producer_stage, offset_k) |
105 | | - with self.single_thread(): |
106 | | - self.tma.global_to_shared( |
107 | | - src=g_a, |
108 | | - dst=s_a[state.producer_stage], |
109 | | - offsets=[offset_m, offset_k], |
110 | | - mbarrier=state.producer_release_barrier(), |
111 | | - ) |
112 | | - self.tma.global_to_shared( |
113 | | - src=g_b, |
114 | | - dst=s_b[state.producer_stage], |
115 | | - offsets=[offset_n, offset_k], |
116 | | - mbarrier=state.producer_release_barrier(), |
117 | | - ) |
118 | | - # self.printf("[%d][%d] producer produced offset_k=%d\n", self.blockIdx.x, state.producer_stage, offset_k) |
119 | | - state.producer_advance() |
| 200 | + pipe = LoadPipeline(num_stages=self.stages, info=info) |
| 201 | + load_worker = LoadWorker(pipe, g_a, g_b, info) |
| 202 | + mma_worker = MmaWorker(pipe, info) |
120 | 203 |
|
121 | | - # remaining mma stages to wait for completion |
122 | | - for _ in self.range(min(self.stages, cdiv(k_size, self.block_k))): |
123 | | - state.producer_acquire() |
124 | | - state.producer_advance() |
| 204 | + # producer |
| 205 | + load_worker.async_run() |
125 | 206 |
|
126 | | - with self.thread_group(thread_begin=32, num_threads=32): |
127 | | - for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages): |
128 | | - # self.printf("[%d][%d] consumer acquring offset_k=%d\n", self.blockIdx.x, state.consumer_stage, offset_k) |
129 | | - state.consumer_acquire() |
130 | | - # self.printf("[%d][%d] consumer acquired offset_k=%d\n", self.blockIdx.x, state.consumer_stage, offset_k) |
131 | | - with self.single_thread(): |
132 | | - self.tcgen05.mma( |
133 | | - s_a[state.consumer_stage], |
134 | | - s_b[state.consumer_stage].transpose(), |
135 | | - t_acc, |
136 | | - ) |
137 | | - self.tcgen05.commit(mbarrier=state.consumer_release_barrier()) |
138 | | - # self.printf("[%d][%d] consumer consumed offset_k=%d\n", self.blockIdx.x, state.consumer_stage, offset_k) |
139 | | - state.consumer_advance() |
| 207 | + # consumer |
| 208 | + mma_worker.async_run() |
140 | 209 |
|
141 | 210 | self.sync() |
142 | 211 |
|
143 | 212 | # load the result from tensor memory to register |
144 | 213 | r_acc = self.tcgen05.load( |
145 | | - t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n] |
| 214 | + mma_worker.t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n] |
146 | 215 | ) |
147 | 216 |
|
148 | 217 | g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) |
149 | 218 | self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n]) |
150 | 219 |
|
151 | 220 | # all allocated tensor memory must be deallocated |
152 | 221 | self.sync() |
153 | | - self.tcgen05.dealloc(t_acc) |
| 222 | + mma_worker.dealloc() |
154 | 223 |
|
155 | 224 |
|
156 | 225 | def main(bench=True): |
|
0 commit comments