Skip to content

Commit 33b7e09

Browse files
committed
wip
wip wip wip wip wip
1 parent 818bd30 commit 33b7e09

File tree

37 files changed

+459
-217
lines changed

37 files changed

+459
-217
lines changed

examples/blackwell_matmul/matmul_v0.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ def __call__(
7070
phase ^= 1
7171

7272
# load the result from tensor memory to register
73-
r_acc = self.tcgen05.load(
74-
t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
75-
)
73+
r_acc = self.tcgen05.load(t_acc)
7674

7775
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
7876
self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

examples/blackwell_matmul/matmul_v1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ def __call__(
7979
phase ^= 1
8080

8181
# load the result from tensor memory to register
82-
r_acc = self.tcgen05.load(
83-
t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
84-
)
82+
r_acc = self.tcgen05.load(t_acc)
8583

8684
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
8785
self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

examples/blackwell_matmul/matmul_v2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ def __call__(
108108
self.sync()
109109

110110
# load the result from tensor memory to register
111-
r_acc = self.tcgen05.load(
112-
t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
113-
)
111+
r_acc = self.tcgen05.load(t_acc)
114112

115113
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
116114
self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

examples/blackwell_matmul/matmul_v3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,7 @@ def __call__(
109109
self.sync()
110110

111111
# load the result from tensor memory to register
112-
r_acc = self.tcgen05.load(
113-
t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
114-
)
112+
r_acc = self.tcgen05.load(t_acc)
115113

116114
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
117115
self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

examples/blackwell_matmul/matmul_v4.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,7 @@ def __call__(
210210
self.sync()
211211

212212
# load the result from tensor memory to register
213-
r_acc = self.tcgen05.load(
214-
mma_worker.t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
215-
)
213+
r_acc = self.tcgen05.load(mma_worker.t_acc)
216214

217215
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
218216
self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ ignore = [
9292
convention = "numpy"
9393

9494
[tool.ruff.lint.per-file-ignores]
95-
"__init__.py" = ["F401"]
95+
"__init__.py" = [
96+
"F401" # checks for unused imports.
97+
]
9698
"examples/**/*.py" = [
9799
"D400",
98100
"D205", # 1 blank line required between summary line and description

python/tilus/backends/emitters/cuda/tcgen05/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from . import allocation, copy, ldst, mma, sync
15+
from . import alloc, copy, ldst, mma, slice, sync

python/tilus/backends/emitters/cuda/tcgen05/allocation.py renamed to python/tilus/backends/emitters/cuda/tcgen05/alloc.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,27 @@
3030
Tcgen05AllocInst,
3131
Tcgen05DeallocInst,
3232
Tcgen05RelinquishAllocPermitInst,
33-
Tcgen05SliceInst,
3433
Tcgen05ViewInst,
3534
)
3635
from tilus.ir.tensor import TMemoryTensor
3736
from tilus.target import nvgpu_sm100
38-
39-
# tmem addr: 0xAAAABBBB where AAAA is the lane index and BBBB is the column index
40-
# lane index: 0x0000 to 0x007F
41-
# column index: 0x0000 to 0x01FF
42-
LANE_STRIDE = 0x00010000
43-
COLUMN_STRIDE = 0x00000001
37+
from tilus.utils import prod, same_list
4438

4539

4640
class Tcgen05AllocDeallocEmitter(BaseInstEmitter):
4741
def get_num_columns(self, tmem_tensor: TMemoryTensor) -> int:
48-
assert tmem_tensor.shape[0] == 128
49-
assert tmem_tensor.shape[1] * tmem_tensor.dtype.nbits % 32 == 0
50-
num_columns = tmem_tensor.shape[1] * tmem_tensor.dtype.nbits // 32
51-
assert num_columns % 32 == 0 and 32 <= num_columns <= 512, num_columns
42+
shape = tmem_tensor.shape
43+
if shape[-2] != 128:
44+
raise NotImplementedError(f"The emitter currently only supports shape[-2] == 128, but got {shape[-2]}")
45+
if shape[-1] * tmem_tensor.dtype.nbits % 32 != 0:
46+
raise ValueError(
47+
f"shape[-1] * dtype.nbits must be divisible by 32, but got {shape[-1]} * {tmem_tensor.dtype.nbits} = {shape[-1] * tmem_tensor.dtype.nbits}"
48+
)
49+
num_columns = prod(shape[:-2]) * shape[-1] * tmem_tensor.dtype.nbits // 32
50+
if not (num_columns % 32 == 0 and 32 <= num_columns <= 512):
51+
raise ValueError(
52+
f"The number of 32-bit columns must be a multiple of 32 and in range [32, 512], but got {num_columns}"
53+
)
5254
return num_columns
5355

5456

@@ -122,32 +124,23 @@ def emit(self, inst: Tcgen05RelinquishAllocPermitInst) -> None:
122124
self.append(tcgen05_relinquish_alloc_permit(Tcgen05CtaGroupKind.from_int(inst.cta_group)))
123125

124126

125-
@register_emitter(Tcgen05SliceInst, target=nvgpu_sm100)
126-
class TMemorySliceEmitter(BaseInstEmitter):
127-
def emit(self, inst: Tcgen05SliceInst) -> None:
128-
tmem_tensor = inst.inputs[0].as_tmemory_tensor()
129-
output_tmem_tensor = inst.tmemory_output
130-
tmem_addr = self.get_or_allocate_var(tmem_tensor)
131-
132-
sliced_addr = self.get_or_allocate_var(output_tmem_tensor, name="tmem_slice")
133-
self.assign(
134-
sliced_addr,
135-
tmem_addr + inst.offsets[0] * LANE_STRIDE + inst.offsets[1] * COLUMN_STRIDE * tmem_tensor.dtype.nbits // 32,
136-
)
137-
138-
139127
@register_emitter(Tcgen05ViewInst, target=nvgpu_sm100)
140128
class TMemoryViewEmitter(BaseInstEmitter):
141129
def emit(self, inst: Tcgen05ViewInst) -> None:
142130
tmem_tensor = inst.inputs[0].as_tmemory_tensor()
143131
output_tmem_tensor = inst.tmemory_output
144132

145133
if (
146-
tmem_tensor.dtype.nbits * tmem_tensor.shape[1]
147-
!= output_tmem_tensor.dtype.nbits * output_tmem_tensor.shape[1]
134+
tmem_tensor.dtype.nbits * tmem_tensor.shape[-1]
135+
!= output_tmem_tensor.dtype.nbits * output_tmem_tensor.shape[-1]
148136
):
149137
raise ValueError("The total number of bits must be the same as the original tensor.")
150138

139+
if not same_list(tmem_tensor.layout.column_strides[:-2], output_tmem_tensor.layout.column_strides[:-2]):
140+
raise ValueError(
141+
"The column strides of the leading dimensions (all dimensions except the last two ones) must be the same as the original tensor."
142+
)
143+
151144
tmem_addr = self.get_or_allocate_var(tmem_tensor)
152145
view_addr = self.get_or_allocate_var(output_tmem_tensor, name="tmem_view")
153146
self.assign(view_addr, tmem_addr)

python/tilus/backends/emitters/cuda/tcgen05/copy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from hidet.ir.expr import Expr
2323

2424
from tilus.backends.emitter import BaseInstEmitter, register_emitter
25-
from tilus.backends.emitters.cuda.tcgen05.allocation import COLUMN_STRIDE, LANE_STRIDE
2625
from tilus.backends.emitters.cuda.tcgen05.smem_desc import SharedMatrixDescriptor
2726
from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import (
27+
COLUMN_STRIDE,
28+
LANE_STRIDE,
2829
Tcgen05CopyMulticastKind,
2930
Tcgen05CopyShapeKind,
3031
Tcgen05CtaGroupKind,
@@ -200,10 +201,12 @@ def emit(self, inst: Tcgen05CopyInst) -> None:
200201
self.check_warp_group()
201202

202203
if len(shared_tensor.shape) != 2:
203-
raise ValueError("The shared tensor must be a 2D tensor")
204+
raise ValueError("The shared tensor must be a 2D tensor, got shape {}".format(shared_tensor.shape))
205+
if len(tmem_tensor.shape) != 2:
206+
raise ValueError("The tensor memory tensor must be a 2D tensor, got shape {}".format(tmem_tensor.shape))
204207
if shared_tensor.shape[0] != 128:
205208
raise NotImplementedError("The number of rows in the shared tensor must be 128")
206-
if tmem_tensor.first_lane != 0:
209+
if tmem_tensor.layout.lane_offset != 0:
207210
raise NotImplementedError("The first lane of the tmem tensor must be 0")
208211

209212
tmem_base_addr = self.tensor2var[tmem_tensor]

python/tilus/backends/emitters/cuda/tcgen05/ldst.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
# limitations under the License.
1515

1616
from dataclasses import dataclass
17-
from typing import Sequence
1817

1918
from hidet.ir.dtypes import int32, uint32
2019
from hidet.ir.expr import Expr, cast
2120

2221
from tilus.backends.emitter import BaseInstEmitter, register_emitter
2322
from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import (
23+
COLUMN_STRIDE,
24+
LANE_STRIDE,
2425
Tcgen05LoadStoreNumKind,
2526
Tcgen05LoadStorePackKind,
2627
Tcgen05LoadStoreShapeKind,
@@ -41,12 +42,6 @@
4142
from tilus.target import nvgpu_sm100
4243
from tilus.utils import gcd
4344

44-
# tmem addr: 0xAAAABBBB where AAAA is the lane index and BBBB is the column index
45-
# lane index: 0x0000 to 0x007F
46-
# column index: 0x0000 to 0x01FF
47-
LANE_STRIDE = 0x00010000
48-
COLUMN_STRIDE = 0x00000001
49-
5045

5146
@dataclass
5247
class LoadStoreWarpInst:
@@ -58,24 +53,6 @@ class LoadStoreWarpInst:
5853

5954

6055
class TMemoryLoadStoreBaseEmitter(BaseInstEmitter):
61-
def slice_tmem_tensor(
62-
self, tmem_tensor: TMemoryTensor, offsets: Sequence[int], shape: Sequence[int]
63-
) -> tuple[TMemoryTensor, Expr]:
64-
if any(not isinstance(ofs, int) for ofs in offsets):
65-
raise ValueError("All offsets must be integer constants")
66-
if len(offsets) != 2:
67-
raise ValueError("The length of offsets must be 2")
68-
if len(shape) != 2:
69-
raise ValueError("The length of shape must be 2")
70-
tmem_addr = self.get_or_allocate_var(tmem_tensor)
71-
sliced_tmem_tensor = TMemoryTensor.create(
72-
dtype=tmem_tensor.dtype, shape=shape, first_lane=tmem_tensor.first_lane + offsets[0]
73-
)
74-
sliced_tmem_addr = (
75-
tmem_addr + offsets[0] * LANE_STRIDE + offsets[1] * COLUMN_STRIDE * tmem_tensor.dtype.nbits // 32
76-
)
77-
return sliced_tmem_tensor, sliced_tmem_addr
78-
7956
def emit_tcgen05_inst(self, inst: LoadStoreWarpInst) -> None:
8057
raise NotImplementedError("Subclasses must implement this method")
8158

@@ -87,11 +64,11 @@ def emit_tcgen05_instructions(
8764
) -> None:
8865
if self.current_num_threads % 32 != 0:
8966
raise ValueError("The number of threads in the current thread group must be divisible by 32")
90-
if self.current_thread_group_begin % 128 != tmem_tensor.first_lane:
67+
if self.current_thread_group_begin % 128 != tmem_tensor.layout.lane_offset:
9168
raise ValueError(
9269
"Lane mismatch: the first lane of the tmem tensor must be the same as the thread group begin"
9370
)
94-
if self.current_num_threads != tmem_tensor.shape[0]:
71+
if self.current_num_threads != tmem_tensor.shape[-2]:
9572
raise ValueError(
9673
"The number of threads in the current thread group must be the same as the number of lanes in the tmem tensor"
9774
)
@@ -174,8 +151,7 @@ class TMemoryLoadEmitter(TMemoryLoadStoreBaseEmitter):
174151
def emit(self, inst: Tcgen05LoadInst) -> None:
175152
regs_tensor = inst.register_output
176153
tmem_tensor = inst.inputs[0].as_tmemory_tensor()
177-
sliced_tmem_tensor, sliced_tmem_addr = self.slice_tmem_tensor(tmem_tensor, inst.offsets, regs_tensor.shape)
178-
self.emit_tcgen05_instructions(regs_tensor, sliced_tmem_tensor, sliced_tmem_addr)
154+
self.emit_tcgen05_instructions(regs_tensor, tmem_tensor, self.tensor2var[tmem_tensor])
179155

180156
def emit_tcgen05_inst(self, inst: LoadStoreWarpInst) -> None:
181157
self.append(
@@ -195,11 +171,10 @@ def emit(self, inst: Tcgen05StoreInst) -> None:
195171
regs_tensor = inst.inputs[1].as_register_tensor()
196172
tmem_tensor = inst.inputs[0].as_tmemory_tensor()
197173

198-
sliced_tmem_tensor, sliced_tmem_addr = self.slice_tmem_tensor(tmem_tensor, inst.offsets, regs_tensor.shape)
199174
self.emit_tcgen05_instructions(
200175
regs_tensor,
201-
sliced_tmem_tensor,
202-
sliced_tmem_addr,
176+
tmem_tensor,
177+
self.tensor2var[tmem_tensor],
203178
)
204179

205180
def emit_tcgen05_inst(self, inst: LoadStoreWarpInst) -> None:

0 commit comments

Comments
 (0)