-
Notifications
You must be signed in to change notification settings - Fork 661
Open
Description
On ROCm 7.1+ (MIOpen 3.5+), PyTorch decomposes BatchNorm2d with channels-last inputs to aten.miopen_batch_norm instead of aten._native_batch_norm_legit_functional. torch-mlir doesn't have a lowering for this op, so it gets imported as a generic torch.operator and fails legalization.
Reproducer
// repro.mlir
module @module {
func.func @main(%arg0: !torch.vtensor<[8,64,56,56],bf16>) -> !torch.vtensor<[8,64,56,56],bf16> {
%weight = torch.vtensor.literal(dense<1.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
%bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
%running_mean = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
%running_var = torch.vtensor.literal(dense<1.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
%false = torch.constant.bool false
%momentum = torch.constant.float 1.000000e-01
%eps = torch.constant.float 1.000000e-05
%0:3 = torch.operator "torch.aten.miopen_batch_norm"(
%arg0, %weight, %bias, %running_mean, %running_var,
%false, %momentum, %eps
) : (!torch.vtensor<[8,64,56,56],bf16>,
!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>,
!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>,
!torch.bool, !torch.float, !torch.float)
-> (!torch.vtensor<[8,64,56,56],bf16>, !torch.vtensor<[0],bf16>, !torch.vtensor<[0],bf16>)
return %0#0 : !torch.vtensor<[8,64,56,56],bf16>
}
}$ torch-mlir-opt repro.mlir --torch-function-to-torch-backend-pipeline
repro.mlir:12:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
%0:3 = torch.operator "torch.aten.miopen_batch_norm"(
^
IREE Turbine reproducer
import torch
import iree.turbine.aot as aot
model = torch.nn.BatchNorm2d(64).to(device="cuda", memory_format=torch.channels_last)
model.eval()
x = torch.randn(8, 64, 56, 56, device="cuda", dtype=torch.bfloat16).to(
memory_format=torch.channels_last
)
exported = aot.export(model, x)
exported.print_readable()
exported.compile(save_to=None)Suggested fix
Add a decomposition mapping aten.miopen_batch_norm → aten.native_batch_norm. The ops have the same signature and compatible output semantics. PyTorch Inductor has an equivalent decomposition: torch/_inductor/decomposition.py#L867-L896
Environment
- PyTorch
2.10.0+rocm7.1 - iree-org/iree-turbine@391729b
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels