Skip to content

Commit 0cb029a

Browse files
committed
Add transpose, power, and where operations with tests
1 parent bc85e71 commit 0cb029a

File tree

5 files changed

+870
-0
lines changed

5 files changed

+870
-0
lines changed

include/ttlang/Dialect/TTL/IR/TTLOps.td

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,36 @@ def TTL_TileMatmulOp : TTL_TileOp<"tile_matmul", [TTL_CBInputTileOpTrait]> {
629629
let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` `(` type($a) `,` type($b) `,` type($c) `)` `->` type($result)";
630630
}
631631

632+
//===----------------------------------------------------------------------===//
633+
// Transpose Operation
634+
//===----------------------------------------------------------------------===//
635+
636+
def TTL_TransposeOp : TTL_Op<"transpose", []> {
637+
let summary = "Transpose a 2D tile tensor (swap width and height)";
638+
let description = [{
639+
Performs width-height transpose on input tiles. Reads from input CB,
640+
writes result to output CB. Both operands must be CB-attached tensors.
641+
642+
This operation transposes each 32x32 tile, swapping rows and columns.
643+
The input tensor shape [M, N] becomes output shape [N, M] in tiles.
644+
}];
645+
let arguments = (ins AnyType:$input, AnyType:$output);
646+
let results = (outs AnyType:$result);
647+
let assemblyFormat = "$input `,` $output attr-dict `:` `(` type($input) `,` type($output) `)` `->` type($result)";
648+
}
649+
650+
def TTL_TileTransposeOp : TTL_TileOp<"tile_transpose", [TTL_CBInputTileOpTrait]> {
651+
let summary = "Tile-level width-height transpose operation";
652+
let description = [{
653+
Transposes a single 32x32 tile, swapping rows and columns.
654+
Maps to ttkernel.transpose_wh_init + ttkernel.transpose_wh_tile.
655+
Reads directly from circular buffer (not DST), writes to DST.
656+
}];
657+
let arguments = (ins AnyType:$input, AnyType:$output);
658+
let results = (outs AnyType:$result);
659+
let assemblyFormat = "$input `,` $output attr-dict `:` `(` type($input) `,` type($output) `)` `->` type($result)";
660+
}
661+
632662
//===----------------------------------------------------------------------===//
633663
// Reduce Operation
634664
//===----------------------------------------------------------------------===//
@@ -684,6 +714,71 @@ def TTL_TileReduceOp : TTL_TileOp<"tile_reduce", [TTL_CBInputTileOpTrait]> {
684714
}];
685715
}
686716

717+
//===----------------------------------------------------------------------===//
718+
// Power Operation
719+
//===----------------------------------------------------------------------===//
720+
721+
def TTL_PowerOp : TTL_Op<"power", []> {
722+
let summary = "Raise tensor elements to an integer power";
723+
let description = [{
724+
Computes element-wise power: output = input ^ exponent.
725+
Both input and output must be CB-attached tensors.
726+
The exponent is a compile-time constant integer.
727+
}];
728+
let arguments = (ins AnyType:$input, I32Attr:$exponent);
729+
let results = (outs AnyType:$result);
730+
let assemblyFormat = "$input `,` $exponent attr-dict `:` type($input) `->` type($result)";
731+
}
732+
733+
def TTL_TilePowerOp : TTL_TileOp<"tile_power", []> {
734+
let summary = "Tile-level power operation";
735+
let description = [{
736+
Raises each element of a tile to an integer power.
737+
Maps to ttkernel.power_tile_init + ttkernel.power_tile.
738+
Operates in-place in DST: DST[dst_idx] = DST[dst_idx] ^ exponent.
739+
}];
740+
let arguments = (ins AnyType:$input, I32Attr:$exponent);
741+
let results = (outs AnyType:$result);
742+
let assemblyFormat = "$input `,` $exponent attr-dict `:` type($input) `->` type($result)";
743+
}
744+
745+
//===----------------------------------------------------------------------===//
746+
// Where (Conditional Select) Operation
747+
//===----------------------------------------------------------------------===//
748+
749+
def TTL_WhereOp : TTL_Op<"where", []> {
750+
let summary = "Element-wise conditional selection";
751+
let description = [{
752+
Performs element-wise conditional selection:
753+
output = condition ? true_value : false_value
754+
755+
All three input tensors and output must be CB-attached tensors with the
756+
same shape. For each element, if condition is non-zero, selects from
757+
true_value, otherwise selects from false_value.
758+
}];
759+
let arguments = (ins AnyType:$condition, AnyType:$true_value, AnyType:$false_value);
760+
let results = (outs AnyType:$result);
761+
let assemblyFormat = [{
762+
$condition `,` $true_value `,` $false_value attr-dict `:`
763+
`(` type($condition) `,` type($true_value) `,` type($false_value) `)` `->` type($result)
764+
}];
765+
}
766+
767+
def TTL_TileWhereOp : TTL_TileOp<"tile_where", []> {
768+
let summary = "Tile-level conditional selection operation";
769+
let description = [{
770+
Performs element-wise conditional selection on tiles.
771+
Maps to ttkernel.where_tile_init + ttkernel.where_tile.
772+
All operands read from DST: DST[odst] = DST[cond] ? DST[true] : DST[false].
773+
}];
774+
let arguments = (ins AnyType:$condition, AnyType:$true_value, AnyType:$false_value);
775+
let results = (outs AnyType:$result);
776+
let assemblyFormat = [{
777+
$condition `,` $true_value `,` $false_value attr-dict `:`
778+
`(` type($condition) `,` type($true_value) `,` type($false_value) `)` `->` type($result)
779+
}];
780+
}
781+
687782
// Circular buffer synchronization operations
688783
//===----------------------------------------------------------------------===//
689784

lib/Dialect/TTL/Transforms/ConvertTTLTileOpsToTTKernel.cpp

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,173 @@ struct TTLTileReduceToTTKernel : OpConversionPattern<TileReduceOp> {
947947
}
948948
};
949949

950+
//===----------------------------------------------------------------------===//
951+
// Transpose Tile Op Lowering
952+
//===----------------------------------------------------------------------===//
953+
954+
/// Lower ttl.tile_transpose to TTKernel transpose_wh_init + transpose_wh_tile.
955+
/// Reads from input CB, writes to DST.
956+
/// For multi-tile transpose, computes transposed input CB index:
957+
/// Output position (i, j) reads from input position (j, i).
958+
struct TTLTileTransposeToTTKernel : OpConversionPattern<TileTransposeOp> {
959+
using OpConversionPattern<TileTransposeOp>::OpConversionPattern;
960+
961+
LogicalResult
962+
matchAndRewrite(TileTransposeOp op, TileTransposeOp::Adaptor adaptor,
963+
ConversionPatternRewriter &rewriter) const override {
964+
Location loc = op.getLoc();
965+
966+
auto funcOp = op->getParentOfType<func::FuncOp>();
967+
if (!funcOp) {
968+
return rewriter.notifyMatchFailure(op, "op not in function");
969+
}
970+
971+
auto *typeConverter = this->getTypeConverter();
972+
auto inCB =
973+
lookupAndConvertCB(op.getInput(), funcOp, typeConverter, rewriter, loc);
974+
if (failed(inCB)) {
975+
return rewriter.notifyMatchFailure(op, "cannot find/convert input CB");
976+
}
977+
978+
auto outCB = lookupAndConvertCB(op.getOutput(), funcOp, typeConverter,
979+
rewriter, loc);
980+
if (failed(outCB)) {
981+
funcOp->walk([&](InitSFPUOp initOp) {
982+
outCB = utils::convertTTLCBToTTKernel(initOp.getOcb(), rewriter, loc,
983+
typeConverter);
984+
return WalkResult::interrupt();
985+
});
986+
if (failed(outCB)) {
987+
return rewriter.notifyMatchFailure(op, "cannot find/convert output CB");
988+
}
989+
}
990+
991+
// Get DST index from attribute (assigned by TTLAssignDST pass).
992+
auto dstIdxAttr = op->getAttrOfType<IntegerAttr>(kDstIdxAttrName);
993+
if (!dstIdxAttr) {
994+
return rewriter.notifyMatchFailure(op, "missing dst_idx attribute");
995+
}
996+
int64_t dstIdxVal = dstIdxAttr.getInt();
997+
Value dstIdx = rewriter.create<arith::ConstantIndexOp>(loc, dstIdxVal);
998+
999+
// Get input CB shape to compute transposed index.
1000+
auto inShape = getCBTileGridShape(op.getInput(), funcOp);
1001+
if (!inShape) {
1002+
return rewriter.notifyMatchFailure(op, "cannot determine input CB shape");
1003+
}
1004+
int64_t inCols = inShape->second; // Input is [M, N], so N columns
1005+
1006+
// Compute transposed input CB index.
1007+
// Loop iterates over output shape [N, M]. For output position (i, j),
1008+
// we read from input position (j, i).
1009+
// Input CB index = j * N + i (linearized for input shape [M, N]).
1010+
SmallVector<scf::ForOp> loops = utils::collectEnclosingLoops(op);
1011+
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1012+
Value inCBIdx = zero;
1013+
1014+
if (loops.size() >= 2) {
1015+
// loops[0] = j (innermost, cols), loops[1] = i (outer, rows)
1016+
Value colIdx = loops[0].getInductionVar(); // j
1017+
Value rowIdx = loops[1].getInductionVar(); // i
1018+
// inCBIdx = colIdx * inCols + rowIdx (transposed indexing)
1019+
Value inColsVal = rewriter.create<arith::ConstantIndexOp>(loc, inCols);
1020+
Value colMulN = rewriter.create<arith::MulIOp>(loc, colIdx, inColsVal);
1021+
inCBIdx = rewriter.create<arith::AddIOp>(loc, colMulN, rowIdx);
1022+
} else if (loops.size() == 1) {
1023+
inCBIdx = loops[0].getInductionVar();
1024+
}
1025+
1026+
rewriter.create<ttk::TransposeInitOp>(loc, *inCB, *outCB);
1027+
rewriter.create<ttk::TransposeTileOp>(loc, *inCB, inCBIdx, dstIdx);
1028+
1029+
rewriter.replaceOp(op, adaptor.getInput());
1030+
return success();
1031+
}
1032+
};
1033+
1034+
//===----------------------------------------------------------------------===//
1035+
// Power Tile Op Lowering
1036+
//===----------------------------------------------------------------------===//
1037+
1038+
/// Lower ttl.tile_power to TTKernel power_tile_init + power_tile.
1039+
/// Operates in-place in DST with an integer exponent.
1040+
struct TTLTilePowerToTTKernel : OpConversionPattern<TilePowerOp> {
1041+
using OpConversionPattern<TilePowerOp>::OpConversionPattern;
1042+
1043+
LogicalResult
1044+
matchAndRewrite(TilePowerOp op, TilePowerOp::Adaptor adaptor,
1045+
ConversionPatternRewriter &rewriter) const override {
1046+
Location loc = op.getLoc();
1047+
1048+
auto dstIdxAttr = op->getAttrOfType<IntegerAttr>(kDstIdxAttrName);
1049+
if (!dstIdxAttr) {
1050+
return rewriter.notifyMatchFailure(op, "missing dst_idx attribute");
1051+
}
1052+
int64_t dstIdx = dstIdxAttr.getInt();
1053+
Value dstIdxVal = rewriter.create<arith::ConstantIndexOp>(loc, dstIdx);
1054+
1055+
// Get the exponent from the op's attribute.
1056+
Value exponent = rewriter.create<arith::ConstantOp>(
1057+
loc, rewriter.getI32IntegerAttr(op.getExponent()));
1058+
1059+
rewriter.create<ttk::PowerTileInitOp>(loc);
1060+
rewriter.create<ttk::PowUnaryTileOp>(loc, dstIdxVal, exponent);
1061+
1062+
rewriter.replaceOp(op, adaptor.getInput());
1063+
return success();
1064+
}
1065+
};
1066+
1067+
//===----------------------------------------------------------------------===//
1068+
// Where Tile Op Lowering
1069+
//===----------------------------------------------------------------------===//
1070+
1071+
/// Lower ttl.tile_where to TTKernel where_tile_init + where_tile.
1072+
/// Ternary DST-based op: cond ? true : false.
1073+
struct TTLTileWhereToTTKernel : OpConversionPattern<TileWhereOp> {
1074+
using OpConversionPattern<TileWhereOp>::OpConversionPattern;
1075+
1076+
LogicalResult
1077+
matchAndRewrite(TileWhereOp op, TileWhereOp::Adaptor adaptor,
1078+
ConversionPatternRewriter &rewriter) const override {
1079+
Location loc = op.getLoc();
1080+
1081+
auto dstIdxAttr = op->getAttrOfType<IntegerAttr>(kDstIdxAttrName);
1082+
if (!dstIdxAttr) {
1083+
return rewriter.notifyMatchFailure(op, "missing dst_idx attribute");
1084+
}
1085+
int64_t odstIdx = dstIdxAttr.getInt();
1086+
1087+
auto condIdxOpt = getDstIndexFromValue(op.getCondition());
1088+
auto trueIdxOpt = getDstIndexFromValue(op.getTrueValue());
1089+
auto falseIdxOpt = getDstIndexFromValue(op.getFalseValue());
1090+
1091+
if (!condIdxOpt) {
1092+
return rewriter.notifyMatchFailure(
1093+
op, "failed to extract dst_idx from condition operand");
1094+
}
1095+
if (!trueIdxOpt) {
1096+
return rewriter.notifyMatchFailure(
1097+
op, "failed to extract dst_idx from true_value operand");
1098+
}
1099+
if (!falseIdxOpt) {
1100+
return rewriter.notifyMatchFailure(
1101+
op, "failed to extract dst_idx from false_value operand");
1102+
}
1103+
1104+
Value condIdx = rewriter.create<arith::ConstantIndexOp>(loc, *condIdxOpt);
1105+
Value trueIdx = rewriter.create<arith::ConstantIndexOp>(loc, *trueIdxOpt);
1106+
Value falseIdx = rewriter.create<arith::ConstantIndexOp>(loc, *falseIdxOpt);
1107+
Value odst = rewriter.create<arith::ConstantIndexOp>(loc, odstIdx);
1108+
1109+
rewriter.create<ttk::WhereTileInitOp>(loc);
1110+
rewriter.create<ttk::WhereTileOp>(loc, condIdx, trueIdx, falseIdx, odst);
1111+
1112+
rewriter.replaceOp(op, adaptor.getCondition());
1113+
return success();
1114+
}
1115+
};
1116+
9501117
//===----------------------------------------------------------------------===//
9511118
// Tile Op Lowerings - Generated from TTLElementwiseOps.def
9521119
//===----------------------------------------------------------------------===//
@@ -1008,6 +1175,11 @@ void populateTTLTileOpsToTTKernelPatterns(TypeConverter *typeConverter,
10081175
patterns.add<TTLTileBcastToTTKernel>(*typeConverter, ctx);
10091176
patterns.add<TTLTileMatmulToTTKernel>(*typeConverter, ctx);
10101177
patterns.add<TTLTileReduceToTTKernel>(*typeConverter, ctx);
1178+
patterns.add<TTLTileTransposeToTTKernel>(*typeConverter, ctx);
1179+
1180+
// DST-based ops.
1181+
patterns.add<TTLTilePowerToTTKernel>(ctx);
1182+
patterns.add<TTLTileWhereToTTKernel>(ctx);
10111183

10121184
// TODO(#124): Add DST lifecycle wrapper pattern for loop iterations
10131185
// (acquire/commit/wait/release + copy_tile/pack_tile)

0 commit comments

Comments
 (0)