Skip to content

Commit 83be687

Browse files
authored
Fusing on by default + Multiply commute pattern rewrite (#3946)
Extending multiply commute to support commuting multiply when scale is not coming from block argument directly. Before we only supported this pattern: ``` constant_argument conv2d | | | | +-------multiply-----+ ``` In mobilenet there are bunch of layers which have this pattern: ``` constant_argument | | | transpose | | | transpose | | | broadcast conv2d | | | | +-------multiply-----+ ``` This PR adds small extension to `Conv2dWithMultiply` which can match scale coming directly from block argument or scale coming from broadcast where subgraph which is input into broadcast is const eval. For example above graph can be commuted since input into graph is constant but something like below can't: ``` constant_argument input | | | | +--------add---------+ | | | broadcast conv2d | | | | +-------multiply-----+ ``` ``` constant_argument | | | transpose | | | transpose conv2d | | | | +-------multiply-----+ ``` To check if subraph is fusable we start from`scale` argument in `isCommutable` and we construct [UD chain](https://en.wikipedia.org/wiki/Use-define_chain) and we use it to check if inputs into this subgraph are constants. When we determine that subgraph is const eval we commute whole subgraph before conv2d and apply reshape like we did before to align channel dim with weight. So resulting graph after commute would become: ``` constant_argument | | | transpose | | | transpose | | | reshape | | | broadcast weight | | | | | | multiply--------+ | | | conv2d ``` Or in no broadcast case: ``` constant_argument | | | reshape weight | | | | | | multiply--------+ | | | conv2d ``` In addition this PR tags clamp scalar with eltwise unary trait which would enable TM to commute through it.
1 parent b824083 commit 83be687

File tree

10 files changed

+225
-42
lines changed

10 files changed

+225
-42
lines changed

include/ttmlir/Dialect/StableHLO/Transforms/ShardyUtils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
#include "ttmlir/Dialect/StableHLO/Transforms/ShardyCCLToStableHLOCCL.h"
99

10-
#include "mlir/Analysis/TopologicalSortUtils.h"
1110
#include "mlir/Dialect/Func/IR/FuncOps.h"
1211
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1312
#include "mlir/IR/Builders.h"

include/ttmlir/Dialect/TTIR/IR/TTIROps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4075,7 +4075,7 @@ def TTIR_UnsqueezeOp : TTIR_NamedOp<"unsqueeze"> {
40754075
let hasVerifier = 1;
40764076
}
40774077

4078-
def TTIR_ClampScalarOp : TTIR_NamedOp<"clamp_scalar"> {
4078+
def TTIR_ClampScalarOp : TTIR_NamedOp<"clamp_scalar", [TTIR_ElementwiseUnary]> {
40794079
let summary = "Scalar value clamping operation.";
40804080
let description = [{
40814081
The `clamp_scalar` operation constrains all elements of a tensor to be within a specified range.

include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ struct TTIRToTTNNBackendPipelineOptions
222222

223223
Option<bool> enableFusing{*this, "enable-fusing-pass",
224224
llvm::cl::desc("Enable fusing pass."),
225-
llvm::cl::init(false)};
225+
llvm::cl::init(true)};
226226

227227
Option<ttcore::TTArgumentTypeMap, ttcore::ArgumentTypeMapParser>
228228
argumentTypeMap{

include/ttmlir/Utils.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,63 @@ OpType getOutermostLoopNest(mlir::ValueRange values) {
486486

487487
} // namespace loop
488488

489+
// Given a startng mlir::Value return a set of all values in the use-def
490+
// chain. This chain is not topologically sorted, so the order of values in the
491+
// result is not guaranteed. If you want to topologically sort the chain
492+
// use topologicalSort.
493+
inline llvm::SetVector<mlir::Value> getUseDefChain(mlir::Value start) {
494+
llvm::SetVector<mlir::Value> useDefChain;
495+
llvm::SmallVector<mlir::Value> worklist{start};
496+
llvm::SmallPtrSet<mlir::Value, 4> visited;
497+
498+
while (!worklist.empty()) {
499+
mlir::Value value = worklist.pop_back_val();
500+
useDefChain.insert(value);
501+
502+
mlir::Operation *defOp = value.getDefiningOp();
503+
if (!defOp) {
504+
continue;
505+
}
506+
507+
for (mlir::OpOperand &operand : defOp->getOpOperands()) {
508+
mlir::Value operandValue = operand.get();
509+
if (visited.contains(operandValue)) {
510+
continue;
511+
}
512+
visited.insert(operandValue);
513+
worklist.push_back(operandValue);
514+
}
515+
}
516+
517+
return useDefChain;
518+
}
519+
520+
// Given list of mlir::Value filter out block arguments.
521+
inline llvm::SetVector<mlir::BlockArgument>
522+
filterBlockArguments(llvm::ArrayRef<mlir::Value> values) {
523+
llvm::SetVector<mlir::BlockArgument> blockArgs;
524+
for (mlir::Value value : values) {
525+
if (auto blockArg = llvm::dyn_cast<mlir::BlockArgument>(value)) {
526+
blockArgs.insert(blockArg);
527+
}
528+
}
529+
530+
return blockArgs;
531+
}
532+
533+
// Given list of mlir::Value filter out operations that define them.
534+
// If value is not operation it is ignored.
535+
inline llvm::SetVector<mlir::Operation *>
536+
filterOperations(llvm::ArrayRef<mlir::Value> values) {
537+
llvm::SetVector<mlir::Operation *> ops;
538+
for (mlir::Value value : values) {
539+
if (auto *op = value.getDefiningOp()) {
540+
ops.insert(op);
541+
}
542+
}
543+
544+
return ops;
545+
}
489546
} // namespace ttmlir::utils
490547

491548
#endif // TTMLIR_UTILS_H

lib/Dialect/TTIR/Transforms/ExplicateTMs.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ class ExplicateBroadcastsRewriter
102102
llvm::ArrayRef<int64_t> operandShape = operandType.getShape();
103103

104104
llvm::SmallVector<int64_t> broadcastDimensions =
105-
getBroadcastDimensions(operandShape, broadcastedShape);
105+
ttmlir::utils::getBroadcastDimensions<int64_t>(operandShape,
106+
broadcastedShape);
106107
if (llvm::all_of(broadcastDimensions, [](int64_t i) { return i == 1; })) {
107108
continue;
108109
}
@@ -157,19 +158,6 @@ class ExplicateBroadcastsRewriter
157158

158159
return broadcastedShape;
159160
}
160-
161-
llvm::SmallVector<int64_t>
162-
getBroadcastDimensions(llvm::ArrayRef<int64_t> operandShape,
163-
llvm::ArrayRef<int64_t> targetShape) const {
164-
llvm::SmallVector<int64_t> broadcastDimensions(operandShape.size(), 1);
165-
for (size_t dim = 0; dim < operandShape.size(); dim++) {
166-
if (operandShape[dim] < targetShape[dim]) {
167-
broadcastDimensions[dim] = targetShape[dim];
168-
}
169-
}
170-
171-
return broadcastDimensions;
172-
}
173161
};
174162
} // namespace
175163

lib/Dialect/TTIR/Transforms/TTIRFusing.cpp

Lines changed: 94 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "ttmlir/Dialect/TTIR/Utils/Utils.h"
77
#include "ttmlir/Utils.h"
88

9-
#include "mlir/IR/Value.h"
9+
#include "mlir/Analysis/TopologicalSortUtils.h"
1010
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1111

1212
namespace mlir::tt::ttir {
@@ -251,16 +251,30 @@ class Conv2dWithMultiply : public mlir::OpRewritePattern<MultiplyOp> {
251251
Conv2dOp conv2dOp = components->first;
252252
Value scaleValue = components->second;
253253

254-
// Insert before conv2d op.
255-
rewriter.setInsertionPoint(conv2dOp);
256-
257254
// Reshape scale to match weight dimensions and pre-multiply weights.
258255
Value reshapedScale =
259-
createReshapedScale(rewriter, conv2dOp.getLoc(), scaleValue);
256+
createReshapedScale(rewriter, conv2dOp.getLoc(), scaleValue,
257+
conv2dOp.getWeight().getType());
258+
259+
// Get UD chain starting from the reshaped scale. This chain will be
260+
// moved before the conv2dOp to ensure that weight scale can be
261+
// const-evaled.
262+
SetVector<Value> udChain = ttmlir::utils::getUseDefChain(reshapedScale);
263+
SetVector<Operation *> udChainOps =
264+
ttmlir::utils::filterOperations(udChain.getArrayRef());
265+
SetVector<Operation *> udChainSorted = topologicalSort(udChainOps);
266+
for (auto *op : udChainSorted) {
267+
op->moveBefore(conv2dOp);
268+
}
269+
270+
rewriter.setInsertionPoint(conv2dOp);
271+
272+
// Create scaled weights by multiplying the original weights with the
273+
// resshaped scale.
260274
Value scaledWeights = createScaledWeights(
261275
rewriter, conv2dOp.getLoc(), conv2dOp.getWeight(), reshapedScale);
262276

263-
// Update conv2d to use scaled weights and replace multiply operation
277+
// Update conv2d to use scaled weights and replace multiply operation.
264278
rewriter.modifyOpInPlace(
265279
conv2dOp, [&]() { conv2dOp.getWeightMutable().assign(scaledWeights); });
266280
rewriter.replaceAllOpUsesWith(multiplyOp, conv2dOp);
@@ -295,22 +309,50 @@ class Conv2dWithMultiply : public mlir::OpRewritePattern<MultiplyOp> {
295309
mlir::func::FuncOp funcOp = conv2dOp->getParentOfType<mlir::func::FuncOp>();
296310
llvm::SmallPtrSet<BlockArgument, 4> constParams =
297311
mlir::tt::ttcore::getConstsAndParams(funcOp);
298-
auto isConstant = [&constParams, conv2dOp](mlir::Value value) {
312+
auto isConstant = [&constParams](mlir::Value value) {
299313
if (auto blockArg = mlir::dyn_cast<BlockArgument>(value)) {
300314
return constParams.contains(blockArg);
301315
}
302316

303-
Operation *op = value.getDefiningOp();
304-
return op->hasTrait<mlir::tt::ttcore::Trait::TTCoreCreationOpTrait>() &&
305-
op->isBeforeInBlock(conv2dOp);
317+
Operation *defOp = value.getDefiningOp();
318+
return defOp->hasTrait<mlir::tt::ttcore::Trait::TTCoreCreationOpTrait>();
306319
};
307320

308-
// Both scale and weight must be constant.
309-
if (!isConstant(scale) || !isConstant(conv2dOp.getWeight())) {
321+
// If weight is not constant, we cannot commute.
322+
if (!isConstant(conv2dOp.getWeight())) {
323+
return false;
324+
}
325+
326+
RankedTensorType scaleType = scale.getType();
327+
// If scale is comming from broadcast then we want to use the input type
328+
// to the broadcast to check the shape.
329+
if (auto bcastOp =
330+
mlir::dyn_cast_if_present<BroadcastOp>(scale.getDefiningOp())) {
331+
scaleType = bcastOp.getInput().getType();
332+
}
333+
334+
// Check if scale shape is with conv2d weight.
335+
if (!hasValidScaleShape(conv2dOp, scaleType)) {
336+
return false;
337+
}
338+
339+
// Now we want to check if operations which produce scale are
340+
// const-evalable. We do this by getting UD chain of the scale and then
341+
// checking if all inputs into this chain are constants.
342+
SetVector<Value> useDefChain = ttmlir::utils::getUseDefChain(scale);
343+
SetVector<BlockArgument> useDefChainBlockArgs =
344+
ttmlir::utils::filterBlockArguments(useDefChain.getArrayRef());
345+
if (!all_of(useDefChainBlockArgs, isConstant)) {
310346
return false;
311347
}
312348

313-
return hasValidScaleShape(conv2dOp, scale.getType());
349+
// Since we want to move the scale chain before conv2dOp we want to make
350+
// sure that the scale chain does not contain conv2dOp.
351+
if (useDefChain.contains(conv2dOp)) {
352+
return false;
353+
}
354+
355+
return true;
314356
}
315357

316358
// Scale must have rank 4 and shape (1, 1, 1, out_channels).
@@ -320,11 +362,31 @@ class Conv2dWithMultiply : public mlir::OpRewritePattern<MultiplyOp> {
320362
scaleType.getDimSize(3) == convOp.getOutputChannelSize();
321363
}
322364

365+
// There are two cases we want to handle here:
366+
// 1. Input scale is a constant tensor that only neeeds reshaping
367+
// 2. Input scale is a broadcast operation that needs reshaping
368+
//
369+
// In case of 1 we just add reshape operation to the scale tensor such that
370+
// it has shape (out_channels, 1, 1, 1).
371+
//
372+
// In case of 2 we need to add reshape operation to the input of the of bcast
373+
// and then we create new broadcast operation with the new reshaped scale
374+
// which broadcasts the reshaped scale to the shape of the weight tensor.
323375
static Value createReshapedScale(mlir::PatternRewriter &rewriter,
324-
Location loc, Value scaleValue) {
376+
Location loc, Value scaleValue,
377+
RankedTensorType weightType) {
378+
// If scaleValue is broadcast operation we want to reshape its input.
379+
// Otherwise we reshape the scaleValue itself.
380+
Value reshapeInput = scaleValue;
381+
if (auto bcastOp = mlir::dyn_cast_if_present<BroadcastOp>(
382+
scaleValue.getDefiningOp())) {
383+
rewriter.setInsertionPoint(bcastOp);
384+
reshapeInput = bcastOp.getInput();
385+
}
386+
325387
// Get the scale's type.
326388
RankedTensorType scaleType =
327-
mlir::cast<RankedTensorType>(scaleValue.getType());
389+
mlir::cast<RankedTensorType>(reshapeInput.getType());
328390

329391
// Create a new shape (out_channels, 1, 1, 1) from (1, 1, 1, out_channels).
330392
llvm::SmallVector<int64_t> newShape(scaleType.getShape());
@@ -335,11 +397,25 @@ class Conv2dWithMultiply : public mlir::OpRewritePattern<MultiplyOp> {
335397
// Convert to int32 for the reshape operation.
336398
llvm::SmallVector<int32_t> newShapeI32(newShape.begin(), newShape.end());
337399

338-
// Create and return the reshape operation.
339-
return ttir::utils::createDPSOp<ttir::ReshapeOp>(
400+
// Create the reshape operation.
401+
auto reshapedScale = ttir::utils::createDPSOp<ttir::ReshapeOp>(
340402
rewriter, ttmlir::utils::appendLocationSuffix(loc, "_reshape"),
341403
newShape, scaleType.getElementType(), scaleType.getEncoding(),
342-
scaleValue, rewriter.getI32ArrayAttr(newShapeI32));
404+
reshapeInput, rewriter.getI32ArrayAttr(newShapeI32));
405+
406+
// If scale value is not a broadcast operation we can return reshapedScale.
407+
if (!isa_and_present<ttir::BroadcastOp>(scaleValue.getDefiningOp())) {
408+
return reshapedScale;
409+
}
410+
411+
// Otherwise we need to create a new broadcast operation that will take
412+
// reshaped scale and brroadcast it to the shape of the weight tensor.
413+
SmallVector<int64_t> broadcastDims =
414+
ttmlir::utils::getBroadcastDimensions<int64_t>(
415+
reshapedScale.getType().getShape(), weightType.getShape());
416+
return ttir::utils::createDPSOp<ttir::BroadcastOp>(
417+
rewriter, scaleValue.getLoc(), weightType, reshapedScale,
418+
broadcastDims);
343419
}
344420

345421
/// Create pre-multiplied weights.

test/ttmlir/Dialect/TTIR/fusing/conv2d_multiply_commute.mlir

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,11 @@ module {
134134
}
135135

136136
// Check that we can't commute since %scale is not before %conv in block.
137-
// CHECK-LABEL: func.func @conv2d_creation_op_non_commutable
138-
func.func @conv2d_creation_op_non_commutable(%input: tensor<1x32x32x64xbf16>) -> tensor<1x30x30x64xbf16> {
139-
// CHECK-NOT: "ttir.reshape"
137+
// CHECK-LABEL: func.func @conv2d_creation_op_commutable
138+
func.func @conv2d_creation_op_commutable(%input: tensor<1x32x32x64xbf16>) -> tensor<1x30x30x64xbf16> {
139+
// CHECK: "ttir.ones"
140+
// CHECK: "ttir.reshape"
141+
// CHECK: "ttir.multiply"
140142
// CHECK: "ttir.conv2d"
141143
%0 = ttir.empty() : tensor<1x30x30x64xbf16>
142144
%weight = "ttir.zeros"() <{shape = array<i32: 64, 64, 3, 3>}> : () -> tensor<64x64x3x3xbf16>
@@ -147,11 +149,72 @@ module {
147149
dilation = 1: i32,
148150
groups = 1: i32
149151
}> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16>
150-
// CHECK: "ttir.multiply"
151152
%scale = "ttir.ones"() <{shape = array<i32: 1, 1, 1, 64>}> : () -> tensor<1x1x1x64xbf16>
152153
%1 = ttir.empty() : tensor<1x30x30x64xbf16>
153154
%2 = "ttir.multiply"(%conv, %scale, %1) : (tensor<1x30x30x64xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16>
154155

155156
return %2: tensor<1x30x30x64xbf16>
156157
}
158+
159+
// Check that we can commute const-eval subgraph which generates scale for conv2d output.
160+
// CHECK-LABEL: func.func @conv2d_subgraph_commute
161+
func.func @conv2d_subgraph_commute(%arg0: tensor<1x3x224x224xbf16> {ttcore.argument_type = #ttcore.argument_type<input>}, %arg1: tensor<1x32x1x1xbf16> {ttcore.argument_type = #ttcore.argument_type<constant>}, %arg2: tensor<1x32x1x1xbf16> {ttcore.argument_type = #ttcore.argument_type<constant>}, %arg3: tensor<32x3x3x3xbf16> {ttcore.argument_type = #ttcore.argument_type<parameter>, ttir.conv2d_weight}, %arg4: tensor<32x1x3x3xbf16> {ttcore.argument_type = #ttcore.argument_type<parameter>, ttir.conv2d_weight}) -> tensor<1x112x112x32xbf16> {
162+
// Ignore first reshape which is for conv input.
163+
// CHECK: "ttir.reshape"
164+
// CHECK: %[[RESHAPE:.*]] = "ttir.reshape"
165+
// CHECK: %[[BCAST:.*]] = "ttir.broadcast"
166+
// CHECK-SAME: (%[[RESHAPE]]
167+
// CHECK: %[[MUL:.*]] = "ttir.multiply"
168+
// CHECK-SAME: (%arg3, %[[BCAST]]
169+
// CHECK: "ttir.conv2d"
170+
// CHECK-SAME: ([[X:.*]], %[[MUL]]
171+
%0 = ttir.empty() : tensor<1x224x3x224xbf16>
172+
%1 = "ttir.transpose"(%arg0, %0) <{dim0 = 1 : si32, dim1 = 2 : si32}> : (tensor<1x3x224x224xbf16>, tensor<1x224x3x224xbf16>) -> tensor<1x224x3x224xbf16>
173+
%2 = ttir.empty() : tensor<1x224x224x3xbf16>
174+
%3 = "ttir.transpose"(%1, %2) <{dim0 = 2 : si32, dim1 = 3 : si32}> : (tensor<1x224x3x224xbf16>, tensor<1x224x224x3xbf16>) -> tensor<1x224x224x3xbf16>
175+
%4 = ttir.empty() : tensor<1x1x50176x3xbf16>
176+
%5 = "ttir.reshape"(%3, %4) <{shape = [1 : i32, 1 : i32, 50176 : i32, 3 : i32]}> : (tensor<1x224x224x3xbf16>, tensor<1x1x50176x3xbf16>) -> tensor<1x1x50176x3xbf16>
177+
%6 = ttir.empty() : tensor<1x1x12544x32xbf16>
178+
%7 = "ttir.conv2d"(%5, %arg3, %6) <{dilation = array<i32: 1, 1>, flattened_compat_info = #ttir<flattened_compat batch_size = 1, input_height = 224, input_width = 224>, groups = 1 : i32, padding = array<i32: 1, 1, 1, 1>, stride = array<i32: 2, 2>}> : (tensor<1x1x50176x3xbf16>, tensor<32x3x3x3xbf16>, tensor<1x1x12544x32xbf16>) -> tensor<1x1x12544x32xbf16>
179+
%8 = ttir.empty() : tensor<1x1x32x1xbf16>
180+
%9 = "ttir.transpose"(%arg1, %8) <{dim0 = 1 : si32, dim1 = 2 : si32}> : (tensor<1x32x1x1xbf16>, tensor<1x1x32x1xbf16>) -> tensor<1x1x32x1xbf16>
181+
%10 = ttir.empty() : tensor<1x1x1x32xbf16>
182+
%11 = "ttir.transpose"(%9, %10) <{dim0 = 2 : si32, dim1 = 3 : si32}> : (tensor<1x1x32x1xbf16>, tensor<1x1x1x32xbf16>) -> tensor<1x1x1x32xbf16>
183+
%12 = ttir.empty() : tensor<1x1x12544x32xbf16>
184+
%13 = "ttir.broadcast"(%11, %12) <{broadcast_dimensions = array<i64: 1, 1, 12544, 1>}> : (tensor<1x1x1x32xbf16>, tensor<1x1x12544x32xbf16>) -> tensor<1x1x12544x32xbf16>
185+
%14 = ttir.empty() : tensor<1x1x12544x32xbf16>
186+
%15 = "ttir.multiply"(%7, %13, %14) : (tensor<1x1x12544x32xbf16>, tensor<1x1x12544x32xbf16>, tensor<1x1x12544x32xbf16>) -> tensor<1x1x12544x32xbf16>
187+
%16 = ttir.empty() : tensor<1x112x112x32xbf16>
188+
%17 = "ttir.reshape"(%15, %16) <{shape = [1 : i32, 112 : i32, 112 : i32, 32 : i32]}> : (tensor<1x1x12544x32xbf16>, tensor<1x112x112x32xbf16>) -> tensor<1x112x112x32xbf16>
189+
return %17 : tensor<1x112x112x32xbf16>
190+
}
191+
192+
// Check that we can't commute const-eval since arg1 is not constant.
193+
// CHECK-LABEL: func.func @conv2d_subgraph_not_commuteable
194+
func.func @conv2d_subgraph_not_commuteable(%arg0: tensor<1x3x224x224xbf16> {ttcore.argument_type = #ttcore.argument_type<input>}, %arg1: tensor<1x32x1x1xbf16> {ttcore.argument_type = #ttcore.argument_type<input>}, %arg2: tensor<1x32x1x1xbf16> {ttcore.argument_type = #ttcore.argument_type<constant>}, %arg3: tensor<32x3x3x3xbf16> {ttcore.argument_type = #ttcore.argument_type<parameter>, ttir.conv2d_weight}, %arg4: tensor<32x1x3x3xbf16> {ttcore.argument_type = #ttcore.argument_type<parameter>, ttir.conv2d_weight}) -> tensor<1x112x112x32xbf16> {
195+
// Ignore first reshape which is for conv input.
196+
// CHECK: "ttir.reshape"
197+
// CHECK: "ttir.conv2d"
198+
// CHECK: "ttir.broadcast"
199+
// CHECK: "ttir.multiply"
200+
%0 = ttir.empty() : tensor<1x224x3x224xbf16>
201+
%1 = "ttir.transpose"(%arg0, %0) <{dim0 = 1 : si32, dim1 = 2 : si32}> : (tensor<1x3x224x224xbf16>, tensor<1x224x3x224xbf16>) -> tensor<1x224x3x224xbf16>
202+
%2 = ttir.empty() : tensor<1x224x224x3xbf16>
203+
%3 = "ttir.transpose"(%1, %2) <{dim0 = 2 : si32, dim1 = 3 : si32}> : (tensor<1x224x3x224xbf16>, tensor<1x224x224x3xbf16>) -> tensor<1x224x224x3xbf16>
204+
%4 = ttir.empty() : tensor<1x1x50176x3xbf16>
205+
%5 = "ttir.reshape"(%3, %4) <{shape = [1 : i32, 1 : i32, 50176 : i32, 3 : i32]}> : (tensor<1x224x224x3xbf16>, tensor<1x1x50176x3xbf16>) -> tensor<1x1x50176x3xbf16>
206+
%6 = ttir.empty() : tensor<1x1x12544x32xbf16>
207+
%7 = "ttir.conv2d"(%5, %arg3, %6) <{dilation = array<i32: 1, 1>, flattened_compat_info = #ttir<flattened_compat batch_size = 1, input_height = 224, input_width = 224>, groups = 1 : i32, padding = array<i32: 1, 1, 1, 1>, stride = array<i32: 2, 2>}> : (tensor<1x1x50176x3xbf16>, tensor<32x3x3x3xbf16>, tensor<1x1x12544x32xbf16>) -> tensor<1x1x12544x32xbf16>
208+
%8 = ttir.empty() : tensor<1x1x32x1xbf16>
209+
%9 = "ttir.transpose"(%arg1, %8) <{dim0 = 1 : si32, dim1 = 2 : si32}> : (tensor<1x32x1x1xbf16>, tensor<1x1x32x1xbf16>) -> tensor<1x1x32x1xbf16>
210+
%10 = ttir.empty() : tensor<1x1x1x32xbf16>
211+
%11 = "ttir.transpose"(%9, %10) <{dim0 = 2 : si32, dim1 = 3 : si32}> : (tensor<1x1x32x1xbf16>, tensor<1x1x1x32xbf16>) -> tensor<1x1x1x32xbf16>
212+
%12 = ttir.empty() : tensor<1x1x12544x32xbf16>
213+
%13 = "ttir.broadcast"(%11, %12) <{broadcast_dimensions = array<i64: 1, 1, 12544, 1>}> : (tensor<1x1x1x32xbf16>, tensor<1x1x12544x32xbf16>) -> tensor<1x1x12544x32xbf16>
214+
%14 = ttir.empty() : tensor<1x1x12544x32xbf16>
215+
%15 = "ttir.multiply"(%7, %13, %14) : (tensor<1x1x12544x32xbf16>, tensor<1x1x12544x32xbf16>, tensor<1x1x12544x32xbf16>) -> tensor<1x1x12544x32xbf16>
216+
%16 = ttir.empty() : tensor<1x112x112x32xbf16>
217+
%17 = "ttir.reshape"(%15, %16) <{shape = [1 : i32, 112 : i32, 112 : i32, 32 : i32]}> : (tensor<1x1x12544x32xbf16>, tensor<1x112x112x32xbf16>) -> tensor<1x112x112x32xbf16>
218+
return %17 : tensor<1x112x112x32xbf16>
219+
}
157220
}

test/ttmlir/Dialect/TTNN/fusing/resnet_pattern_fusing.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-fusing-pass=true" %s | FileCheck %s
1+
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
22

33
// This is common pattern throught Resnet. We have conv2d with constant weight, followed by multiply with constant input. This will be commuted through conv2d.
44
// Then we fuse add into conv2d with bias and lastly we fuse conv2d and relu into conv2d with activation.

test/ttmlir/Silicon/TTNN/n150/fusing/resnet_pattern_fusing.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-fusing-pass=true" %s > %t.mlir
1+
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
22
// RUN: FileCheck %s --input-file=%t.mlir
33
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
44

test/ttmlir/Silicon/TTNN/n150/fusing/softmax_fusing.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-fusing-pass=true" %s > %t.mlir
1+
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
22
// RUN: FileCheck %s --input-file=%t.mlir
33
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
44

0 commit comments

Comments
 (0)