Skip to content

Commit 1f2abc8

Browse files
rkayaithclaude
andcommitted
[Torch] Fix batch norm decomposition dtype for mixed-precision inputs
Use the running stats dtype instead of the input dtype when reshaping `running_mean`/`running_var` in `DecomposeAtenNativeBatchNormOp`. When the input is e.g. bf16 with f32 running stats, the reshape was producing `aten.view` ops with a bf16 result type from an f32 input, which is invalid. Fixes #4480 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c748d48 commit 1f2abc8

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8592,14 +8592,17 @@ class DecomposeAtenNativeBatchNormOp
85928592
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
85938593
runningStatsShapeInt[1] =
85948594
cast<BaseTensorType>(runningMean.getType()).getSizes()[0];
8595-
Type dtype = cast<ValueTensorType>(input.getType()).getOptionalDtype();
8596-
Type reshapeType = ValueTensorType::get(
8597-
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
85988595

8599-
runningMean = AtenViewOp::create(rewriter, loc, reshapeType, runningMean,
8600-
runningStatsSizeList);
8601-
runningVar = AtenViewOp::create(rewriter, loc, reshapeType, runningVar,
8602-
runningStatsSizeList);
8596+
auto reshapeType = [&](Value v) {
8597+
auto dtype = cast<ValueTensorType>(v.getType()).getOptionalDtype();
8598+
return ValueTensorType::get(context, llvm::ArrayRef(runningStatsShapeInt),
8599+
dtype);
8600+
};
8601+
8602+
runningMean = AtenViewOp::create(rewriter, loc, reshapeType(runningMean),
8603+
runningMean, runningStatsSizeList);
8604+
runningVar = AtenViewOp::create(rewriter, loc, reshapeType(runningVar),
8605+
runningVar, runningStatsSizeList);
86038606

86048607
// normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)).
86058608
Value inputSubMean = AtenSubTensorOp::create(
@@ -8621,7 +8624,7 @@ class DecomposeAtenNativeBatchNormOp
86218624
std::optional<unsigned> weightRank = getTensorRank(weight);
86228625
if (!weightRank || *weightRank != 1)
86238626
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
8624-
weight = AtenViewOp::create(rewriter, loc, reshapeType, weight,
8627+
weight = AtenViewOp::create(rewriter, loc, reshapeType(weight), weight,
86258628
runningStatsSizeList);
86268629
batchNormOutput = AtenMulTensorOp::create(
86278630
rewriter, loc, batchNormOutput.getType(), batchNormOutput, weight);
@@ -8631,7 +8634,7 @@ class DecomposeAtenNativeBatchNormOp
86318634
std::optional<unsigned> biasRank = getTensorRank(bias);
86328635
if (!biasRank || *biasRank != 1)
86338636
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
8634-
bias = AtenViewOp::create(rewriter, loc, reshapeType, bias,
8637+
bias = AtenViewOp::create(rewriter, loc, reshapeType(bias), bias,
86358638
runningStatsSizeList);
86368639
batchNormOutput =
86378640
AtenAddTensorOp::create(rewriter, loc, batchNormOutput.getType(),

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,3 +1013,27 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens
10131013
%0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32>
10141014
return %0 : !torch.vtensor<[1,8,4,4],f32>
10151015
}
1016+
1017+
// -----
1018+
1019+
// CHECK-LABEL: func.func @native_batch_norm_mixed_precision(
1020+
// CHECK-SAME: %{{.*}}: !torch.vtensor<[1,3,4,4],bf16>, %{{.*}}: !torch.vtensor<[3],f32>, %{{.*}}: !torch.vtensor<[3],f32>, %[[MEAN:.*]]: !torch.vtensor<[3],f32>, %[[VAR:.*]]: !torch.vtensor<[3],f32>)
1021+
// Verify that the running stats reshape uses f32 (matching running stats dtype),
1022+
// not bf16 (the input dtype).
1023+
// CHECK: torch.aten.view %[[MEAN]], %{{.*}} : !torch.vtensor<[3],f32>, !torch.list<int> -> !torch.vtensor<[1,3,1,1],f32>
1024+
// CHECK: torch.aten.view %[[VAR]], %{{.*}} : !torch.vtensor<[3],f32>, !torch.list<int> -> !torch.vtensor<[1,3,1,1],f32>
1025+
func.func @native_batch_norm_mixed_precision(
1026+
%input: !torch.vtensor<[1,3,4,4],bf16>,
1027+
%weight: !torch.vtensor<[3],f32>,
1028+
%bias: !torch.vtensor<[3],f32>,
1029+
%running_mean: !torch.vtensor<[3],f32>,
1030+
%running_var: !torch.vtensor<[3],f32>
1031+
) -> (!torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32>) {
1032+
%false = torch.constant.bool false
1033+
%float1e-5 = torch.constant.float 1.000000e-05
1034+
%float0.1 = torch.constant.float 1.000000e-01
1035+
%out:3 = torch.aten.native_batch_norm %input, %weight, %bias, %running_mean, %running_var, %false, %float0.1, %float1e-5 :
1036+
!torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float
1037+
-> !torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32>
1038+
return %out#0, %out#1, %out#2 : !torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32>
1039+
}

0 commit comments

Comments
 (0)