Skip to content

Commit aa72a53

Browse files
committed
fix
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
1 parent 33c4ca5 commit aa72a53

File tree

5 files changed

+103
-82
lines changed

5 files changed

+103
-82
lines changed

python/tilus/backends/emitters/cuda/tcgen05/copy.py

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,32 @@
1515

1616

1717
from __future__ import annotations
18-
from typing import Optional
18+
1919
from dataclasses import dataclass
2020

2121
from hidet.ir.dtypes import uint64
2222
from hidet.ir.expr import Expr
23-
from hidet.ir.primitives.debug import printf
2423

2524
from tilus.backends.codegen import BaseInstEmitter, register_emitter
26-
from tilus.backends.emitters.cuda.tcgen05.allocation import COLUMN_STRIDE, ROW_STRIDE
25+
from tilus.backends.emitters.cuda.tcgen05.allocation import COLUMN_STRIDE, LANE_STRIDE
2726
from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import (
2827
Tcgen05CopyMulticastKind,
2928
Tcgen05CopyShapeKind,
3029
Tcgen05CtaGroupKind,
3130
tcgen05_copy,
3231
tcgen05_encode_smem_descriptor,
3332
)
33+
from tilus.extensions.hidet.ir.utils.index_transform import index_deserialize
3434
from tilus.ir.instructions.cuda.tmem import Tcgen05CopyInst
35-
from tilus.ir.layout.cuda.tcgen05_smem import CanonicalSharedLayout, canonicalize_shared_layout, Tcgen05SwizzleMode
35+
from tilus.ir.layout.cuda.tcgen05_smem import CanonicalSharedLayout, Tcgen05SwizzleMode, canonicalize_shared_layout
3636
from tilus.ir.tensor import SharedTensor, TMemoryTensor
3737
from tilus.target import nvgpu_sm100
3838

39+
3940
class GenerationFailedError(Exception):
4041
pass
4142

43+
4244
@dataclass
4345
class SharedMatrixDescriptor:
4446
start_addr: Expr | int
@@ -78,19 +80,24 @@ class Tcgen05CopyInstMeta:
7880
tmem_offset: int
7981
shared_descriptor: SharedMatrixDescriptor
8082

83+
def __str__(self) -> str:
84+
items = []
85+
for key, value in self.__dict__.items():
86+
items.append(f"{key}: {value}")
87+
return "Tcgen05CopyInstMeta(" + ",\n ".join(items) + "\n)"
88+
8189

8290
@register_emitter(Tcgen05CopyInst, target=nvgpu_sm100)
8391
class Tcgen05CopyEmitter(BaseInstEmitter):
8492
def split_canonical_layout(
85-
self,
86-
canonical: CanonicalSharedLayout,
87-
shape_kind: Tcgen05CopyShapeKind
88-
) -> Optional[list[tuple[int, SharedMatrixDescriptor]]]:
93+
self, smem_addr: Expr, canonical: CanonicalSharedLayout, shape_kind: Tcgen05CopyShapeKind
94+
) -> list[Tcgen05CopyInstMeta]:
8995
"""
9096
A shared memory tensor might be very large that we need to split it into multiple sub-tensors and
9197
each sub-tensor is copied by a tcgen05.copy instruction. The smem_addr in returned SharedMatrixDescriptor
9298
is the offset of the sub-tensor relative to the shared memory tensor in bytes.
9399
100+
Each tcgen05.copy instruction copies a sub-tensor with the following layout:
94101
+----------------+--------------------------+--------------------------------------+-------------------------------------+
95102
| Major-ness | Swizzling mode | Canonical Layout without swizzling | Swizzling on the previous column |
96103
+================+==========================+======================================+=====================================+
@@ -118,58 +125,101 @@ def split_canonical_layout(
118125
- k represents the number of repeating patterns across columns.
119126
(The table is is from: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-canonical-layouts.)
120127
128+
The definition of the canonical layout in Tilus is similar to above table, but it's different since we want to represent the layouts
129+
in a more natural and extensible way for larger tensors. See the docstring of CanonicalSharedLayout for more details.
130+
121131
Returns
122132
-------
123133
ret: Optional[list[tuple[int, SharedMatrixDescriptor]]]
124134
The list of instructions, each instruction contains the tmem_offset and shared matrix descriptor for each sub-tensor.
125135
"""
126136
cute_layout = canonical.swizzled_cute_layout.layout
127-
m, n = cute_layout.shape
137+
m, n = cute_layout.flattened_shape
128138

