Skip to content

Add support for aten.miopen_batch_norm #4476

@rkayaith

Description

@rkayaith

On ROCm 7.1+ (MIOpen 3.5+), PyTorch decomposes BatchNorm2d with channels-last inputs to aten.miopen_batch_norm instead of aten._native_batch_norm_legit_functional. torch-mlir doesn't have a lowering for this op, so it gets imported as a generic torch.operator and fails legalization.

Reproducer

// repro.mlir
module @module {
  func.func @main(%arg0: !torch.vtensor<[8,64,56,56],bf16>) -> !torch.vtensor<[8,64,56,56],bf16> {
    %weight = torch.vtensor.literal(dense<1.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
    %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
    %running_mean = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
    %running_var = torch.vtensor.literal(dense<1.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
    %false = torch.constant.bool false
    %momentum = torch.constant.float 1.000000e-01
    %eps = torch.constant.float 1.000000e-05
    %0:3 = torch.operator "torch.aten.miopen_batch_norm"(
        %arg0, %weight, %bias, %running_mean, %running_var,
        %false, %momentum, %eps
    ) : (!torch.vtensor<[8,64,56,56],bf16>,
         !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>,
         !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>,
         !torch.bool, !torch.float, !torch.float)
      -> (!torch.vtensor<[8,64,56,56],bf16>, !torch.vtensor<[0],bf16>, !torch.vtensor<[0],bf16>)
    return %0#0 : !torch.vtensor<[8,64,56,56],bf16>
  }
}
$ torch-mlir-opt repro.mlir --torch-function-to-torch-backend-pipeline
repro.mlir:12:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
    %0:3 = torch.operator "torch.aten.miopen_batch_norm"(
           ^

IREE Turbine reproducer

import torch
import iree.turbine.aot as aot

model = torch.nn.BatchNorm2d(64).to(device="cuda", memory_format=torch.channels_last)
model.eval()
x = torch.randn(8, 64, 56, 56, device="cuda", dtype=torch.bfloat16).to(
    memory_format=torch.channels_last
)

exported = aot.export(model, x)
exported.print_readable()
exported.compile(save_to=None)

Suggested fix

Add a decomposition mapping aten.miopen_batch_norm → aten.native_batch_norm. The ops have the same signature and compatible output semantics. PyTorch Inductor has an equivalent decomposition: torch/_inductor/decomposition.py#L867-L896

Environment

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions