1515
1616
1717from __future__ import annotations
18- from typing import Optional
18+
1919from dataclasses import dataclass
2020
2121from hidet .ir .dtypes import uint64
2222from hidet .ir .expr import Expr
23- from hidet .ir .primitives .debug import printf
2423
2524from tilus .backends .codegen import BaseInstEmitter , register_emitter
26- from tilus .backends .emitters .cuda .tcgen05 .allocation import COLUMN_STRIDE , ROW_STRIDE
25+ from tilus .backends .emitters .cuda .tcgen05 .allocation import COLUMN_STRIDE , LANE_STRIDE
2726from tilus .extensions .hidet .ir .primitives .cuda .tcgen05 import (
2827 Tcgen05CopyMulticastKind ,
2928 Tcgen05CopyShapeKind ,
3029 Tcgen05CtaGroupKind ,
3130 tcgen05_copy ,
3231 tcgen05_encode_smem_descriptor ,
3332)
33+ from tilus .extensions .hidet .ir .utils .index_transform import index_deserialize
3434from tilus .ir .instructions .cuda .tmem import Tcgen05CopyInst
35- from tilus .ir .layout .cuda .tcgen05_smem import CanonicalSharedLayout , canonicalize_shared_layout , Tcgen05SwizzleMode
35+ from tilus .ir .layout .cuda .tcgen05_smem import CanonicalSharedLayout , Tcgen05SwizzleMode , canonicalize_shared_layout
3636from tilus .ir .tensor import SharedTensor , TMemoryTensor
3737from tilus .target import nvgpu_sm100
3838
39+
3940class GenerationFailedError (Exception ):
4041 pass
4142
43+
4244@dataclass
4345class SharedMatrixDescriptor :
4446 start_addr : Expr | int
@@ -78,19 +80,24 @@ class Tcgen05CopyInstMeta:
7880 tmem_offset : int
7981 shared_descriptor : SharedMatrixDescriptor
8082
83+ def __str__ (self ) -> str :
84+ items = []
85+ for key , value in self .__dict__ .items ():
86+ items .append (f"{ key } : { value } " )
87+ return "Tcgen05CopyInstMeta(" + ",\n " .join (items ) + "\n )"
88+
8189
8290@register_emitter (Tcgen05CopyInst , target = nvgpu_sm100 )
8391class Tcgen05CopyEmitter (BaseInstEmitter ):
8492 def split_canonical_layout (
85- self ,
86- canonical : CanonicalSharedLayout ,
87- shape_kind : Tcgen05CopyShapeKind
88- ) -> Optional [list [tuple [int , SharedMatrixDescriptor ]]]:
93+ self , smem_addr : Expr , canonical : CanonicalSharedLayout , shape_kind : Tcgen05CopyShapeKind
94+ ) -> list [Tcgen05CopyInstMeta ]:
8995 """
9096 A shared memory tensor might be very large that we need to split it into multiple sub-tensors and
9197 each sub-tensor is copied by a tcgen05.copy instruction. The smem_addr in returned SharedMatrixDescriptor
9298 is the offset of the sub-tensor relative to the shared memory tensor in bytes.
9399
100+ Each tcgen05.copy instruction copies a sub-tensor with the following layout:
94101 +----------------+--------------------------+--------------------------------------+-------------------------------------+
95102 | Major-ness | Swizzling mode | Canonical Layout without swizzling | Swizzling on the previous column |
96103 +================+==========================+======================================+=====================================+
@@ -118,58 +125,101 @@ def split_canonical_layout(
118125 - k represents the number of repeating patterns across columns.
119126 (The table is is from: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-canonical-layouts.)
120127
128+ The definition of the canonical layout in Tilus is similar to above table, but it's different since we want to represent the layouts
129+ in a more natural and extensible way for larger tensors. See the docstring of CanonicalSharedLayout for more details.
130+
121131 Returns
122132 -------
123133 ret: Optional[list[tuple[int, SharedMatrixDescriptor]]]
124134 The list of instructions, each instruction contains the tmem_offset and shared matrix descriptor for each sub-tensor.
125135 """
126136 cute_layout = canonical .swizzled_cute_layout .layout
127- m , n = cute_layout .shape
137+ m , n = cute_layout .flattened_shape
128138
129139 if shape_kind .n % canonical .dtype_nbits != 0 :
130- raise GenerationFailedError ("The number of columns in the shape kind must be divisible by the number of bits in the data type" )
140+ raise GenerationFailedError (
141+ "The number of columns in the shape kind must be divisible by the number of bits in the data type"
142+ )
131143
132144 inst_m , inst_n = shape_kind .m , shape_kind .n // canonical .dtype_nbits
133145
134146 if m % inst_m != 0 or n % inst_n != 0 :
135- raise GenerationFailedError ("The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout" )
147+ raise GenerationFailedError (
148+ "The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout"
149+ )
150+ if canonical .major_kind == "MN" and (inst_m % (canonical .T * canonical .S ) != 0 or inst_n % 8 != 0 ):
151+ raise GenerationFailedError (
152+ "The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout"
153+ )
154+ if canonical .major_kind == "K" and (inst_m % 8 != 0 or inst_n % (canonical .T * 2 ) != 0 ):
155+ raise GenerationFailedError (
156+ "The number of rows or columns in the shape kind must be divisible by the number of rows or columns in the canonical layout"
157+ )
136158
137159 num_m , num_n = m // inst_m , n // inst_n
138160 nbytes = canonical .dtype_nbits // 8
139161
162+ instructions : list [Tcgen05CopyInstMeta ] = []
140163 for i in range (num_m ):
141164 for j in range (num_n ):
142- tmem_offset = i * ROW_STRIDE + j * COLUMN_STRIDE
165+ tmem_offset = i * inst_m * LANE_STRIDE + j * inst_n * COLUMN_STRIDE
143166 if canonical .major_kind == "MN" :
144- assert inst_m % (canonical .T * canonical .S ) == 0 and inst_n % 8 == 0
167+ if canonical .swizzle_mode == Tcgen05SwizzleMode .NO_SWIZZLE :
168+ smem_offset = (
169+ i * inst_m // (canonical .T * canonical .S ) * canonical .SBO + j * inst_n // 8 * canonical .LBO
170+ ) * nbytes
171+ else :
172+ smem_offset = (
173+ i * inst_m // (canonical .T * canonical .S ) * canonical .LBO + j * inst_n // 8 * canonical .SBO
174+ ) * nbytes
145175 s_desc = SharedMatrixDescriptor (
146- start_addr = ( i * inst_m * canonical . SBO + j * inst_n * canonical . LBO ) * nbytes ,
176+ start_addr = smem_addr + smem_offset ,
147177 lbo = canonical .LBO * nbytes ,
148178 sbo = canonical .SBO * nbytes ,
149179 base_offset = 0 ,
150- stride_mode = 0 ,
180+ stride_mode = 0 ,
151181 swizzle_mode = canonical .swizzle_mode .encode (),
152182 )
153183 elif canonical .major_kind == "K" :
154- assert inst_m % 8 == 0 and inst_n % (canonical .T * canonical .S ) == 0
155184 if canonical .swizzle_mode == Tcgen05SwizzleMode .NO_SWIZZLE :
156- s_desc = SharedMatrixDescriptor (
157- start_addr = (i * inst_m * canonical .SBO + j * inst_n * canonical .LBO ) * nbytes ,
158- lbo = canonical .LBO * nbytes ,
159- sbo = canonical .SBO * nbytes ,
160- base_offset = 0 ,
161- stride_mode = 0 ,
162- swizzle_mode = canonical .swizzle_mode .encode (),
163- )
185+ smem_offset = (
186+ i * inst_m // 8 * canonical .SBO + j * inst_n // (canonical .T * canonical .S ) * canonical .LBO
187+ ) * nbytes
188+ lbo = canonical .LBO * nbytes
164189 else :
165- pass
190+ # j0, j1, j2 for shape (T, S, k)
191+ _ , j1 , j2 = index_deserialize (
192+ j * inst_n ,
193+ (canonical .T , canonical .S , canonical .k // (canonical .T * canonical .S )),
194+ ranks = [2 , 1 , 0 ],
195+ )
196+ smem_offset = (i * inst_m // 8 * canonical .SBO + j1 * canonical .T + j2 * canonical .LBO ) * nbytes
197+ lbo = 1 << 4 # assume lbo be 16 so that lbo >> 4 == 1, as required by the documentation
198+ s_desc = SharedMatrixDescriptor (
199+ start_addr = smem_addr + smem_offset ,
200+ lbo = lbo ,
201+ sbo = canonical .SBO * nbytes ,
202+ base_offset = 0 ,
203+ stride_mode = 0 ,
204+ swizzle_mode = canonical .swizzle_mode .encode (),
205+ )
206+
207+ instructions .append (
208+ Tcgen05CopyInstMeta (
209+ shape_kind = shape_kind ,
210+ multicast = Tcgen05CopyMulticastKind .NONE ,
211+ cta_group = Tcgen05CtaGroupKind .CTA_1 ,
212+ tmem_offset = tmem_offset ,
213+ shared_descriptor = s_desc ,
214+ )
215+ )
166216
217+ return instructions
167218
168219 def generate_instructions (
169220 self , tmem_tensor : TMemoryTensor , shared_tensor : SharedTensor
170221 ) -> list [Tcgen05CopyInstMeta ]:
171222 dtype = shared_tensor .dtype
172- shape = shared_tensor .shape
173223 canonical_layout : CanonicalSharedLayout | None = canonicalize_shared_layout (
174224 shared_tensor .layout , tmem_tensor .dtype
175225 )
@@ -180,52 +230,18 @@ def generate_instructions(
180230 f" shared_layout: { shared_tensor .layout } " ,
181231 ]
182232 raise ValueError ("\n " .join (msg ))
183- print (f"canonical_layout: { canonical_layout } " )
184- print (f"canonical_layout.swizzled_cute_layout: { canonical_layout .swizzled_cute_layout } " )
185- print (f"canonical_layout.atom_shape: { canonical_layout .atom_shape } " )
186- print (f"canonical_layout.atom_strides: { canonical_layout .atom_strides } " )
187233 smem_addr = self .shared_tensor_shared_space_addr [shared_tensor ]
188- ret = []
234+
189235 for shape_kind in [
190236 Tcgen05CopyShapeKind .R128x256B ,
191237 Tcgen05CopyShapeKind .R128x128B ,
192238 ]:
193- column_bits = shape_kind .as_int_tuple ()[1 ]
194- assert column_bits % dtype .nbits == 0
195- column_elements = column_bits // dtype .nbits
196- if shape [1 ] % column_elements != 0 :
197- continue
198- if shape [0 ] != 128 :
239+ try :
240+ return self .split_canonical_layout (smem_addr , canonical_layout , shape_kind )
241+ except GenerationFailedError :
199242 continue
200- num_inst_columns = shape [1 ] // column_elements
201- for inst_column in range (num_inst_columns ):
202- tmem_offset = inst_column * (column_bits // 32 * COLUMN_STRIDE )
203- smem_offset = inst_column * (
204- column_elements // canonical_layout .atom_shape [1 ] * canonical_layout .atom_strides [1 ] * dtype .nbytes
205- )
206243
207- shared_descriptor = SharedMatrixDescriptor (
208- start_addr = (smem_addr + smem_offset ),
209- lbo = (canonical_layout .LBO * dtype .nbytes ),
210- sbo = (canonical_layout .SBO * dtype .nbytes ),
211- base_offset = 0 ,
212- stride_mode = 0 , # 0 for relative mode and 1 for absolute mode
213- swizzle_mode = canonical_layout .swizzle_mode .encode (),
214- )
215- print (f"shared_descriptor: { shared_descriptor } " )
216-
217- inst_meta = Tcgen05CopyInstMeta (
218- shape_kind = shape_kind ,
219- multicast = Tcgen05CopyMulticastKind .NONE ,
220- cta_group = Tcgen05CtaGroupKind .CTA_1 ,
221- tmem_offset = tmem_offset ,
222- shared_descriptor = shared_descriptor ,
223- )
224- ret .append (inst_meta )
225- break
226- else :
227- raise ValueError ("No valid instructions generated" )
228- return ret
244+ raise ValueError ("No valid instructions generated" )
229245
230246 def check_warp_group (self ) -> None :
231247 begin = self .current_thread_group_begin
@@ -254,8 +270,6 @@ def emit(self, inst: Tcgen05CopyInst) -> None:
254270 s_desc = self .declare_var ("s_desc" , tp = uint64 , init = inst_meta .shared_descriptor .encoded ())
255271 t_addr = tmem_base_addr + inst_meta .tmem_offset
256272
257- self .append (printf ("taddr: %#08x, sdesc: %#016lx\n " , t_addr , s_desc ))
258-
259273 self .append (
260274 tcgen05_copy (
261275 taddr = t_addr ,
0 commit comments