129139
if shape_kind.n % canonical.dtype_nbits != 0:
130-
raise GenerationFailedError("The number of columns in the shape kind must be divisible by the number of bits in the data type")
140+
raise GenerationFailedError(
141+
"The number of columns in the shape kind must be divisible by the number of bits in the data type"
142+
)
131143

132144
inst_m, inst_n = shape_kind.m, shape_kind.n // canonical.dtype_nbits
133145

134146
if m % inst_m != 0 or n % inst_n != 0:
135-
raise GenerationFailedError("The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout")
147+
raise GenerationFailedError(
148+
"The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout"
149+
)
150+
if canonical.major_kind == "MN" and (inst_m % (canonical.T * canonical.S) != 0 or inst_n % 8 != 0):
151+
raise GenerationFailedError(
152+
"The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout"
153+
)
154+
if canonical.major_kind == "K" and (inst_m % 8 != 0 or inst_n % (canonical.T * 2) != 0):
155+
raise GenerationFailedError(
156+
"The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout"
157+
)
136158

137159
num_m, num_n = m // inst_m, n // inst_n
138160
nbytes = canonical.dtype_nbits // 8
139161

162+
instructions: list[Tcgen05CopyInstMeta] = []
140163
for i in range(num_m):
141164
for j in range(num_n):
142-
tmem_offset = i * ROW_STRIDE + j * COLUMN_STRIDE
165+
tmem_offset = i * inst_m * LANE_STRIDE + j * inst_n * COLUMN_STRIDE
143166
if canonical.major_kind == "MN":
144-
assert inst_m % (canonical.T * canonical.S) == 0 and inst_n % 8 == 0
167+
if canonical.swizzle_mode == Tcgen05SwizzleMode.NO_SWIZZLE:
168+
smem_offset = (
169+
i * inst_m // (canonical.T * canonical.S) * canonical.SBO + j * inst_n // 8 * canonical.LBO
170+
) * nbytes
171+
else:
172+
smem_offset = (
173+
i * inst_m // (canonical.T * canonical.S) * canonical.LBO + j * inst_n // 8 * canonical.SBO
174+
) * nbytes
145175
s_desc = SharedMatrixDescriptor(
146-
start_addr=(i * inst_m * canonical.SBO + j * inst_n * canonical.LBO) * nbytes,
176+
start_addr=smem_addr + smem_offset,
147177
lbo=canonical.LBO * nbytes,
148178
sbo=canonical.SBO * nbytes,
149179
base_offset=0,
150-
stride_mode=0,
180+
stride_mode=0,
151181
swizzle_mode=canonical.swizzle_mode.encode(),
152182
)
153183
elif canonical.major_kind == "K":
154-
assert inst_m % 8 == 0 and inst_n % (canonical.T * canonical.S) == 0
155184
if canonical.swizzle_mode == Tcgen05SwizzleMode.NO_SWIZZLE:
156-
s_desc = SharedMatrixDescriptor(
157-
start_addr=(i * inst_m * canonical.SBO + j * inst_n * canonical.LBO) * nbytes,
158-
lbo=canonical.LBO * nbytes,
159-
sbo=canonical.SBO * nbytes,
160-
base_offset=0,
161-
stride_mode=0,
162-
swizzle_mode=canonical.swizzle_mode.encode(),
163-
)
185+
smem_offset = (
186+
i * inst_m // 8 * canonical.SBO + j * inst_n // (canonical.T * canonical.S) * canonical.LBO
187+
) * nbytes
188+
lbo = canonical.LBO * nbytes
164189
else:
165-
pass
190+
# j0, j1, j2 for shape (T, S, k)
191+
_, j1, j2 = index_deserialize(
192+
j * inst_n,
193+
(canonical.T, canonical.S, canonical.k // (canonical.T * canonical.S)),
194+
ranks=[2, 1, 0],
195+
)
196+
smem_offset = (i * inst_m // 8 * canonical.SBO + j1 * canonical.T + j2 * canonical.LBO) * nbytes
197+
lbo = 1 << 4 # assume lbo be 16 so that lbo >> 4 == 1, as required by the documentation
198+
s_desc = SharedMatrixDescriptor(
199+
start_addr=smem_addr + smem_offset,
200+
lbo=lbo,
201+
sbo=canonical.SBO * nbytes,
202+
base_offset=0,
203+
stride_mode=0,
204+
swizzle_mode=canonical.swizzle_mode.encode(),
205+
)
206+
207+
instructions.append(
208+
Tcgen05CopyInstMeta(
209+
shape_kind=shape_kind,
210+
multicast=Tcgen05CopyMulticastKind.NONE,
211+
cta_group=Tcgen05CtaGroupKind.CTA_1,
212+
tmem_offset=tmem_offset,
213+
shared_descriptor=s_desc,
214+
)
215+
)
166216

217+
return instructions
167218

168219
def generate_instructions(
169220
self, tmem_tensor: TMemoryTensor, shared_tensor: SharedTensor
170221
) -> list[Tcgen05CopyInstMeta]:
171222
dtype = shared_tensor.dtype
172-
shape = shared_tensor.shape
173223
canonical_layout: CanonicalSharedLayout | None = canonicalize_shared_layout(
174224
shared_tensor.layout, tmem_tensor.dtype
175225
)
@@ -180,52 +230,18 @@ def generate_instructions(
180230
f" shared_layout: {shared_tensor.layout}",
181231
]
182232
raise ValueError("\n".join(msg))
183-
print(f"canonical_layout: {canonical_layout}")
184-
print(f"canonical_layout.swizzled_cute_layout: {canonical_layout.swizzled_cute_layout}")
185-
print(f"canonical_layout.atom_shape: {canonical_layout.atom_shape}")
186-
print(f"canonical_layout.atom_strides: {canonical_layout.atom_strides}")
187233
smem_addr = self.shared_tensor_shared_space_addr[shared_tensor]
188-
ret = []
234+
189235
for shape_kind in [
190236
Tcgen05CopyShapeKind.R128x256B,
191237
Tcgen05CopyShapeKind.R128x128B,
192238
]:
193-
column_bits = shape_kind.as_int_tuple()[1]
194-
assert column_bits % dtype.nbits == 0
195-
column_elements = column_bits // dtype.nbits
196-
if shape[1] % column_elements != 0:
197-
continue
198-
if shape[0] != 128:
239+
try:
240+
return self.split_canonical_layout(smem_addr, canonical_layout, shape_kind)
241+
except GenerationFailedError:
199242
continue
200-
num_inst_columns = shape[1] // column_elements
201-
for inst_column in range(num_inst_columns):
202-
tmem_offset = inst_column * (column_bits // 32 * COLUMN_STRIDE)
203-
smem_offset = inst_column * (
204-
column_elements // canonical_layout.atom_shape[1] * canonical_layout.atom_strides[1] * dtype.nbytes
205-
)
206243

207-
shared_descriptor = SharedMatrixDescriptor(
208-
start_addr=(smem_addr + smem_offset),
209-
lbo=(canonical_layout.LBO * dtype.nbytes),
210-
sbo=(canonical_layout.SBO * dtype.nbytes),
211-
base_offset=0,
212-
stride_mode=0, # 0 for relative mode and 1 for absolute mode
213-
swizzle_mode=canonical_layout.swizzle_mode.encode(),
214-
)
215-
print(f"shared_descriptor: {shared_descriptor}")
216-
217-
inst_meta = Tcgen05CopyInstMeta(
218-
shape_kind=shape_kind,
219-
multicast=Tcgen05CopyMulticastKind.NONE,
220-
cta_group=Tcgen05CtaGroupKind.CTA_1,
221-
tmem_offset=tmem_offset,
222-
shared_descriptor=shared_descriptor,
223-
)
224-
ret.append(inst_meta)
225-
break
226-
else:
227-
raise ValueError("No valid instructions generated")
228-
return ret
244+
raise ValueError("No valid instructions generated")
229245

230246
def check_warp_group(self) -> None:
231247
begin = self.current_thread_group_begin
@@ -254,8 +270,6 @@ def emit(self, inst: Tcgen05CopyInst) -> None:
254270
s_desc = self.declare_var("s_desc", tp=uint64, init=inst_meta.shared_descriptor.encoded())
255271
t_addr = tmem_base_addr + inst_meta.tmem_offset
256272

257-
self.append(printf("taddr: %#08x, sdesc: %#016lx\n", t_addr, s_desc))
258-
259273
self.append(
260274
tcgen05_copy(
261275
taddr=t_addr,

python/tilus/extensions/hidet/ir/primitives/cuda/tcgen05.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def as_int_tuple(self) -> tuple[int, int]:
138138
Tcgen05CopyShapeKind.R4x128B: (4, 128),
139139
}
140140
return table[self]
141-
141+
142142
@property
143143
def n(self) -> int:
144144
return self.as_int_tuple()[1]
@@ -148,7 +148,6 @@ def m(self) -> int:
148148
return self.as_int_tuple()[0]
149149

150150

151-
152151
class Tcgen05CopyMulticastKind(Enum):
153152
NONE = ""
154153
WARP_X2_02_13 = ".warpx2_02_13"

python/tilus/ir/layout/cuda/tcgen05_smem.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from hidet.utils.py import prod
1111

1212
from tilus.ir.layout.shared_layout import SharedLayout
13-
from tilus.ir.layout.utils.cute import CuteLayout, CuteSwizzle, cute_layout, SwizzledCuteLayout, tuple_product
13+
from tilus.ir.layout.utils.cute import CuteLayout, CuteSwizzle, IntTuple, SwizzledCuteLayout, cute_layout, tuple_product
1414
from tilus.ir.utils.veceval import meshgrid, vectorized_evaluate
1515
from tilus.utils import floor_log2
1616

@@ -31,19 +31,19 @@ def encode(self) -> int:
3131
Tcgen05SwizzleMode.B64_SWIZZLE: 4,
3232
Tcgen05SwizzleMode.B128_SWIZZLE: 2,
3333
}[self]
34-
34+
3535
@property
3636
def bbits(self) -> int:
3737
return self.value[0]
38-
38+
3939
@property
4040
def mbase(self) -> int:
4141
return self.value[1]
42-
42+
4343
@property
4444
def sshift(self) -> int:
4545
return self.value[2]
46-
46+
4747
def as_cute_swizzle(self) -> CuteSwizzle:
4848
bbits, mbase, sshift = self.value
4949
return CuteSwizzle(bbits=bbits, mbase=mbase, sshift=sshift)
@@ -94,19 +94,21 @@ def __post_init__(self):
9494
atom_size = 2**self.swizzle_mode.bbits * 8 * self.T
9595
if (self.m > 1 and self.SBO % atom_size != 0) or (self.k > 1 and self.LBO % atom_size != 0):
9696
raise ValueError(f"SBO {self.SBO} and LBO {self.LBO} must be divisible by atom size: {atom_size}")
97-
97+
9898
@property
9999
def S(self) -> int:
100100
return 2**self.swizzle_mode.bbits
101-
101+
102102
@property
103103
def dtype_nbits(self) -> int:
104104
return 128 // self.T
105105

106106
@property
107107
def swizzled_cute_layout(self) -> SwizzledCuteLayout:
108+
shape: IntTuple
109+
strides: IntTuple
108110
if self.major_kind == "MN":
109-
shape = ((self.T, S, self.m), (8, self.k))
111+
shape = ((self.T, self.S, self.m), (8, self.k))
110112
if self.swizzle_mode == Tcgen05SwizzleMode.NO_SWIZZLE:
111113
strides = ((1, self.T, self.SBO), (self.T, self.LBO))
112114
else:

python/tilus/ir/layout/utils/cute.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ def __call__(self, *coords: IntTuple) -> Int:
6767
ret = tuple_sum(tuple_multiply(coords, self.strides))
6868
return ret
6969

70+
@property
71+
def flattened_shape(self) -> tuple[Int, ...]:
72+
if not isinstance(self.shape, Sequence):
73+
return (self.shape,)
74+
else:
75+
return tuple(tuple_product(item) for item in self.shape)
76+
7077

7178
class CuteSwizzle:
7279
def __init__(self, bbits: int, mbase: int, sshift: int):
@@ -86,14 +93,15 @@ def __call__(self, offset: Int) -> Int:
8693
y_mask = ((1 << self.bbits) - 1) << (self.mbase + self.sshift)
8794
return offset ^ ((offset & y_mask) >> self.sshift)
8895

96+
8997
class SwizzledCuteLayout:
9098
def __init__(self, layout: CuteLayout, swizzle: CuteSwizzle):
9199
self.layout: CuteLayout = layout
92100
self.swizzle: CuteSwizzle = swizzle
93-
101+
94102
def __str__(self) -> str:
95-
return str(self.swizzle) + '' + str(self.layout)
96-
103+
return str(self.swizzle) + "" + str(self.layout)
104+
97105
def __call__(self, *coords: IntTuple) -> Int:
98106
return self.swizzle(self.layout(*coords))
99107

tests/instructions/test_tcgen05_copy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ def __call__(self, m_size: int, n_size: int, x_ptr: ~int32, y_ptr: ~int32):
8080
def test_tcgen05_copy(major_kind, swizzle_mode):
8181
if major_kind == "MN":
8282
pytest.xfail("MN is not supported")
83-
if major_kind == "K" and swizzle_mode in [Tcgen05SwizzleMode.B64_SWIZZLE, Tcgen05SwizzleMode.B128_SWIZZLE]:
84-
pytest.xfail("K with swizzle mode B64 and B128 is not supported")
8583
m_size = 128
8684
n_size = 32
8785
x = torch.randint(0, 128, [m_size, n_size], dtype=torch.int32, device="cuda")

0 commit comments

Comments
 (0)