@@ -20,6 +20,10 @@ using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
2020
2121unsigned defaultAllocationAnalysisScratchSizeFn (Operation *op);
2222
23+ unsigned getNumScratchElemsSwizzledCvt (const LinearLayout &srcLayout,
24+ const LinearLayout &dstLayout,
25+ int bitwidth);
26+
2327unsigned getNumScratchElemsSwizzledCvt (RankedTensorType srcTy,
2428 RankedTensorType dstTy);
2529
@@ -70,8 +74,11 @@ class Allocation {
7074 explicit Allocation (Operation *operation) : operation(operation) {}
7175
7276 // / Runs allocation analysis on the given top-level operation.
77+ // / \param sharedMemoryPartitionSize The size of each shared memory partition
78+ // / in bytes. A value of 0 means shared memory is not partitioned.
7379 void run (FuncAllocMapT &funcAllocMap,
74- triton::AllocationAnalysisScratchSizeFn scratchSizeGetter);
80+ triton::AllocationAnalysisScratchSizeFn scratchSizeGetter,
81+ size_t sharedMemoryPartitionSize = 0 );
7582
7683 // / Returns the operation this analysis was constructed from.
7784 Operation *getOperation () const { return operation; }
@@ -92,24 +99,29 @@ class Allocation {
9299 return Interval<size_t >(buffer.offset , buffer.offset + buffer.size );
93100 }
94101
95- // / Returns the buffer id of the given value.
96- // / This interface only returns the allocated buffer id.
97- // / If you want to get all the buffer ids that are associated with the given
98- // / value, including alias buffers, use getBufferIds.
99- BufferId getBufferId (Value value) const {
100- if (valueBuffer.count (value)) {
101- return valueBuffer.lookup (value)->id ;
102- } else {
103- return InvalidBufferId;
102+ // / Returns all buffer ids for a value.
103+ // / For partitioned tensors, returns all logical piece buffer ids.
104+ // / For non-partitioned values, returns a single-element vector.
105+ // / Returns empty vector if value has no associated buffer.
106+ SmallVector<BufferId> getBufferIds (Value value) const {
107+ SmallVector<BufferId> bufferIds;
108+ auto it = valueBuffer.find (value);
109+ if (it == valueBuffer.end ())
110+ return bufferIds;
111+
112+ for (auto *buffer : it->second ) {
113+ bufferIds.push_back (buffer->id );
104114 }
115+ return bufferIds;
105116 }
106117
107- // / Returns all the buffer ids of the given value, including alias buffers.
108- BufferIdSetT getBufferIds (Value value) const {
118+ // / Returns all buffer ids of the given value, including alias buffers.
119+ // / This is a superset of getBufferIds that also includes aliased buffers.
120+ BufferIdSetT getAllBufferIdsWithAliases (Value value) const {
109121 BufferIdSetT bufferIds;
110- auto allocBufferId = getBufferId (value);
111- if (allocBufferId != InvalidBufferId)
112- bufferIds. insert (allocBufferId);
122+ for ( auto bufferId : getBufferIds (value)) {
123+ bufferIds. insert (bufferId);
124+ }
113125 for (auto *buffer : aliasBuffer.lookup (value)) {
114126 if (buffer->id != InvalidBufferId)
115127 bufferIds.insert (buffer->id );
@@ -154,6 +166,10 @@ class Allocation {
154166 size_t alignment;
155167 size_t offset;
156168
169+ // / For partitioned tensors: buffers that reside in different physical
170+ // / partitions.
171+ SmallVector<BufferT *> neighbors;
172+
157173 bool operator ==(const BufferT &other) const { return id == other.id ; }
158174 bool operator <(const BufferT &other) const { return id < other.id ; }
159175
@@ -169,8 +185,8 @@ class Allocation {
169185
170186 // / Op -> Scratch Buffer
171187 using OpScratchMapT = llvm::MapVector<Operation *, BufferT *>;
172- // / Value -> Explicit Buffer
173- using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
188+ // / Value -> Explicit Buffers (vector for partitioned tensors)
189+ using ValueBufferMapT = llvm::MapVector<Value, SmallVector< BufferT *> >;
174190 // / Value -> Alias Buffer
175191 using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
176192 // / BufferId -> Buffer
@@ -184,16 +200,28 @@ class Allocation {
184200 nextId, BufferT (Kind, nextId, key, std::forward<Args>(args)...));
185201 BufferT *buffer = &it->second ;
186202 if constexpr (Kind == BufferT::BufferKind::Explicit) {
187- valueBuffer[key] = buffer;
203+ valueBuffer[key]. push_back ( buffer) ;
188204 } else if constexpr (Kind == BufferT::BufferKind::Virtual) {
189205 opVirtual[key] = buffer;
190206 } else {
191207 opScratch[key] = buffer;
192208 }
193209 }
194210
211+ // / Create multiple buffers for partitions where all different partitions
212+ // / are neighbors (must be placed in different physical shared memory slots).
213+ // /
214+ // / \param key The value that owns these buffers
215+ // / \param numPartitions Number of partition buffers to create
216+ // / \param partitionSize Size of each partition buffer in bytes
217+ // / \param alignment Required alignment for each buffer
218+ void addPartitionBuffers (Value key, unsigned numPartitions,
219+ size_t partitionSize, size_t alignment);
220+
195221 void addAlias (Value value, Value alloc) {
196- aliasBuffer[value].insert (valueBuffer[alloc]);
222+ for (auto *buffer : valueBuffer[alloc]) {
223+ aliasBuffer[value].insert (buffer);
224+ }
197225 }
198226
199227private:
@@ -222,7 +250,8 @@ class ModuleAllocation : public triton::CallGraph<Allocation> {
222250
223251 ModuleAllocation (ModuleOp moduleOp,
224252 triton::AllocationAnalysisScratchSizeFn scratchSizeGetter =
225- triton::defaultAllocationAnalysisScratchSizeFn)
253+ triton::defaultAllocationAnalysisScratchSizeFn,
254+ size_t sharedMemoryPartitionSize = 0 )
226255 : triton::CallGraph<Allocation>(moduleOp) {
227256 walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
228257 // Pre-order edge walk callback
@@ -231,7 +260,8 @@ class ModuleAllocation : public triton::CallGraph<Allocation> {
231260 [&](FunctionOpInterface funcOp) {
232261 auto [iter, inserted] = funcMap.try_emplace (funcOp, funcOp);
233262 if (inserted)
234- iter->second .run (funcMap, scratchSizeGetter);
263+ iter->second .run (funcMap, scratchSizeGetter,
264+ sharedMemoryPartitionSize);
235265 });
236266 }
237267
0 commit comments