Skip to content

Commit 7b126a6

Browse files
committed
move tensor accessor ops to function level
1 parent 930c9eb commit 7b126a6

17 files changed

+512
-316
lines changed

lib/Dialect/TTL/Transforms/ConvertTTLToTTKernel.cpp

Lines changed: 157 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" // IWYU pragma: keep
3131
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" // IWYU pragma: keep
3232
#include "llvm/ADT/BitVector.h"
33+
#include "llvm/ADT/DenseMap.h"
3334
#include "llvm/ADT/STLExtras.h"
3435
#include "llvm/Support/Casting.h"
3536

@@ -45,12 +46,14 @@ using mlir::RewritePatternSet;
4546
using mlir::TypeConverter;
4647
using mlir::UnrealizedConversionCastOp;
4748
using mlir::ValueRange;
49+
using mlir::WalkResult;
4850
using mlir::func::FuncOp;
4951
namespace ttk = mlir::tt::ttkernel;
5052

5153
// Start index in compile-time args for TA static metadata (is_sharded,
5254
// is_dram). CTA layout is [CBs, TAs], so this is the number of CBs.
5355
constexpr llvm::StringLiteral kBaseCTAIndexAttr = "ttl.base_cta_index";
56+
5457
// Maps local args to global tensor indices for common runtime args (buffer
5558
// addresses). CRTA is filtered per-thread, containing only addresses for
5659
// tensors this thread uses.
@@ -118,15 +121,14 @@ static FailureOr<unsigned> getTensorFuncArgIndex(Value tensor) {
118121
/// Get the L1 buffer address from runtime args for a tensor function argument.
119122
/// Runtime args are indexed by the tensor's function argument position.
120123
static FailureOr<Value>
121-
getBufferAddressFromRuntimeArg(Value tensor, Location loc,
122-
ConversionPatternRewriter &rewriter) {
124+
getBufferAddressFromRuntimeArg(Value tensor, Location loc, OpBuilder &builder) {
123125
auto argIdx = getTensorFuncArgIndex(tensor);
124126
if (failed(argIdx)) {
125127
return failure();
126128
}
127-
auto idxConst = rewriter.create<arith::ConstantIndexOp>(loc, *argIdx);
128-
return rewriter
129-
.create<ttk::GetCommonArgValOp>(loc, rewriter.getI32Type(), idxConst)
129+
auto idxConst = builder.create<arith::ConstantIndexOp>(loc, *argIdx);
130+
return builder
131+
.create<ttk::GetCommonArgValOp>(loc, builder.getI32Type(), idxConst)
130132
.getResult();
131133
}
132134

@@ -188,17 +190,16 @@ static Value computeCBTileIndexFromLoops(Operation *op, OpBuilder &builder) {
188190
/// Build a TensorAccessor from CTA/CRTA indices, bank base, and page size.
189191
/// ctaIndex: Index into compile-time args where tensor config starts.
190192
/// crtaIndex: Index into compile-runtime args (typically 0).
191-
static Value buildTensorAccessor(Location loc,
192-
ConversionPatternRewriter &rewriter,
193+
static Value buildTensorAccessor(Location loc, OpBuilder &builder,
193194
int32_t ctaIndex, int32_t crtaIndex,
194195
Value bankBase, Value pageSize) {
195-
auto ctaConst = rewriter.create<arith::ConstantIntOp>(loc, ctaIndex, 32);
196-
auto crtaConst = rewriter.create<arith::ConstantIntOp>(loc, crtaIndex, 32);
197-
auto args = rewriter.create<ttk::TensorAccessorArgsOp>(
196+
auto ctaConst = builder.create<arith::ConstantIntOp>(loc, ctaIndex, 32);
197+
auto crtaConst = builder.create<arith::ConstantIntOp>(loc, crtaIndex, 32);
198+
auto args = builder.create<ttk::TensorAccessorArgsOp>(
198199
loc, ctaConst.getResult(), crtaConst.getResult(),
199200
/*prev_args=*/Value(), /*cta_expr=*/nullptr, /*crta_expr=*/nullptr);
200-
auto accessor = rewriter.create<ttk::TensorAccessorOp>(loc, args.getResult(),
201-
bankBase, pageSize);
201+
auto accessor = builder.create<ttk::TensorAccessorOp>(loc, args.getResult(),
202+
bankBase, pageSize);
202203
return accessor.getResult();
203204
}
204205

@@ -489,6 +490,9 @@ static FailureOr<int32_t> computeCTAIndex(Value tensor, Operation *op) {
489490
}
490491

491492
auto parentFunc = op->getParentOfType<func::FuncOp>();
493+
if (!parentFunc) {
494+
parentFunc = llvm::dyn_cast<func::FuncOp>(op);
495+
}
492496
if (!parentFunc) {
493497
return op->emitError("operation must be inside a function");
494498
}
@@ -529,9 +533,9 @@ static FailureOr<int32_t> computeCTAIndex(Value tensor, Operation *op) {
529533
/// Unsupported layouts will emit errors referencing the appropriate GH issues:
530534
/// - Sharded layouts: See GH issue #118
531535
/// - Row-major (non-tiled): See GH issue #173
532-
static FailureOr<Value>
533-
materializeTensorAccessor(Value tensor, Value bankBase, Operation *op,
534-
ConversionPatternRewriter &rewriter) {
536+
static FailureOr<Value> materializeTensorAccessor(Value tensor, Value bankBase,
537+
Operation *op,
538+
OpBuilder &builder) {
535539
auto tensorTy = llvm::dyn_cast<RankedTensorType>(tensor.getType());
536540
if (!tensorTy) {
537541
return op->emitError("expected RankedTensorType for tensor accessor");
@@ -578,9 +582,9 @@ materializeTensorAccessor(Value tensor, Value bankBase, Operation *op,
578582
}
579583
int32_t crtaIndex = static_cast<int32_t>(*argIdx);
580584

581-
auto pageSize = rewriter.create<arith::ConstantIntOp>(loc, pageSizeBytes, 32);
585+
auto pageSize = builder.create<arith::ConstantIntOp>(loc, pageSizeBytes, 32);
582586

583-
return buildTensorAccessor(loc, rewriter, *ctaIndex, crtaIndex, bankBase,
587+
return buildTensorAccessor(loc, builder, *ctaIndex, crtaIndex, bankBase,
584588
pageSize);
585589
}
586590

@@ -649,25 +653,111 @@ emitTileLoop(ConversionPatternRewriter &rewriter, Location loc, int64_t tilesY,
649653
}
650654
}
651655

656+
/// Maps each function operation to its tensor accessors.
657+
/// The inner map uses the function argument index (unsigned) as the key
658+
/// to look up the pre-materialized TensorAccessor Value for that tensor arg.
659+
using FuncAccessorMapsType = DenseMap<func::FuncOp, DenseMap<unsigned, Value>>;
660+
661+
/// Look up a pre-materialized TensorAccessor for a tensor argument.
662+
static FailureOr<Value>
663+
lookupTensorAccessor(Value tensor,
664+
const DenseMap<unsigned, Value> &tensorToAccessor) {
665+
auto argIdx = getTensorFuncArgIndex(tensor);
666+
if (failed(argIdx)) {
667+
return failure();
668+
}
669+
670+
auto it = tensorToAccessor.find(*argIdx);
671+
if (it == tensorToAccessor.end()) {
672+
return failure();
673+
}
674+
675+
return it->second;
676+
}
677+
678+
/// Materialize TensorAccessor ops at function entry for tensor arguments used
679+
/// by ttl.copy. Returns a map used later by CopyLowering.
680+
static FailureOr<FuncAccessorMapsType>
681+
materializeFuncTensorAccessors(ModuleOp mod, MLIRContext &ctx) {
682+
FuncAccessorMapsType funcAccessorMaps;
683+
684+
auto walkResult = mod.walk([&](func::FuncOp funcOp) -> WalkResult {
685+
if (!isNocKernel(funcOp.getOperation())) {
686+
return WalkResult::advance();
687+
}
688+
689+
if (funcOp.isExternal() || funcOp.getBody().empty()) {
690+
return WalkResult::advance();
691+
}
692+
693+
DenseMap<unsigned, Value> tensorAccessors;
694+
695+
Block &entryBlock = funcOp.getBody().front();
696+
OpBuilder builder(&ctx);
697+
builder.setInsertionPointToStart(&entryBlock);
698+
699+
for (unsigned argIdx = 0; argIdx < funcOp.getNumArguments(); ++argIdx) {
700+
auto arg = funcOp.getArgument(argIdx);
701+
auto tensorTy = llvm::dyn_cast<RankedTensorType>(arg.getType());
702+
if (!tensorTy) {
703+
continue;
704+
}
705+
706+
bool usedByCopy = llvm::any_of(arg.getUses(), [](OpOperand &use) {
707+
return llvm::isa<CopyOp>(use.getOwner());
708+
});
709+
if (!usedByCopy) {
710+
continue;
711+
}
712+
713+
auto bankBase =
714+
getBufferAddressFromRuntimeArg(arg, arg.getLoc(), builder);
715+
if (failed(bankBase)) {
716+
funcOp.emitError(
717+
"tensor must be a function argument for runtime arg mapping");
718+
return WalkResult::interrupt();
719+
}
720+
721+
auto accessor = materializeTensorAccessor(arg, *bankBase,
722+
funcOp.getOperation(), builder);
723+
if (failed(accessor)) {
724+
return WalkResult::interrupt();
725+
}
726+
727+
tensorAccessors.try_emplace(argIdx, *accessor);
728+
}
729+
730+
if (!tensorAccessors.empty()) {
731+
funcAccessorMaps.try_emplace(funcOp, std::move(tensorAccessors));
732+
}
733+
734+
return WalkResult::advance();
735+
});
736+
737+
if (walkResult.wasInterrupted()) {
738+
return failure();
739+
}
740+
741+
return funcAccessorMaps;
742+
}
743+
652744
/// Lower tensor->CB copy: read from DRAM/L1 tensor into circular buffer.
653-
static LogicalResult lowerTensorToCB(CopyOp op, Value srcTensor, Value dstCB,
654-
ConversionPatternRewriter &rewriter,
655-
const TypeConverter &typeConverter) {
745+
static LogicalResult
746+
lowerTensorToCB(CopyOp op, Value srcTensor, Value dstCB,
747+
ConversionPatternRewriter &rewriter,
748+
const TypeConverter &typeConverter,
749+
const DenseMap<unsigned, Value> *tensorAccessors) {
656750
auto loc = op.getLoc();
657751

658-
// Get tensor L1 address from runtime args.
659-
auto bankBase = getBufferAddressFromRuntimeArg(srcTensor, loc, rewriter);
660-
if (failed(bankBase)) {
752+
if (!tensorAccessors) {
661753
return rewriter.notifyMatchFailure(
662-
op, "tensor must be a function argument for runtime arg mapping");
754+
op, "no tensor accessor map for parent function");
663755
}
664756

665-
// Create tensor accessor with actual buffer address.
666-
// This derives page size from TTNNLayoutAttr encoding.
667-
auto srcAccessor =
668-
materializeTensorAccessor(srcTensor, *bankBase, op, rewriter);
757+
auto srcAccessor = lookupTensorAccessor(srcTensor, *tensorAccessors);
669758
if (failed(srcAccessor)) {
670-
return failure(); // Error already emitted by materializeTensorAccessor
759+
return rewriter.notifyMatchFailure(
760+
op, "no pre-materialized tensor accessor found for src tensor");
671761
}
672762

673763
// Convert CB to TTKernel type and get write pointer.
@@ -719,24 +809,22 @@ static LogicalResult lowerTensorToCB(CopyOp op, Value srcTensor, Value dstCB,
719809
}
720810

721811
/// Lower CB->tensor copy: write from circular buffer to DRAM/L1 tensor.
722-
static LogicalResult lowerCBToTensor(CopyOp op, Value srcCB, Value dstTensor,
723-
ConversionPatternRewriter &rewriter,
724-
const TypeConverter &typeConverter) {
812+
static LogicalResult
813+
lowerCBToTensor(CopyOp op, Value srcCB, Value dstTensor,
814+
ConversionPatternRewriter &rewriter,
815+
const TypeConverter &typeConverter,
816+
const DenseMap<unsigned, Value> *tensorAccessors) {
725817
auto loc = op.getLoc();
726818

727-
// Get tensor L1 address from runtime args.
728-
auto bankBase = getBufferAddressFromRuntimeArg(dstTensor, loc, rewriter);
729-
if (failed(bankBase)) {
819+
if (!tensorAccessors) {
730820
return rewriter.notifyMatchFailure(
731-
op, "tensor must be a function argument for runtime arg mapping");
821+
op, "no tensor accessor map for parent function");
732822
}
733823

734-
// Create tensor accessor with actual buffer address.
735-
// This derives page size from TTNNLayoutAttr encoding.
736-
auto dstAccessor =
737-
materializeTensorAccessor(dstTensor, *bankBase, op, rewriter);
824+
auto dstAccessor = lookupTensorAccessor(dstTensor, *tensorAccessors);
738825
if (failed(dstAccessor)) {
739-
return failure(); // Error already emitted by materializeTensorAccessor
826+
return rewriter.notifyMatchFailure(
827+
op, "no pre-materialized tensor accessor found for dst tensor");
740828
}
741829

742830
// Convert CB to TTKernel type and get read pointer.
@@ -788,7 +876,12 @@ static LogicalResult lowerCBToTensor(CopyOp op, Value srcCB, Value dstTensor,
788876
}
789877

790878
struct CopyLowering : OpConversionPattern<CopyOp> {
791-
using OpConversionPattern::OpConversionPattern;
879+
CopyLowering(const TypeConverter &typeConverter, MLIRContext *context,
880+
const FuncAccessorMapsType *funcAccessorMaps)
881+
: OpConversionPattern(typeConverter, context),
882+
funcAccessorMaps(funcAccessorMaps) {}
883+
884+
const FuncAccessorMapsType *funcAccessorMaps;
792885

793886
LogicalResult
794887
matchAndRewrite(CopyOp op, OpAdaptor adaptor,
@@ -798,6 +891,16 @@ struct CopyLowering : OpConversionPattern<CopyOp> {
798891
return rewriter.notifyMatchFailure(op, "no type converter");
799892
}
800893

894+
const DenseMap<unsigned, Value> *tensorAccessors = nullptr;
895+
if (funcAccessorMaps) {
896+
if (auto parentFunc = op->getParentOfType<func::FuncOp>()) {
897+
auto it = funcAccessorMaps->find(parentFunc);
898+
if (it != funcAccessorMaps->end()) {
899+
tensorAccessors = &it->second;
900+
}
901+
}
902+
}
903+
801904
// Use original operands for classification since lowering functions
802905
// handle type conversion internally.
803906
Value src = op.getSrc();
@@ -809,14 +912,14 @@ struct CopyLowering : OpConversionPattern<CopyOp> {
809912
if (srcKind == CopySourceKind::TensorAccessor &&
810913
dstKind == CopyDestKind::CircularBuffer) {
811914
return lowerTensorToCB(op, src, adaptor.getDst(), rewriter,
812-
*typeConverter);
915+
*typeConverter, tensorAccessors);
813916
}
814917

815918
// CB -> Tensor: write from circular buffer to tensor.
816919
if (srcKind == CopySourceKind::CircularBuffer &&
817920
dstKind == CopyDestKind::TensorAccessor) {
818921
return lowerCBToTensor(op, adaptor.getSrc(), dst, rewriter,
819-
*typeConverter);
922+
*typeConverter, tensorAccessors);
820923
}
821924

822925
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
@@ -920,6 +1023,12 @@ static LogicalResult
9201023
lowerTTLOpsToTTKernel(ModuleOp mod, MLIRContext &ctx,
9211024
TTLToTTKernelTypeConverter &typeConverter,
9221025
StringRef passName) {
1026+
auto accessorMapsOrFailure = materializeFuncTensorAccessors(mod, ctx);
1027+
if (failed(accessorMapsOrFailure)) {
1028+
return failure();
1029+
}
1030+
FuncAccessorMapsType funcAccessorMaps = *accessorMapsOrFailure;
1031+
9231032
ConversionTarget target(ctx);
9241033
target.addIllegalDialect<tt::ttl::TTLDialect>();
9251034
target.addLegalDialect<arith::ArithDialect, BuiltinDialect, scf::SCFDialect,
@@ -951,9 +1060,10 @@ lowerTTLOpsToTTKernel(ModuleOp mod, MLIRContext &ctx,
9511060
});
9521061

9531062
RewritePatternSet patterns(&ctx);
954-
patterns.add<BindCBLowering, CopyLowering, WaitLowering, CBReserveLowering,
955-
CBPushLowering, CBWaitLowering, CBPopLowering, StoreLowering>(
956-
typeConverter, &ctx);
1063+
patterns.add<BindCBLowering, WaitLowering, CBReserveLowering, CBPushLowering,
1064+
CBWaitLowering, CBPopLowering, StoreLowering>(typeConverter,
1065+
&ctx);
1066+
patterns.add<CopyLowering>(typeConverter, &ctx, &funcAccessorMaps);
9571067
populateFunctionOpInterfaceTypeConversionPattern(
9581068
func::FuncOp::getOperationName(), patterns, typeConverter);
9591069

python/ttlang/ttl_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
register_tensor_source,
3636
)
3737
from ._src.ttl_ast import TTLGenericCompiler
38-
from .circular_buffer import CircularBuffer
38+
from .circular_buffer import CircularBuffer, get_cb_count
3939
from .constants import SUPPORTED_MEMORY_SPACES
4040
from .diagnostics import (
4141
TTLangCompileError,

python/ttlang/utils/block_allocation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# SPDX-License-Identifier: Apache-2.0
44
import itertools
55
import math
6+
import sympy
7+
from collections import namedtuple
68
from typing import List, Tuple
79

810

0 commit comments

Comments
 (0)