Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def TensorRT_CastOp : TensorRT_Op<"cast", [
let arguments = (ins TensorRT_RankedTensorOf<[I1, UI8, TensorRT_I8, I32, I64, F16, BF16, F32]>:$input);
let results = (outs TensorRT_RankedTensorOf<[I1, UI8, TensorRT_I8, I32, I64, F16, BF16, F32]>:$result);
let assemblyFormat = "attr-dict $input `:` type($input) `to` type($result)";
let hasFolder = 1;

let extraClassDeclaration = [{
/// Returns true if created op is valid for TensorRT major version.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ TypedValue<RankedTensorType>
scatterShapeTensor(RewriterBase &b, Location loc, ArrayRef<int64_t> baseShape,
int32_t scatterDim, TypedValue<RankedTensorType> update);

/// Get a splatted constant's attribute by going up a chain of reshape and cast
/// operations to find the original constant. The constant can be a different
/// data type if there is a cast operation in the chain.
FailureOr<Attribute> getSplatConstantElementAttribute(Value x);

} // namespace tensorrt
} // namespace mlir

Expand Down
11 changes: 11 additions & 0 deletions mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2380,6 +2380,17 @@ OpFoldResult IdentityOp::fold(FoldAdaptor adaptor) {
return foldIdentity(getType(), getInput(), adaptor);
}

//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//

OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
if (getInput().getType() == getType()) {
return getInput();
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// ReduceOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
//===----------------------------------------------------------------------===//
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h"
#include "mlir-tensorrt-dialect/TensorRT/Utils/Utils.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -51,6 +52,8 @@ class RaiseNormalizations
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<RaiseInstanceNormalization_NCHW>(ctx);
patterns.add<RaisePytorchLayerNorm>(ctx);
patterns.add<RemoveLayerNormCast>(ctx);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
emitError(getOperation()->getLoc())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ Constraint ReduceSumImpl(val: Value)[{
(reduceOp.getInput().getType().getRank() - 1)));
}];

Constraint AvgImpl(op: Op) [{
return success(cast<tensorrt::ReduceOp>(op).getReduceOperation() == ReduceOperation::kAVG);
}];

Constraint CheckRank4(val: Value)[{
RankedTensorType rtt = cast<RankedTensorType>(val.getType());
return success(rtt.getRank() == 4);
Expand Down Expand Up @@ -196,10 +200,10 @@ Constraint ReverseSqrt(val : Value) -> Value{
}

Constraint FlattenTailDims(val: Value) -> Value {
CheckRank4(val);
let reshapeRes = op<tensorrt.reshape>(val);
FlattenConstraintImpl(reshapeRes);
return reshapeRes;
CheckRank4(val);
let reshapeRes = op<tensorrt.reshape>(val);
FlattenConstraintImpl(reshapeRes);
return reshapeRes;
}

Constraint ReduceSum(val: Value) -> Value{
Expand All @@ -219,6 +223,35 @@ Constraint Mean(input: Value, numHW: Value){
return Div(ExpandTailDims(ReduceSum(FlattenTailDims(input))), numHW);
}

Constraint ReduceAvg(input: Value, reduceAxes: Attr) {
let avgOp = op<tensorrt.reduce>(input) {keepDimensions = attr<"true">, reduceAxes = reduceAxes};
AvgImpl(avgOp);
return avgOp;
}


Rewrite GetSplatElementAttr(x: Value) -> Attr [{
return *getSplatConstantElementAttribute(x);
}];

Constraint HasSplatElements(x: Value) [{
return LogicalResult(getSplatConstantElementAttribute(x));
}];

Constraint SameElementType(a: Value, b: Value) [{
return success(cast<RankedTensorType>(a.getType()).getElementType() == cast<RankedTensorType>(b.getType()).getElementType());
}];

Rewrite CreateCast(x: Value, refValue: Value) -> Value [{
Type retType = RankedTensorType::Builder(cast<RankedTensorType>(x.getType())).setElementType(cast<RankedTensorType>(refValue.getType()).getElementType());
return rewriter.createOrFold<tensorrt::CastOp>(
x.getLoc(),
retType,
x
);
}];


Pattern RaiseInstanceNormalization_NCHW {
let inputType : Type;
let input : Value<inputType>;
Expand All @@ -240,3 +273,57 @@ Pattern RaiseInstanceNormalization_NCHW {
CheckRank4(addOffset);
replace addOffset with op<tensorrt.normalization>(input, scale, offset){axis = attr<"array<i64: 2,3>">};
}

Pattern RaisePytorchLayerNorm {
let x: Value;
let beta: Value;
let gamma: Value;
let axis: Attr;
let epsilon: Value;

let mean = ReduceAvg(x, axis);
let diffMean = Sub(x, mean);

let varianceDenominator: Value;
let varianceMean = Div(ReduceSum(x), varianceDenominator); // for some reason Pytorch's lowering computes the mean in 2 different ways....
let varianceDiff = Sub(x, varianceMean);
let varianceDiffSquared = Mul(varianceDiff, varianceDiff);
let varianceNumerator = ReduceSum(varianceDiffSquared);
let variance = Div(varianceNumerator, varianceDenominator);
let varianceEps = Add(variance, epsilon);

let inverseSqrt = ReverseSqrt(varianceEps);
let normed = Mul(diffMean, inverseSqrt);
let prod = Mul(normed, gamma);
let root = Add(prod, beta);

HasSplatElements(epsilon);
HasSplatElements(varianceDenominator);

rewrite root with {
let epsilonAttr = GetSplatElementAttr(epsilon);
let replacement = op<tensorrt.normalization>(x, gamma, beta) {axis = axis, eps = epsilonAttr};
replace root with replacement;
};
}

Pattern RemoveLayerNormCast {
let x: Value;
let gamma: Value;
let beta: Value;
let axis: Attr;
let epsilonAttr: Attr;

let castInput = op<tensorrt.cast>(x);
let norm = op<tensorrt.normalization>(castInput, gamma, beta) {axis = axis, eps = epsilonAttr};
let root = op<tensorrt.cast>(norm);

SameElementType(x, root);

rewrite root with {
let newGamma = CreateCast(gamma, x);
let newBeta = CreateCast(beta, x);
let replacement = op<tensorrt.normalization>(x, newGamma, newBeta) {axis = axis, eps = epsilonAttr};
replace root with replacement;
};
}
32 changes: 32 additions & 0 deletions mlir-tensorrt/tensorrt/lib/TensorRT/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir-tensorrt-dialect/TensorRT/Utils/Utils.h"

#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

using namespace mlir;
Expand Down Expand Up @@ -158,3 +159,34 @@ tensorrt::scatterShapeTensor(RewriterBase &b, Location loc,

return b.create<tensorrt::ConcatenationOp>(loc, parts, 0);
}

FailureOr<Attribute> tensorrt::getSplatConstantElementAttribute(Value x) {
while (true) {
if (auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
x = expandRank.getInput();
else if (auto collapseRank = x.getDefiningOp<tensorrt::CollapseRankOp>())
x = collapseRank.getInput();
else if (auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
x = reshape.getInput();
else if (auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
x = broadcast.getInput();
else if (auto cast = x.getDefiningOp<tensorrt::CastOp>())
x = cast.getInput();
else if (auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
x = identity.getInput();
else if (auto slice = x.getDefiningOp<tensorrt::SliceOp>())
x = slice.getInput();
else if (auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
SplatElementsAttr els{};
if (!matchPattern(x, m_Constant(&els)))
return failure();
Attribute value = els.getSplatValue<Attribute>();
if (!isa<FloatAttr, IntegerAttr>(value))
return failure();
return value;
} else {
return failure();
}
}
return failure();
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,38 @@ func.func @raise_inst_norm_nchw(%arg0: tensor<1x3x1x1xf32>, %arg1: tensor<1x3x1x

// CHECK-LABEL: @neg_raise_nhwc
// CHECK-NOT: tensorrt.normalization

// -----

// CHECK: @raise_layer_norm_pytorch(%[[arg0:.+]]: tensor<16x1024x1024xf32>)
// CHECK: %[[ret:.+]] = tensorrt.normalization [[attr:.+]](%[[arg0]] : tensor<16x1024x1024xf32>, %[[gamma:.+]] : tensor<1x1x1024xf32>, %[[beta:.+]] : tensor<1x1x1024xf32>)
// CHECK: return %[[ret]]
func.func @raise_layer_norm_pytorch(%arg0: tensor<16x1024x1024xf32>) -> tensor<16x1024x1024xf32> {
%cst_i64 = tensorrt.constant dense<1024> : tensor<i64>
%cst_f32 = tensorrt.constant dense<9.99999974E-6> : tensor<1x1x1xf32>

%cst_bf16_1 = tensorrt.constant dense_resource<__elided__> : tensor<1024xbf16> // beta (added)
%cst_bf16_2 = tensorrt.constant dense_resource<__elided__> : tensor<1024xbf16> // gamma (multiplied)

%6 = tensorrt.reduce <kSUM> %arg0 {keepDimensions = true, reduceAxes = array<i64: 2>} : tensor<16x1024x1024xf32> -> tensor<16x1024x1xf32>
%7 = tensorrt.cast %cst_i64 : tensor<i64> to tensor<f32>
%8 = tensorrt.expand_rank %7 : tensor<f32> to tensor<1x1x1xf32>
%9 = tensorrt.element_wise <kDIV>(%6, %8 : tensor<16x1024x1xf32>, tensor<1x1x1xf32>) -> tensor<16x1024x1xf32>
%10 = tensorrt.element_wise <kSUB>(%arg0, %9 : tensor<16x1024x1024xf32>, tensor<16x1024x1xf32>) -> tensor<16x1024x1024xf32>
%11 = tensorrt.element_wise <kPROD>(%10, %10 : tensor<16x1024x1024xf32>, tensor<16x1024x1024xf32>) -> tensor<16x1024x1024xf32>
%12 = tensorrt.reduce <kSUM> %11 {keepDimensions = true, reduceAxes = array<i64: 2>} : tensor<16x1024x1024xf32> -> tensor<16x1024x1xf32>
%13 = tensorrt.element_wise <kDIV>(%12, %8 : tensor<16x1024x1xf32>, tensor<1x1x1xf32>) -> tensor<16x1024x1xf32> // Var[x]
%15 = tensorrt.reduce <kAVG> %arg0 {keepDimensions = true, reduceAxes = array<i64: 2>} : tensor<16x1024x1024xf32> -> tensor<16x1024x1xf32> // E[x]
%16 = tensorrt.element_wise <kSUM>(%13, %cst_f32 : tensor<16x1024x1xf32>, tensor<1x1x1xf32>) -> tensor<16x1024x1xf32> // Var[x] + epsilon
%17 = tensorrt.unary {unaryOperation = #tensorrt.unary_operation<kRECIP>} %16 : tensor<16x1024x1xf32>
%18 = tensorrt.unary {unaryOperation = #tensorrt.unary_operation<kSQRT>} %17 : tensor<16x1024x1xf32> // compute 1/sqrt(...)
%19 = tensorrt.element_wise <kSUB>(%arg0, %15 : tensor<16x1024x1024xf32>, tensor<16x1024x1xf32>) -> tensor<16x1024x1024xf32>
%20 = tensorrt.element_wise <kPROD>(%19, %18 : tensor<16x1024x1024xf32>, tensor<16x1024x1xf32>) -> tensor<16x1024x1024xf32> // multiply for division
%21 = tensorrt.cast %cst_bf16_2 : tensor<1024xbf16> to tensor<1024xf32>
%22 = tensorrt.expand_rank %21 : tensor<1024xf32> to tensor<1x1x1024xf32>
%23 = tensorrt.element_wise <kPROD>(%20, %22 : tensor<16x1024x1024xf32>, tensor<1x1x1024xf32>) -> tensor<16x1024x1024xf32> // multiply gamma
%24 = tensorrt.cast %cst_bf16_1 : tensor<1024xbf16> to tensor<1024xf32>
%25 = tensorrt.expand_rank %24 : tensor<1024xf32> to tensor<1x1x1024xf32>
%26 = tensorrt.element_wise <kSUM>(%23, %25 : tensor<16x1024x1024xf32>, tensor<1x1x1024xf32>) -> tensor<16x1024x1024xf32> // add beta
return %26 : tensor<16x1024x1024xf32>
}
Loading