Skip to content

Commit e20cd5f

Browse files
authored
D2M: cleanup and preparation for bcast (#5375)
### Ticket None ### Problem description Break down the bcast PR into smaller ones, this one is mainly for cleanups. ### What's changed * Remove redundant `bcast_type` that clutters the IR. * Extract the logic that generates the compute region of `linalg.generic` into its own function. * Remove support for `FuncOp` from the insert DST register pass as it's getting in the way more and more. * Update the simple `FuncOp` tests to full `d2m.generic` equivalents (OMG so painful, but it's a one-time cost). ### Checklist - [x] New/Existing tests provide coverage for changes Signed-off-by: wenbinlyuTT <[email protected]>
1 parent aa5d5b0 commit e20cd5f

File tree

12 files changed

+1190
-927
lines changed

12 files changed

+1190
-927
lines changed

include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsEnums.td

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -105,22 +105,6 @@ def TTKernel_ReduceType : I32EnumAttr<"ReduceType", "TTKernel Reduce Types",
105105
let cppNamespace = "::mlir::tt::ttkernel";
106106
}
107107

108-
def TTKernel_BcastTypeNone : I32EnumAttrCase<"None", 0, "bcast_type_none">;
109-
def TTKernel_BcastTypeCol : I32EnumAttrCase<"Col", 1, "bcast_type_col">;
110-
def TTKernel_BcastTypeRow : I32EnumAttrCase<"Row", 2, "bcast_type_row">;
111-
def TTKernel_BcastTypeScalar : I32EnumAttrCase<"Scalar", 3, "bcast_type_scalar">;
112-
113-
def TTKernel_BcastType : I32EnumAttr<"BcastType", "TTKernel broadcast types",
114-
[
115-
TTKernel_BcastTypeNone,
116-
TTKernel_BcastTypeCol,
117-
TTKernel_BcastTypeRow,
118-
TTKernel_BcastTypeScalar
119-
]> {
120-
let genSpecializedAttr = 0;
121-
let cppNamespace = "::mlir::tt::ttkernel";
122-
}
123-
124108
def TTKernel_ReduceDimRow : I32EnumAttrCase<"Row", 0, "reduce_dim_row">;
125109
def TTKernel_ReduceDimCol : I32EnumAttrCase<"Col", 1, "reduce_dim_col">;
126110
def TTKernel_ReduceDimScalar
@@ -134,6 +118,22 @@ def TTKernel_ReduceDim
134118
let cppNamespace = "::mlir::tt::ttkernel";
135119
}
136120

121+
def TTKernel_BcastTypeNone : I32EnumAttrCase<"None", 0, "none">;
122+
def TTKernel_BcastTypeCol : I32EnumAttrCase<"Col", 1, "col">;
123+
def TTKernel_BcastTypeRow : I32EnumAttrCase<"Row", 2, "row">;
124+
def TTKernel_BcastTypeScalar : I32EnumAttrCase<"Scalar", 3, "scalar">;
125+
126+
def TTKernel_BcastType : I32EnumAttr<"BcastType", "TTKernel Broadcast Types",
127+
[
128+
TTKernel_BcastTypeNone,
129+
TTKernel_BcastTypeCol,
130+
TTKernel_BcastTypeRow,
131+
TTKernel_BcastTypeScalar
132+
]> {
133+
let genSpecializedAttr = 0;
134+
let cppNamespace = "::mlir::tt::ttkernel";
135+
}
136+
137137
def TTKernel_ArgTypeCBPort : I32EnumAttrCase<"CBPort", 0, "cb_port">;
138138
def TTKernel_ArgTypeBufferAddress : I32EnumAttrCase<"BufferAddress", 1, "buffer_address">;
139139
def TTKernel_ArgTypeSemaphore : I32EnumAttrCase<"Semaphore", 2, "semaphore">;

lib/Conversion/TTIRToD2M/TTIRToD2M.cpp

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,13 @@
1616
#include "ttmlir/Dialect/D2M/IR/D2MGenericRegionOps.h"
1717
#include "ttmlir/Dialect/TTCore/IR/TTCoreOpsTypes.h"
1818
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
19-
#include "ttmlir/Dialect/TTIR/Utils/Utils.h"
2019
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
2120

