@@ -25,6 +25,20 @@ namespace ttng = mlir::triton::nvidia_gpu;
2525
2626namespace {
2727
28+ static bool isValueAvailableInScope (Value value, Region *scope) {
29+ if (!scope)
30+ return false ;
31+ if (auto arg = dyn_cast<BlockArgument>(value)) {
32+ Region *argRegion = arg.getOwner ()->getParent ();
33+ return argRegion == scope || scope->isAncestor (argRegion);
34+ }
35+ if (Operation *def = value.getDefiningOp ()) {
36+ Region *defRegion = def->getParentRegion ();
37+ return defRegion == scope || scope->isAncestor (defRegion);
38+ }
39+ return false ;
40+ }
41+
2842constexpr int64_t kTileM = 8 ;
2943constexpr int64_t kTileN = 8 ;
3044
@@ -162,6 +176,8 @@ class TmemScratchManager {
162176 return std::nullopt ;
163177 }
164178
179+ ptr = remapToScope (ptr, rewriter, scope, loc);
180+
165181 ScratchInfo info{ptr, tensorTy};
166182 scratchMap[memdesc][scope] = info;
167183 return info;
@@ -189,8 +205,9 @@ class TmemScratchManager {
189205 rewriter, loc, rewriter.getI32IntegerAttr (stride));
190206 auto offsetEls = arith::MulIOp::create (
191207 rewriter, loc, rewriter.getI32Type (), offsetVal, strideVal);
192- auto ptr = tt::AddPtrOp::create (rewriter, loc, baseInfo->ptr .getType (),
193- baseInfo->ptr , offsetEls);
208+ Value ptr = tt::AddPtrOp::create (rewriter, loc, baseInfo->ptr .getType (),
209+ baseInfo->ptr , offsetEls);
210+ ptr = remapToScope (ptr, rewriter, scope, loc);
194211 auto layout = getScratchEncoding (rewriter, memdesc, memTy);
195212 auto tensorTy = RankedTensorType::get (memTy.getShape (),
196213 memTy.getElementType (), layout);
@@ -218,8 +235,9 @@ class TmemScratchManager {
218235 rewriter, loc, rewriter.getI32IntegerAttr (stride));
219236 auto offset = arith::MulIOp::create (rewriter, loc, rewriter.getI32Type (),
220237 idx, strideVal);
221- auto ptr = tt::AddPtrOp::create (rewriter, loc, baseInfo->ptr .getType (),
222- baseInfo->ptr , offset);
238+ Value ptr = tt::AddPtrOp::create (rewriter, loc, baseInfo->ptr .getType (),
239+ baseInfo->ptr , offset);
240+ ptr = remapToScope (ptr, rewriter, scope, loc);
223241 auto layout = getScratchEncoding (rewriter, memdesc, memTy);
224242 auto tensorTy = RankedTensorType::get (memTy.getShape (),
225243 memTy.getElementType (), layout);
@@ -241,6 +259,7 @@ class TmemScratchManager {
241259 if (ptr.getType () != ptrTy) {
242260 ptr = tt::BitcastOp::create (rewriter, loc, ptrTy, ptr);
243261 }
262+ ptr = remapToScope (ptr, rewriter, scope, loc);
244263
245264 auto layout = getScratchEncoding (rewriter, memdesc, memTy);
246265 auto tensorTy = RankedTensorType::get (memTy.getShape (),
@@ -254,6 +273,36 @@ class TmemScratchManager {
254273 }
255274
256275private:
276+ Value remapToScope (Value value, PatternRewriter &rewriter, Region *scope,
277+ Location loc) {
278+ if (!scope || isValueAvailableInScope (value, scope))
279+ return value;
280+
281+ auto *parentOp = scope->getParentOp ();
282+ auto partitions = dyn_cast_or_null<ttg::WarpSpecializePartitionsOp>(
283+ parentOp ? parentOp : nullptr );
284+ if (!partitions)
285+ return value;
286+
287+ unsigned captureIdx = partitions.getNumOperands ();
288+ for (auto [i, capture] :
289+ llvm::enumerate (partitions.getExplicitCaptures ())) {
290+ if (capture == value) {
291+ captureIdx = i;
292+ break ;
293+ }
294+ }
295+
296+ if (captureIdx == partitions.getNumOperands ()) {
297+ partitions->insertOperands (captureIdx, value);
298+ for (Region ®ion : partitions.getPartitionRegions ()) {
299+ region.addArgument (value.getType (), loc);
300+ }
301+ }
302+
303+ return scope->getArgument (captureIdx);
304+ }
305+
257306 DenseMap<Value, DenseMap<Region *, ScratchInfo>> scratchMap;
258307};
259308
0 commit comments