3030#include " ttmlir/Dialect/TTNN/IR/TTNNOps.h" // IWYU pragma: keep
3131#include " ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" // IWYU pragma: keep
3232#include " llvm/ADT/BitVector.h"
33+ #include " llvm/ADT/DenseMap.h"
3334#include " llvm/ADT/STLExtras.h"
3435#include " llvm/Support/Casting.h"
3536
@@ -45,12 +46,14 @@ using mlir::RewritePatternSet;
4546using mlir::TypeConverter;
4647using mlir::UnrealizedConversionCastOp;
4748using mlir::ValueRange;
49+ using mlir::WalkResult;
4850using mlir::func::FuncOp;
4951namespace ttk = mlir::tt::ttkernel;
5052
5153// Start index in compile-time args for TA static metadata (is_sharded,
5254// is_dram). CTA layout is [CBs, TAs], so this is the number of CBs.
5355constexpr llvm::StringLiteral kBaseCTAIndexAttr = " ttl.base_cta_index" ;
56+
5457// Maps local args to global tensor indices for common runtime args (buffer
5558// addresses). CRTA is filtered per-thread, containing only addresses for
5659// tensors this thread uses.
@@ -118,15 +121,14 @@ static FailureOr<unsigned> getTensorFuncArgIndex(Value tensor) {
118121// / Get the L1 buffer address from runtime args for a tensor function argument.
119122// / Runtime args are indexed by the tensor's function argument position.
120123static FailureOr<Value>
121- getBufferAddressFromRuntimeArg (Value tensor, Location loc,
122- ConversionPatternRewriter &rewriter) {
124+ getBufferAddressFromRuntimeArg (Value tensor, Location loc, OpBuilder &builder) {
123125 auto argIdx = getTensorFuncArgIndex (tensor);
124126 if (failed (argIdx)) {
125127 return failure ();
126128 }
127- auto idxConst = rewriter .create <arith::ConstantIndexOp>(loc, *argIdx);
128- return rewriter
129- .create <ttk::GetCommonArgValOp>(loc, rewriter .getI32Type (), idxConst)
129+ auto idxConst = builder .create <arith::ConstantIndexOp>(loc, *argIdx);
130+ return builder
131+ .create <ttk::GetCommonArgValOp>(loc, builder .getI32Type (), idxConst)
130132 .getResult ();
131133}
132134
@@ -188,17 +190,16 @@ static Value computeCBTileIndexFromLoops(Operation *op, OpBuilder &builder) {
188190// / Build a TensorAccessor from CTA/CRTA indices, bank base, and page size.
189191// / ctaIndex: Index into compile-time args where tensor config starts.
190192// / crtaIndex: Index into compile-runtime args (typically 0).
191- static Value buildTensorAccessor (Location loc,
192- ConversionPatternRewriter &rewriter,
193+ static Value buildTensorAccessor (Location loc, OpBuilder &builder,
193194 int32_t ctaIndex, int32_t crtaIndex,
194195 Value bankBase, Value pageSize) {
195- auto ctaConst = rewriter .create <arith::ConstantIntOp>(loc, ctaIndex, 32 );
196- auto crtaConst = rewriter .create <arith::ConstantIntOp>(loc, crtaIndex, 32 );
197- auto args = rewriter .create <ttk::TensorAccessorArgsOp>(
196+ auto ctaConst = builder .create <arith::ConstantIntOp>(loc, ctaIndex, 32 );
197+ auto crtaConst = builder .create <arith::ConstantIntOp>(loc, crtaIndex, 32 );
198+ auto args = builder .create <ttk::TensorAccessorArgsOp>(
198199 loc, ctaConst.getResult (), crtaConst.getResult (),
199200 /* prev_args=*/ Value (), /* cta_expr=*/ nullptr , /* crta_expr=*/ nullptr );
200- auto accessor = rewriter .create <ttk::TensorAccessorOp>(loc, args.getResult (),
201- bankBase, pageSize);
201+ auto accessor = builder .create <ttk::TensorAccessorOp>(loc, args.getResult (),
202+ bankBase, pageSize);
202203 return accessor.getResult ();
203204}
204205
@@ -489,6 +490,9 @@ static FailureOr<int32_t> computeCTAIndex(Value tensor, Operation *op) {
489490 }
490491
491492 auto parentFunc = op->getParentOfType <func::FuncOp>();
493+ if (!parentFunc) {
494+ parentFunc = llvm::dyn_cast<func::FuncOp>(op);
495+ }
492496 if (!parentFunc) {
493497 return op->emitError (" operation must be inside a function" );
494498 }
@@ -529,9 +533,9 @@ static FailureOr<int32_t> computeCTAIndex(Value tensor, Operation *op) {
529533// / Unsupported layouts will emit errors referencing the appropriate GH issues:
530534// / - Sharded layouts: See GH issue #118
531535// / - Row-major (non-tiled): See GH issue #173
532- static FailureOr<Value>
533- materializeTensorAccessor (Value tensor, Value bankBase, Operation *op,
534- ConversionPatternRewriter &rewriter ) {
536+ static FailureOr<Value> materializeTensorAccessor (Value tensor, Value bankBase,
537+ Operation *op,
538+ OpBuilder &builder ) {
535539 auto tensorTy = llvm::dyn_cast<RankedTensorType>(tensor.getType ());
536540 if (!tensorTy) {
537541 return op->emitError (" expected RankedTensorType for tensor accessor" );
@@ -578,9 +582,9 @@ materializeTensorAccessor(Value tensor, Value bankBase, Operation *op,
578582 }
579583 int32_t crtaIndex = static_cast <int32_t >(*argIdx);
580584
581- auto pageSize = rewriter .create <arith::ConstantIntOp>(loc, pageSizeBytes, 32 );
585+ auto pageSize = builder .create <arith::ConstantIntOp>(loc, pageSizeBytes, 32 );
582586
583- return buildTensorAccessor (loc, rewriter , *ctaIndex, crtaIndex, bankBase,
587+ return buildTensorAccessor (loc, builder , *ctaIndex, crtaIndex, bankBase,
584588 pageSize);
585589}
586590
@@ -649,25 +653,111 @@ emitTileLoop(ConversionPatternRewriter &rewriter, Location loc, int64_t tilesY,
649653 }
650654}
651655
656+ // / Maps each function operation to its tensor accessors.
657+ // / The inner map uses the function argument index (unsigned) as the key
658+ // / to look up the pre-materialized TensorAccessor Value for that tensor arg.
659+ using FuncAccessorMapsType = DenseMap<func::FuncOp, DenseMap<unsigned , Value>>;
660+
661+ // / Look up a pre-materialized TensorAccessor for a tensor argument.
662+ static FailureOr<Value>
663+ lookupTensorAccessor (Value tensor,
664+ const DenseMap<unsigned , Value> &tensorToAccessor) {
665+ auto argIdx = getTensorFuncArgIndex (tensor);
666+ if (failed (argIdx)) {
667+ return failure ();
668+ }
669+
670+ auto it = tensorToAccessor.find (*argIdx);
671+ if (it == tensorToAccessor.end ()) {
672+ return failure ();
673+ }
674+
675+ return it->second ;
676+ }
677+
678+ // / Materialize TensorAccessor ops at function entry for tensor arguments used
679+ // / by ttl.copy. Returns a map used later by CopyLowering.
680+ static FailureOr<FuncAccessorMapsType>
681+ materializeFuncTensorAccessors (ModuleOp mod, MLIRContext &ctx) {
682+ FuncAccessorMapsType funcAccessorMaps;
683+
684+ auto walkResult = mod.walk ([&](func::FuncOp funcOp) -> WalkResult {
685+ if (!isNocKernel (funcOp.getOperation ())) {
686+ return WalkResult::advance ();
687+ }
688+
689+ if (funcOp.isExternal () || funcOp.getBody ().empty ()) {
690+ return WalkResult::advance ();
691+ }
692+
693+ DenseMap<unsigned , Value> tensorAccessors;
694+
695+ Block &entryBlock = funcOp.getBody ().front ();
696+ OpBuilder builder (&ctx);
697+ builder.setInsertionPointToStart (&entryBlock);
698+
699+ for (unsigned argIdx = 0 ; argIdx < funcOp.getNumArguments (); ++argIdx) {
700+ auto arg = funcOp.getArgument (argIdx);
701+ auto tensorTy = llvm::dyn_cast<RankedTensorType>(arg.getType ());
702+ if (!tensorTy) {
703+ continue ;
704+ }
705+
706+ bool usedByCopy = llvm::any_of (arg.getUses (), [](OpOperand &use) {
707+ return llvm::isa<CopyOp>(use.getOwner ());
708+ });
709+ if (!usedByCopy) {
710+ continue ;
711+ }
712+
713+ auto bankBase =
714+ getBufferAddressFromRuntimeArg (arg, arg.getLoc (), builder);
715+ if (failed (bankBase)) {
716+ funcOp.emitError (
717+ " tensor must be a function argument for runtime arg mapping" );
718+ return WalkResult::interrupt ();
719+ }
720+
721+ auto accessor = materializeTensorAccessor (arg, *bankBase,
722+ funcOp.getOperation (), builder);
723+ if (failed (accessor)) {
724+ return WalkResult::interrupt ();
725+ }
726+
727+ tensorAccessors.try_emplace (argIdx, *accessor);
728+ }
729+
730+ if (!tensorAccessors.empty ()) {
731+ funcAccessorMaps.try_emplace (funcOp, std::move (tensorAccessors));
732+ }
733+
734+ return WalkResult::advance ();
735+ });
736+
737+ if (walkResult.wasInterrupted ()) {
738+ return failure ();
739+ }
740+
741+ return funcAccessorMaps;
742+ }
743+
652744// / Lower tensor->CB copy: read from DRAM/L1 tensor into circular buffer.
653- static LogicalResult lowerTensorToCB (CopyOp op, Value srcTensor, Value dstCB,
654- ConversionPatternRewriter &rewriter,
655- const TypeConverter &typeConverter) {
745+ static LogicalResult
746+ lowerTensorToCB (CopyOp op, Value srcTensor, Value dstCB,
747+ ConversionPatternRewriter &rewriter,
748+ const TypeConverter &typeConverter,
749+ const DenseMap<unsigned , Value> *tensorAccessors) {
656750 auto loc = op.getLoc ();
657751
658- // Get tensor L1 address from runtime args.
659- auto bankBase = getBufferAddressFromRuntimeArg (srcTensor, loc, rewriter);
660- if (failed (bankBase)) {
752+ if (!tensorAccessors) {
661753 return rewriter.notifyMatchFailure (
662- op, " tensor must be a function argument for runtime arg mapping " );
754+ op, " no tensor accessor map for parent function " );
663755 }
664756
665- // Create tensor accessor with actual buffer address.
666- // This derives page size from TTNNLayoutAttr encoding.
667- auto srcAccessor =
668- materializeTensorAccessor (srcTensor, *bankBase, op, rewriter);
757+ auto srcAccessor = lookupTensorAccessor (srcTensor, *tensorAccessors);
669758 if (failed (srcAccessor)) {
670- return failure (); // Error already emitted by materializeTensorAccessor
759+ return rewriter.notifyMatchFailure (
760+ op, " no pre-materialized tensor accessor found for src tensor" );
671761 }
672762
673763 // Convert CB to TTKernel type and get write pointer.
@@ -719,24 +809,22 @@ static LogicalResult lowerTensorToCB(CopyOp op, Value srcTensor, Value dstCB,
719809}
720810
721811// / Lower CB->tensor copy: write from circular buffer to DRAM/L1 tensor.
722- static LogicalResult lowerCBToTensor (CopyOp op, Value srcCB, Value dstTensor,
723- ConversionPatternRewriter &rewriter,
724- const TypeConverter &typeConverter) {
812+ static LogicalResult
813+ lowerCBToTensor (CopyOp op, Value srcCB, Value dstTensor,
814+ ConversionPatternRewriter &rewriter,
815+ const TypeConverter &typeConverter,
816+ const DenseMap<unsigned , Value> *tensorAccessors) {
725817 auto loc = op.getLoc ();
726818
727- // Get tensor L1 address from runtime args.
728- auto bankBase = getBufferAddressFromRuntimeArg (dstTensor, loc, rewriter);
729- if (failed (bankBase)) {
819+ if (!tensorAccessors) {
730820 return rewriter.notifyMatchFailure (
731- op, " tensor must be a function argument for runtime arg mapping " );
821+ op, " no tensor accessor map for parent function " );
732822 }
733823
734- // Create tensor accessor with actual buffer address.
735- // This derives page size from TTNNLayoutAttr encoding.
736- auto dstAccessor =
737- materializeTensorAccessor (dstTensor, *bankBase, op, rewriter);
824+ auto dstAccessor = lookupTensorAccessor (dstTensor, *tensorAccessors);
738825 if (failed (dstAccessor)) {
739- return failure (); // Error already emitted by materializeTensorAccessor
826+ return rewriter.notifyMatchFailure (
827+ op, " no pre-materialized tensor accessor found for dst tensor" );
740828 }
741829
742830 // Convert CB to TTKernel type and get read pointer.
@@ -788,7 +876,12 @@ static LogicalResult lowerCBToTensor(CopyOp op, Value srcCB, Value dstTensor,
788876}
789877
790878struct CopyLowering : OpConversionPattern<CopyOp> {
791- using OpConversionPattern::OpConversionPattern;
879+ CopyLowering (const TypeConverter &typeConverter, MLIRContext *context,
880+ const FuncAccessorMapsType *funcAccessorMaps)
881+ : OpConversionPattern(typeConverter, context),
882+ funcAccessorMaps (funcAccessorMaps) {}
883+
884+ const FuncAccessorMapsType *funcAccessorMaps;
792885
793886 LogicalResult
794887 matchAndRewrite (CopyOp op, OpAdaptor adaptor,
@@ -798,6 +891,16 @@ struct CopyLowering : OpConversionPattern<CopyOp> {
798891 return rewriter.notifyMatchFailure (op, " no type converter" );
799892 }
800893
894+ const DenseMap<unsigned , Value> *tensorAccessors = nullptr ;
895+ if (funcAccessorMaps) {
896+ if (auto parentFunc = op->getParentOfType <func::FuncOp>()) {
897+ auto it = funcAccessorMaps->find (parentFunc);
898+ if (it != funcAccessorMaps->end ()) {
899+ tensorAccessors = &it->second ;
900+ }
901+ }
902+ }
903+
801904 // Use original operands for classification since lowering functions
802905 // handle type conversion internally.
803906 Value src = op.getSrc ();
@@ -809,14 +912,14 @@ struct CopyLowering : OpConversionPattern<CopyOp> {
809912 if (srcKind == CopySourceKind::TensorAccessor &&
810913 dstKind == CopyDestKind::CircularBuffer) {
811914 return lowerTensorToCB (op, src, adaptor.getDst (), rewriter,
812- *typeConverter);
915+ *typeConverter, tensorAccessors );
813916 }
814917
815918 // CB -> Tensor: write from circular buffer to tensor.
816919 if (srcKind == CopySourceKind::CircularBuffer &&
817920 dstKind == CopyDestKind::TensorAccessor) {
818921 return lowerCBToTensor (op, adaptor.getSrc (), dst, rewriter,
819- *typeConverter);
922+ *typeConverter, tensorAccessors );
820923 }
821924
822925 return rewriter.notifyMatchFailure (op, [&](Diagnostic &diag) {
@@ -920,6 +1023,12 @@ static LogicalResult
9201023lowerTTLOpsToTTKernel (ModuleOp mod, MLIRContext &ctx,
9211024 TTLToTTKernelTypeConverter &typeConverter,
9221025 StringRef passName) {
1026+ auto accessorMapsOrFailure = materializeFuncTensorAccessors (mod, ctx);
1027+ if (failed (accessorMapsOrFailure)) {
1028+ return failure ();
1029+ }
1030+ FuncAccessorMapsType funcAccessorMaps = *accessorMapsOrFailure;
1031+
9231032 ConversionTarget target (ctx);
9241033 target.addIllegalDialect <tt::ttl::TTLDialect>();
9251034 target.addLegalDialect <arith::ArithDialect, BuiltinDialect, scf::SCFDialect,
@@ -951,9 +1060,10 @@ lowerTTLOpsToTTKernel(ModuleOp mod, MLIRContext &ctx,
9511060 });
9521061
9531062 RewritePatternSet patterns (&ctx);
954- patterns.add <BindCBLowering, CopyLowering, WaitLowering, CBReserveLowering,
955- CBPushLowering, CBWaitLowering, CBPopLowering, StoreLowering>(
956- typeConverter, &ctx);
1063+ patterns.add <BindCBLowering, WaitLowering, CBReserveLowering, CBPushLowering,
1064+ CBWaitLowering, CBPopLowering, StoreLowering>(typeConverter,
1065+ &ctx);
1066+ patterns.add <CopyLowering>(typeConverter, &ctx, &funcAccessorMaps);
9571067 populateFunctionOpInterfaceTypeConversionPattern (
9581068 func::FuncOp::getOperationName (), patterns, typeConverter);
9591069
0 commit comments