Skip to content

Commit b754194

Browse files
rkayaithclaude
andcommitted
Add decomposition for aten.miopen_batch_norm to aten.native_batch_norm
On ROCm 7.1+ with MIOpen 3.5+, PyTorch decomposes BatchNorm2d with channels-last layout to aten.miopen_batch_norm instead of aten._native_batch_norm_legit_functional. torch-mlir didn't support this op, causing legalization to fail. Fixes: #4476 - Register aten.miopen_batch_norm in torch_ods_gen.py (signature: (Tensor, Tensor, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)) - Add decomposition pattern to DecomposeComplexOps.cpp that rewrites miopen_batch_norm to native_batch_norm by passing weight directly (type-compatible since native_batch_norm accepts AnyTorchOptionalTensorType) - Add lit test verifying the op is decomposed away Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 17ae730 commit b754194

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7353,6 +7353,38 @@ def Torch_AtenNativeBatchNormOp : Torch_Op<"aten.native_batch_norm", [
73537353
}];
73547354
}
73557355

7356+
def Torch_AtenMiopenBatchNormOp : Torch_Op<"aten.miopen_batch_norm", [
7357+
AllowsTypeRefinement,
7358+
HasValueSemantics,
7359+
ReadOnly
7360+
]> {
7361+
let summary = "Generated op for `aten::miopen_batch_norm : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)`";
7362+
let arguments = (ins
7363+
AnyTorchTensorType:$input,
7364+
AnyTorchTensorType:$weight,
7365+
AnyTorchOptionalTensorType:$bias,
7366+
AnyTorchOptionalTensorType:$running_mean,
7367+
AnyTorchOptionalTensorType:$running_var,
7368+
Torch_BoolType:$training,
7369+
Torch_FloatType:$exponential_average_factor,
7370+
Torch_FloatType:$epsilon
7371+
);
7372+
let results = (outs
7373+
AnyTorchOptionalTensorType:$result0,
7374+
AnyTorchOptionalTensorType:$result1,
7375+
AnyTorchOptionalTensorType:$result2
7376+
);
7377+
let hasCustomAssemblyFormat = 1;
7378+
let extraClassDefinition = [{
7379+
ParseResult AtenMiopenBatchNormOp::parse(OpAsmParser &parser, OperationState &result) {
7380+
return parseDefaultTorchOp(parser, result, 8, 3);
7381+
}
7382+
void AtenMiopenBatchNormOp::print(OpAsmPrinter &printer) {
7383+
printDefaultTorchOp(printer, *this, 8, 3);
7384+
}
7385+
}];
7386+
}
7387+
73567388
def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
73577389
AllowsTypeRefinement,
73587390
HasValueSemantics,

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8526,6 +8526,28 @@ class DecomposeAtenNativeGroupNormOp
85268526
};
85278527
} // namespace
85288528

8529+
// Decompose aten.miopen_batch_norm to aten.native_batch_norm.
8530+
// PyTorch dispatches to miopen_batch_norm on ROCm in some cases:
8531+
// https://github.com/pytorch/pytorch/blob/2f0a6bd3f5efddf0eadf6252924ac4bd0276ee71/aten/src/ATen/native/Normalization.cpp#L626
8532+
// We copy the inductor decomposition here which just lowers to
8533+
// `aten.native_batch_norm` (inference mode returning empty tensors is an
8534+
// optimization we can skip here):
8535+
// https://github.com/pytorch/pytorch/blob/2f0a6bd3f5efddf0eadf6252924ac4bd0276ee71/torch/_inductor/decomposition.py#L888-L917
8536+
namespace {
8537+
class DecomposeAtenMiopenBatchNormOp
8538+
: public OpRewritePattern<AtenMiopenBatchNormOp> {
8539+
using OpRewritePattern::OpRewritePattern;
8540+
LogicalResult matchAndRewrite(AtenMiopenBatchNormOp op,
8541+
PatternRewriter &rewriter) const override {
8542+
rewriter.replaceOpWithNewOp<AtenNativeBatchNormOp>(
8543+
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
8544+
op.getRunningMean(), op.getRunningVar(), op.getTraining(),
8545+
op.getExponentialAverageFactor(), op.getEpsilon());
8546+
return success();
8547+
}
8548+
};
8549+
} // namespace
8550+
85298551
namespace {
85308552
class DecomposeAtenNativeBatchNormOp
85318553
: public OpRewritePattern<AtenNativeBatchNormOp> {
@@ -13380,6 +13402,7 @@ class DecomposeComplexOpsPass
1338013402
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
1338113403
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
1338213404
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
13405+
addPatternIfTargetOpIsIllegal<DecomposeAtenMiopenBatchNormOp>(patterns);
1338313406
addPatternIfTargetOpIsIllegal<
1338413407
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
1338513408
addPatternIfTargetOpIsIllegal<

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,9 @@ def emit_with_mutating_variants(key, **kwargs):
631631
emit(
632632
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
633633
)
634+
emit(
635+
"aten::miopen_batch_norm : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
636+
)
634637
emit(
635638
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
636639
)

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,3 +1013,15 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens
10131013
%0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32>
10141014
return %0 : !torch.vtensor<[1,8,4,4],f32>
10151015
}
1016+
1017+
// -----
1018+
1019+
// CHECK-LABEL: func.func @miopen_batch_norm(
1020+
// CHECK-NOT: torch.aten.miopen_batch_norm
1021+
func.func @miopen_batch_norm(%input: !torch.vtensor<[1,3,4,4],f32>, %weight: !torch.vtensor<[3],f32>, %bias: !torch.vtensor<[3],f32>, %mean: !torch.vtensor<[3],f32>, %var: !torch.vtensor<[3],f32>) -> (!torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) {
1022+
%false = torch.constant.bool false
1023+
%momentum = torch.constant.float 1.000000e-01
1024+
%eps = torch.constant.float 1.000000e-05
1025+
%0:3 = torch.aten.miopen_batch_norm %input, %weight, %bias, %mean, %var, %false, %momentum, %eps : !torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float -> !torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>
1026+
return %0#0, %0#1, %0#2 : !torch.vtensor<[1,3,4,4],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>
1027+
}

0 commit comments

Comments
 (0)