diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 82d11ec0737a..e416634c1c09 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7353,6 +7353,38 @@ def Torch_AtenNativeBatchNormOp : Torch_Op<"aten.native_batch_norm", [ }]; } +def Torch_AtenMiopenBatchNormOp : Torch_Op<"aten.miopen_batch_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::miopen_batch_norm : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$training, + Torch_FloatType:$exponential_average_factor, + Torch_FloatType:$epsilon + ); + let results = (outs + AnyTorchOptionalTensorType:$result0, + AnyTorchOptionalTensorType:$result1, + AnyTorchOptionalTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMiopenBatchNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 3); + } + void AtenMiopenBatchNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 3); + } + }]; +} + def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9a386ec35f30..735be30cc48b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8526,6 +8526,28 @@ class DecomposeAtenNativeGroupNormOp }; } // namespace +// Decompose aten.miopen_batch_norm to aten.native_batch_norm. +// PyTorch dispatches to miopen_batch_norm on ROCm in some cases: +// https://github.com/pytorch/pytorch/blob/2f0a6bd3f5efddf0eadf6252924ac4bd0276ee71/aten/src/ATen/native/Normalization.cpp#L626 +// We copy the inductor decomposition here which just lowers to +// `aten.native_batch_norm` (inference mode returning empty tensors is an +// optimization we can skip here): +// https://github.com/pytorch/pytorch/blob/2f0a6bd3f5efddf0eadf6252924ac4bd0276ee71/torch/_inductor/decomposition.py#L888-L917 +namespace { +class DecomposeAtenMiopenBatchNormOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMiopenBatchNormOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getRunningMean(), op.getRunningVar(), op.getTraining(), + op.getExponentialAverageFactor(), op.getEpsilon()); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNativeBatchNormOp : public OpRewritePattern { @@ -13380,6 +13402,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>(patterns); addPatternIfTargetOpIsIllegal< diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7d5e65c21cef..e33bb82f1269 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -631,6 +631,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)" ) + emit( + "aten::miopen_batch_norm : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)" + ) emit( "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" ) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index d8dc39375bb6..1f6df4562b23 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -1013,3 +1013,27 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens %0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32> return %0 : !torch.vtensor<[1,8,4,4],f32> } + +// ----- + +// CHECK-LABEL: func.func @miopen_batch_norm( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[1,3,4,4],f32>, %[[WEIGHT:.*]]: !torch.vtensor<[3],f32>, %[[BIAS:.*]]: !torch.vtensor<[3],f32>, %[[MEAN:.*]]: !torch.vtensor<[3],f32>, %[[VAR:.*]]: !torch.vtensor<[3],f32>) +// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct {{.*}} : ({{.*}}) -> !torch.list +// CHECK: %[[MEAN_VIEW:.*]] = torch.aten.view %[[MEAN]], %[[SHAPE]] +// CHECK: %[[VAR_VIEW:.*]] = torch.aten.view %[[VAR]], %[[SHAPE]] +// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN_VIEW]] +// CHECK: %[[VAR_EPS:.*]] = torch.aten.add.Scalar %[[VAR_VIEW]], %{{.*}}, %{{.*}} +// CHECK: %[[RSQRT:.*]] = torch.aten.rsqrt %[[VAR_EPS]] +// CHECK: %[[NORMED:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[RSQRT]] +// CHECK: %[[WEIGHT_VIEW:.*]] = torch.aten.view %[[WEIGHT]], %[[SHAPE]] +// CHECK: %[[SCALED:.*]] = torch.aten.mul.Tensor %[[NORMED]], %[[WEIGHT_VIEW]] +// CHECK: %[[BIAS_VIEW:.*]] = torch.aten.view %[[BIAS]], %[[SHAPE]] +// CHECK: %[[RESULT:.*]] = torch.aten.add.Tensor %[[SCALED]], %[[BIAS_VIEW]] +// CHECK: return %[[RESULT]] +func.func @miopen_batch_norm(%input: !torch.vtensor<[1,3,4,4],f32>, %weight: !torch.vtensor<[3],f32>, %bias: !torch.vtensor<[3],f32>, %mean: !torch.vtensor<[3],f32>, %var: !torch.vtensor<[3],f32>) -> (!torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) { + %false = torch.constant.bool false + %momentum = torch.constant.float 1.000000e-01 + %eps = torch.constant.float 1.000000e-05 + %0:3 = torch.aten.miopen_batch_norm %input, %weight, %bias, %mean, %var, %false, %momentum, %eps : !torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float -> !torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> + return %0#0, %0#1, %0#2 : !torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> +}