1414# limitations under the License.
1515
1616from dataclasses import dataclass
17- from typing import Sequence
1817
1918from hidet .ir .dtypes import int32 , uint32
2019from hidet .ir .expr import Expr , cast
2120
2221from tilus .backends .emitter import BaseInstEmitter , register_emitter
2322from tilus .extensions .hidet .ir .primitives .cuda .tcgen05 import (
23+ COLUMN_STRIDE ,
24+ LANE_STRIDE ,
2425 Tcgen05LoadStoreNumKind ,
2526 Tcgen05LoadStorePackKind ,
2627 Tcgen05LoadStoreShapeKind ,
4142from tilus .target import nvgpu_sm100
4243from tilus .utils import gcd
4344
44- # tmem addr: 0xAAAABBBB where AAAA is the lane index and BBBB is the column index
45- # lane index: 0x0000 to 0x007F
46- # column index: 0x0000 to 0x01FF
47- LANE_STRIDE = 0x00010000
48- COLUMN_STRIDE = 0x00000001
49-
5045
5146@dataclass
5247class LoadStoreWarpInst :
@@ -58,24 +53,6 @@ class LoadStoreWarpInst:
5853
5954
6055class TMemoryLoadStoreBaseEmitter (BaseInstEmitter ):
61- def slice_tmem_tensor (
62- self , tmem_tensor : TMemoryTensor , offsets : Sequence [int ], shape : Sequence [int ]
63- ) -> tuple [TMemoryTensor , Expr ]:
64- if any (not isinstance (ofs , int ) for ofs in offsets ):
65- raise ValueError ("All offsets must be integer constants" )
66- if len (offsets ) != 2 :
67- raise ValueError ("The length of offsets must be 2" )
68- if len (shape ) != 2 :
69- raise ValueError ("The length of shape must be 2" )
70- tmem_addr = self .get_or_allocate_var (tmem_tensor )
71- sliced_tmem_tensor = TMemoryTensor .create (
72- dtype = tmem_tensor .dtype , shape = shape , first_lane = tmem_tensor .first_lane + offsets [0 ]
73- )
74- sliced_tmem_addr = (
75- tmem_addr + offsets [0 ] * LANE_STRIDE + offsets [1 ] * COLUMN_STRIDE * tmem_tensor .dtype .nbits // 32
76- )
77- return sliced_tmem_tensor , sliced_tmem_addr
78-
7956 def emit_tcgen05_inst (self , inst : LoadStoreWarpInst ) -> None :
8057 raise NotImplementedError ("Subclasses must implement this method" )
8158
@@ -87,11 +64,11 @@ def emit_tcgen05_instructions(
8764 ) -> None :
8865 if self .current_num_threads % 32 != 0 :
8966 raise ValueError ("The number of threads in the current thread group must be divisible by 32" )
90- if self .current_thread_group_begin % 128 != tmem_tensor .first_lane :
67+ if self .current_thread_group_begin % 128 != tmem_tensor .layout . lane_offset :
9168 raise ValueError (
9269 "Lane mismatch: the first lane of the tmem tensor must be the same as the thread group begin"
9370 )
94- if self .current_num_threads != tmem_tensor .shape [0 ]:
71+ if self .current_num_threads != tmem_tensor .shape [- 2 ]:
9572 raise ValueError (
9673 "The number of threads in the current thread group must be the same as the number of lanes in the tmem tensor"
9774 )
@@ -174,8 +151,7 @@ class TMemoryLoadEmitter(TMemoryLoadStoreBaseEmitter):
174151 def emit (self , inst : Tcgen05LoadInst ) -> None :
175152 regs_tensor = inst .register_output
176153 tmem_tensor = inst .inputs [0 ].as_tmemory_tensor ()
177- sliced_tmem_tensor , sliced_tmem_addr = self .slice_tmem_tensor (tmem_tensor , inst .offsets , regs_tensor .shape )
178- self .emit_tcgen05_instructions (regs_tensor , sliced_tmem_tensor , sliced_tmem_addr )
154+ self .emit_tcgen05_instructions (regs_tensor , tmem_tensor , self .tensor2var [tmem_tensor ])
179155
180156 def emit_tcgen05_inst (self , inst : LoadStoreWarpInst ) -> None :
181157 self .append (
@@ -195,11 +171,10 @@ def emit(self, inst: Tcgen05StoreInst) -> None:
195171 regs_tensor = inst .inputs [1 ].as_register_tensor ()
196172 tmem_tensor = inst .inputs [0 ].as_tmemory_tensor ()
197173
198- sliced_tmem_tensor , sliced_tmem_addr = self .slice_tmem_tensor (tmem_tensor , inst .offsets , regs_tensor .shape )
199174 self .emit_tcgen05_instructions (
200175 regs_tensor ,
201- sliced_tmem_tensor ,
202- sliced_tmem_addr ,
176+ tmem_tensor ,
177+ self . tensor2var [ tmem_tensor ] ,
203178 )
204179
205180 def emit_tcgen05_inst (self , inst : LoadStoreWarpInst ) -> None :
0 commit comments