Skip to content

Commit d444578

Browse files
rkayaithclaude
andcommitted
[Torch] Fix batch norm decomposition dtype for mixed-precision inputs
Use the running stats dtype instead of the input dtype when reshaping `running_mean`/`running_var` in `DecomposeAtenNativeBatchNormOp`. When the input is e.g. bf16 with f32 running stats, the reshape was producing `aten.view` ops with a bf16 result type from an f32 input, which is invalid. Fixes #4480 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c748d48 commit d444578

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8592,7 +8592,8 @@ class DecomposeAtenNativeBatchNormOp
85928592
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
85938593
runningStatsShapeInt[1] =
85948594
cast<BaseTensorType>(runningMean.getType()).getSizes()[0];
8595-
Type dtype = cast<ValueTensorType>(input.getType()).getOptionalDtype();
8595+
Type dtype =
8596+
cast<ValueTensorType>(runningMean.getType()).getOptionalDtype();
85968597
Type reshapeType = ValueTensorType::get(
85978598
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
85988599

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,3 +1013,27 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens
10131013
%0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32>
10141014
return %0 : !torch.vtensor<[1,8,4,4],f32>
10151015
}
1016+
1017+
// -----
1018+
1019+
// CHECK-LABEL: func.func @native_batch_norm_mixed_precision(
1020+
// CHECK-SAME: %{{.*}}: !torch.vtensor<[1,3,4,4],bf16>, %{{.*}}: !torch.vtensor<[3],f32>, %{{.*}}: !torch.vtensor<[3],f32>, %[[MEAN:.*]]: !torch.vtensor<[3],f32>, %[[VAR:.*]]: !torch.vtensor<[3],f32>)
1021+
// Verify that the running stats reshape uses f32 (matching running stats dtype),
1022+
// not bf16 (the input dtype).
1023+
// CHECK: torch.aten.view %[[MEAN]], %{{.*}} : !torch.vtensor<[3],f32>, !torch.list<int> -> !torch.vtensor<[1,3,1,1],f32>
1024+
// CHECK: torch.aten.view %[[VAR]], %{{.*}} : !torch.vtensor<[3],f32>, !torch.list<int> -> !torch.vtensor<[1,3,1,1],f32>
1025+
func.func @native_batch_norm_mixed_precision(
1026+
%input: !torch.vtensor<[1,3,4,4],bf16>,
1027+
%weight: !torch.vtensor<[3],f32>,
1028+
%bias: !torch.vtensor<[3],f32>,
1029+
%running_mean: !torch.vtensor<[3],f32>,
1030+
%running_var: !torch.vtensor<[3],f32>
1031+
) -> (!torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32>) {
1032+
%false = torch.constant.bool false
1033+
%float1e-5 = torch.constant.float 1.000000e-05
1034+
%float0.1 = torch.constant.float 1.000000e-01
1035+
%out:3 = torch.aten.native_batch_norm %input, %weight, %bias, %running_mean, %running_var, %false, %float0.1, %float1e-5 :
1036+
!torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float
1037+
-> !torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32>
1038+
return %out#0, %out#1, %out#2 : !torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32>
1039+
}

0 commit comments

Comments
 (0)