Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7353,6 +7353,38 @@ def Torch_AtenNativeBatchNormOp : Torch_Op<"aten.native_batch_norm", [
}];
}

def Torch_AtenMiopenBatchNormOp : Torch_Op<"aten.miopen_batch_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::miopen_batch_norm : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$running_mean,
AnyTorchOptionalTensorType:$running_var,
Torch_BoolType:$training,
Torch_FloatType:$exponential_average_factor,
Torch_FloatType:$epsilon
);
let results = (outs
AnyTorchOptionalTensorType:$result0,
AnyTorchOptionalTensorType:$result1,
AnyTorchOptionalTensorType:$result2
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMiopenBatchNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 3);
}
void AtenMiopenBatchNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 3);
}
}];
}

def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
23 changes: 23 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8526,6 +8526,28 @@ class DecomposeAtenNativeGroupNormOp
};
} // namespace

// Decompose aten.miopen_batch_norm to aten.native_batch_norm.
// PyTorch dispatches to miopen_batch_norm on ROCm in some cases:
// https://github.com/pytorch/pytorch/blob/2f0a6bd3f5efddf0eadf6252924ac4bd0276ee71/aten/src/ATen/native/Normalization.cpp#L626
// We copy the inductor decomposition here which just lowers to
// `aten.native_batch_norm` (inference mode returning empty tensors is an
// optimization we can skip here):
// https://github.com/pytorch/pytorch/blob/2f0a6bd3f5efddf0eadf6252924ac4bd0276ee71/torch/_inductor/decomposition.py#L888-L917
namespace {
class DecomposeAtenMiopenBatchNormOp
: public OpRewritePattern<AtenMiopenBatchNormOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMiopenBatchNormOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenNativeBatchNormOp>(
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
op.getRunningMean(), op.getRunningVar(), op.getTraining(),
op.getExponentialAverageFactor(), op.getEpsilon());
return success();
}
};
} // namespace

namespace {
class DecomposeAtenNativeBatchNormOp
: public OpRewritePattern<AtenNativeBatchNormOp> {
Expand Down Expand Up @@ -13380,6 +13402,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMiopenBatchNormOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
addPatternIfTargetOpIsIllegal<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::miopen_batch_norm : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
Expand Down
24 changes: 24 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1013,3 +1013,27 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens
%0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32>
return %0 : !torch.vtensor<[1,8,4,4],f32>
}

// -----

// CHECK-LABEL: func.func @miopen_batch_norm(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[1,3,4,4],f32>, %[[WEIGHT:.*]]: !torch.vtensor<[3],f32>, %[[BIAS:.*]]: !torch.vtensor<[3],f32>, %[[MEAN:.*]]: !torch.vtensor<[3],f32>, %[[VAR:.*]]: !torch.vtensor<[3],f32>)
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct {{.*}} : ({{.*}}) -> !torch.list<int>
// CHECK: %[[MEAN_VIEW:.*]] = torch.aten.view %[[MEAN]], %[[SHAPE]]
// CHECK: %[[VAR_VIEW:.*]] = torch.aten.view %[[VAR]], %[[SHAPE]]
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN_VIEW]]
// CHECK: %[[VAR_EPS:.*]] = torch.aten.add.Scalar %[[VAR_VIEW]], %{{.*}}, %{{.*}}
// CHECK: %[[RSQRT:.*]] = torch.aten.rsqrt %[[VAR_EPS]]
// CHECK: %[[NORMED:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[RSQRT]]
// CHECK: %[[WEIGHT_VIEW:.*]] = torch.aten.view %[[WEIGHT]], %[[SHAPE]]
// CHECK: %[[SCALED:.*]] = torch.aten.mul.Tensor %[[NORMED]], %[[WEIGHT_VIEW]]
// CHECK: %[[BIAS_VIEW:.*]] = torch.aten.view %[[BIAS]], %[[SHAPE]]
// CHECK: %[[RESULT:.*]] = torch.aten.add.Tensor %[[SCALED]], %[[BIAS_VIEW]]
// CHECK: return %[[RESULT]]
func.func @miopen_batch_norm(%input: !torch.vtensor<[1,3,4,4],f32>, %weight: !torch.vtensor<[3],f32>, %bias: !torch.vtensor<[3],f32>, %mean: !torch.vtensor<[3],f32>, %var: !torch.vtensor<[3],f32>) -> (!torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) {
%false = torch.constant.bool false
%momentum = torch.constant.float 1.000000e-01
%eps = torch.constant.float 1.000000e-05
%0:3 = torch.aten.miopen_batch_norm %input, %weight, %bias, %mean, %var, %false, %momentum, %eps : !torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float -> !torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>
return %0#0, %0#1, %0#2 : !torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>
}
Loading