2221
#include "mlir/Dialect/Linalg/IR/Linalg.h"
23-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2422
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2523
#include "mlir/IR/AffineExpr.h"
2624
#include "mlir/IR/AffineMap.h"
2725
#include "mlir/IR/BuiltinTypes.h"
28-
#include "mlir/IR/IRMapping.h"
2926
#include "mlir/IR/TypeRange.h"
3027
#include "mlir/IR/ValueRange.h"
3128
#include "mlir/Transforms/DialectConversion.h"
@@ -404,7 +401,26 @@ class D2MNamedElementwiseRewriter final
404401
std::is_same_v<TileOp, d2m::TileLtzOp> ||
405402
std::is_same_v<TileOp, d2m::TileLezOp>;
406403

407-
private:
404+
void createComputeRegion(mlir::OpBuilder &bbBuilder, mlir::Location bbLoc,
405+
mlir::ValueRange bbArgs,
406+
mlir::ConversionPatternRewriter &rewriter,
407+
mlir::Location loc, const size_t numInputs,
408+
const size_t numOutputs) const {
409+
mlir::ValueRange operands = bbArgs.take_front(numInputs);
410+
mlir::TypeRange resultTypes = bbArgs.take_back(numOutputs);
411+
412+
mlir::Value yield;
413+
if constexpr (isComparisonOp) {
414+
// For comparison ops, first subtract then compare with zero.
415+
yield = bbBuilder.create<d2m::TileSubOp>(loc, resultTypes, operands);
416+
yield = bbBuilder.create<TileOp>(loc, resultTypes, yield);
417+
} else {
418+
yield = bbBuilder.create<TileOp>(loc, resultTypes, operands);
419+
}
420+
421+
bbBuilder.create<mlir::linalg::YieldOp>(bbLoc, yield);
422+
}
423+
408424
LogicalResult
409425
matchAndRewrite(ConcreteOp op, typename ConcreteOp::Adaptor adaptor,
410426
mlir::ConversionPatternRewriter &rewriter) const final {
@@ -467,25 +483,8 @@ class D2MNamedElementwiseRewriter final
467483
linalgIteratorTypes,
468484
[&](mlir::OpBuilder &bbBuilder, mlir::Location bbLoc,
469485
mlir::ValueRange bbArgs) {
470-
mlir::Value yield;
471-
472-
if constexpr (isComparisonOp) {
473-
// For comparison ops, first subtract then compare with zero.
474-
mlir::Value subResult = bbBuilder.create<d2m::TileSubOp>(
475-
loc, /*resultTypes=*/bbArgs.take_back(numOutputs),
476-
/*operands=*/bbArgs.take_front(numInputs));
477-
yield = bbBuilder.create<TileOp>(
478-
loc, /*resultTypes=*/bbArgs.take_back(numOutputs),
479-
/*operands=*/subResult);
480-
} else {
481-
// For regular elementwise ops, create TileOp directly.
482-
yield = bbBuilder.create<TileOp>(
483-
loc,
484-
/* resultTypes */ bbArgs.take_back(numOutputs).getTypes(),
485-
/* operands */ bbArgs.take_front(numInputs));
486-
}
487-
488-
bbBuilder.create<mlir::linalg::YieldOp>(bbLoc, yield);
486+
createComputeRegion(bbBuilder, bbLoc, bbArgs, rewriter, loc,
487+
numInputs, numOutputs);
489488
});
490489

491490
rewriter.create<d2m::YieldOp>(loc, linalgGeneric->getResults());

lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88
#include "ttmlir/Dialect/TTIR/Utils/Utils.h"
99

1010
#include "mlir/Dialect/Arith/IR/Arith.h"
11-
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12-
#include "mlir/Dialect/Func/IR/FuncOps.h"
1311
#include "mlir/Dialect/Linalg/IR/Linalg.h"
14-
#include "mlir/Dialect/Math/IR/Math.h"
1512
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1613
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1714
#include "mlir/IR/Attributes.h"
@@ -20,7 +17,6 @@
2017
#include "mlir/IR/Types.h"
2118
#include "mlir/IR/Value.h"
2219
#include "mlir/IR/ValueRange.h"
23-
#include "mlir/Pass/Pass.h"
2420
#include "mlir/Support/LogicalResult.h"
2521
#include "mlir/Transforms/DialectConversion.h"
2622
#include "llvm/ADT/SmallVector.h"
@@ -664,7 +660,7 @@ class BroadcastOpConversionPattern
664660
auto zerosConst =
665661
rewriter.create<tosa::ConstOp>(op.getLoc(), resultType, zerosAttr);
666662

667-
// Multiply by ones to implicitly broadcast
663+
// Add by zeros to implicitly broadcast.
668664
auto result = rewriter.create<tosa::AddOp>(op.getLoc(), resultType, input,
669665
zerosConst);
670666

0 commit comments

Comments
 (0)