@@ -947,6 +947,173 @@ struct TTLTileReduceToTTKernel : OpConversionPattern<TileReduceOp> {
947947 }
948948};
949949
950+ // ===----------------------------------------------------------------------===//
951+ // Transpose Tile Op Lowering
952+ // ===----------------------------------------------------------------------===//
953+
954+ // / Lower ttl.tile_transpose to TTKernel transpose_wh_init + transpose_wh_tile.
955+ // / Reads from input CB, writes to DST.
956+ // / For multi-tile transpose, computes transposed input CB index:
957+ // / Output position (i, j) reads from input position (j, i).
958+ struct TTLTileTransposeToTTKernel : OpConversionPattern<TileTransposeOp> {
959+ using OpConversionPattern<TileTransposeOp>::OpConversionPattern;
960+
961+ LogicalResult
962+ matchAndRewrite (TileTransposeOp op, TileTransposeOp::Adaptor adaptor,
963+ ConversionPatternRewriter &rewriter) const override {
964+ Location loc = op.getLoc ();
965+
966+ auto funcOp = op->getParentOfType <func::FuncOp>();
967+ if (!funcOp) {
968+ return rewriter.notifyMatchFailure (op, " op not in function" );
969+ }
970+
971+ auto *typeConverter = this ->getTypeConverter ();
972+ auto inCB =
973+ lookupAndConvertCB (op.getInput (), funcOp, typeConverter, rewriter, loc);
974+ if (failed (inCB)) {
975+ return rewriter.notifyMatchFailure (op, " cannot find/convert input CB" );
976+ }
977+
978+ auto outCB = lookupAndConvertCB (op.getOutput (), funcOp, typeConverter,
979+ rewriter, loc);
980+ if (failed (outCB)) {
981+ funcOp->walk ([&](InitSFPUOp initOp) {
982+ outCB = utils::convertTTLCBToTTKernel (initOp.getOcb (), rewriter, loc,
983+ typeConverter);
984+ return WalkResult::interrupt ();
985+ });
986+ if (failed (outCB)) {
987+ return rewriter.notifyMatchFailure (op, " cannot find/convert output CB" );
988+ }
989+ }
990+
991+ // Get DST index from attribute (assigned by TTLAssignDST pass).
992+ auto dstIdxAttr = op->getAttrOfType <IntegerAttr>(kDstIdxAttrName );
993+ if (!dstIdxAttr) {
994+ return rewriter.notifyMatchFailure (op, " missing dst_idx attribute" );
995+ }
996+ int64_t dstIdxVal = dstIdxAttr.getInt ();
997+ Value dstIdx = rewriter.create <arith::ConstantIndexOp>(loc, dstIdxVal);
998+
999+ // Get input CB shape to compute transposed index.
1000+ auto inShape = getCBTileGridShape (op.getInput (), funcOp);
1001+ if (!inShape) {
1002+ return rewriter.notifyMatchFailure (op, " cannot determine input CB shape" );
1003+ }
1004+ int64_t inCols = inShape->second ; // Input is [M, N], so N columns
1005+
1006+ // Compute transposed input CB index.
1007+ // Loop iterates over output shape [N, M]. For output position (i, j),
1008+ // we read from input position (j, i).
1009+ // Input CB index = j * N + i (linearized for input shape [M, N]).
1010+ SmallVector<scf::ForOp> loops = utils::collectEnclosingLoops (op);
1011+ Value zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1012+ Value inCBIdx = zero;
1013+
1014+ if (loops.size () >= 2 ) {
1015+ // loops[0] = j (innermost, cols), loops[1] = i (outer, rows)
1016+ Value colIdx = loops[0 ].getInductionVar (); // j
1017+ Value rowIdx = loops[1 ].getInductionVar (); // i
1018+ // inCBIdx = colIdx * inCols + rowIdx (transposed indexing)
1019+ Value inColsVal = rewriter.create <arith::ConstantIndexOp>(loc, inCols);
1020+ Value colMulN = rewriter.create <arith::MulIOp>(loc, colIdx, inColsVal);
1021+ inCBIdx = rewriter.create <arith::AddIOp>(loc, colMulN, rowIdx);
1022+ } else if (loops.size () == 1 ) {
1023+ inCBIdx = loops[0 ].getInductionVar ();
1024+ }
1025+
1026+ rewriter.create <ttk::TransposeInitOp>(loc, *inCB, *outCB);
1027+ rewriter.create <ttk::TransposeTileOp>(loc, *inCB, inCBIdx, dstIdx);
1028+
1029+ rewriter.replaceOp (op, adaptor.getInput ());
1030+ return success ();
1031+ }
1032+ };
1033+
1034+ // ===----------------------------------------------------------------------===//
1035+ // Power Tile Op Lowering
1036+ // ===----------------------------------------------------------------------===//
1037+
1038+ // / Lower ttl.tile_power to TTKernel power_tile_init + power_tile.
1039+ // / Operates in-place in DST with an integer exponent.
1040+ struct TTLTilePowerToTTKernel : OpConversionPattern<TilePowerOp> {
1041+ using OpConversionPattern<TilePowerOp>::OpConversionPattern;
1042+
1043+ LogicalResult
1044+ matchAndRewrite (TilePowerOp op, TilePowerOp::Adaptor adaptor,
1045+ ConversionPatternRewriter &rewriter) const override {
1046+ Location loc = op.getLoc ();
1047+
1048+ auto dstIdxAttr = op->getAttrOfType <IntegerAttr>(kDstIdxAttrName );
1049+ if (!dstIdxAttr) {
1050+ return rewriter.notifyMatchFailure (op, " missing dst_idx attribute" );
1051+ }
1052+ int64_t dstIdx = dstIdxAttr.getInt ();
1053+ Value dstIdxVal = rewriter.create <arith::ConstantIndexOp>(loc, dstIdx);
1054+
1055+ // Get the exponent from the op's attribute.
1056+ Value exponent = rewriter.create <arith::ConstantOp>(
1057+ loc, rewriter.getI32IntegerAttr (op.getExponent ()));
1058+
1059+ rewriter.create <ttk::PowerTileInitOp>(loc);
1060+ rewriter.create <ttk::PowUnaryTileOp>(loc, dstIdxVal, exponent);
1061+
1062+ rewriter.replaceOp (op, adaptor.getInput ());
1063+ return success ();
1064+ }
1065+ };
1066+
1067+ // ===----------------------------------------------------------------------===//
1068+ // Where Tile Op Lowering
1069+ // ===----------------------------------------------------------------------===//
1070+
1071+ // / Lower ttl.tile_where to TTKernel where_tile_init + where_tile.
1072+ // / Ternary DST-based op: cond ? true : false.
1073+ struct TTLTileWhereToTTKernel : OpConversionPattern<TileWhereOp> {
1074+ using OpConversionPattern<TileWhereOp>::OpConversionPattern;
1075+
1076+ LogicalResult
1077+ matchAndRewrite (TileWhereOp op, TileWhereOp::Adaptor adaptor,
1078+ ConversionPatternRewriter &rewriter) const override {
1079+ Location loc = op.getLoc ();
1080+
1081+ auto dstIdxAttr = op->getAttrOfType <IntegerAttr>(kDstIdxAttrName );
1082+ if (!dstIdxAttr) {
1083+ return rewriter.notifyMatchFailure (op, " missing dst_idx attribute" );
1084+ }
1085+ int64_t odstIdx = dstIdxAttr.getInt ();
1086+
1087+ auto condIdxOpt = getDstIndexFromValue (op.getCondition ());
1088+ auto trueIdxOpt = getDstIndexFromValue (op.getTrueValue ());
1089+ auto falseIdxOpt = getDstIndexFromValue (op.getFalseValue ());
1090+
1091+ if (!condIdxOpt) {
1092+ return rewriter.notifyMatchFailure (
1093+ op, " failed to extract dst_idx from condition operand" );
1094+ }
1095+ if (!trueIdxOpt) {
1096+ return rewriter.notifyMatchFailure (
1097+ op, " failed to extract dst_idx from true_value operand" );
1098+ }
1099+ if (!falseIdxOpt) {
1100+ return rewriter.notifyMatchFailure (
1101+ op, " failed to extract dst_idx from false_value operand" );
1102+ }
1103+
1104+ Value condIdx = rewriter.create <arith::ConstantIndexOp>(loc, *condIdxOpt);
1105+ Value trueIdx = rewriter.create <arith::ConstantIndexOp>(loc, *trueIdxOpt);
1106+ Value falseIdx = rewriter.create <arith::ConstantIndexOp>(loc, *falseIdxOpt);
1107+ Value odst = rewriter.create <arith::ConstantIndexOp>(loc, odstIdx);
1108+
1109+ rewriter.create <ttk::WhereTileInitOp>(loc);
1110+ rewriter.create <ttk::WhereTileOp>(loc, condIdx, trueIdx, falseIdx, odst);
1111+
1112+ rewriter.replaceOp (op, adaptor.getCondition ());
1113+ return success ();
1114+ }
1115+ };
1116+
9501117// ===----------------------------------------------------------------------===//
9511118// Tile Op Lowerings - Generated from TTLElementwiseOps.def
9521119// ===----------------------------------------------------------------------===//
@@ -1008,6 +1175,11 @@ void populateTTLTileOpsToTTKernelPatterns(TypeConverter *typeConverter,
10081175 patterns.add <TTLTileBcastToTTKernel>(*typeConverter, ctx);
10091176 patterns.add <TTLTileMatmulToTTKernel>(*typeConverter, ctx);
10101177 patterns.add <TTLTileReduceToTTKernel>(*typeConverter, ctx);
1178+ patterns.add <TTLTileTransposeToTTKernel>(*typeConverter, ctx);
1179+
1180+ // DST-based ops.
1181+ patterns.add <TTLTilePowerToTTKernel>(ctx);
1182+ patterns.add <TTLTileWhereToTTKernel>(ctx);
10111183
10121184 // TODO(#124): Add DST lifecycle wrapper pattern for loop iterations
10131185 // (acquire/commit/wait/release + copy_tile/pack_tile)
0 commit comments