Skip to content

[Torch] Fix batch norm decomposition dtype for mixed-precision inputs#4483

Draft
rkayaith wants to merge 2 commits intollvm:mainfrom
iree-org:bn-decompose-dtype-fix
Draft

[Torch] Fix batch norm decomposition dtype for mixed-precision inputs#4483
rkayaith wants to merge 2 commits intollvm:mainfrom
iree-org:bn-decompose-dtype-fix

Conversation

@rkayaith
Copy link
Member

@rkayaith rkayaith commented Mar 3, 2026

DecomposeAtenNativeBatchNormOp used the input dtype when reshaping running_mean/running_var from [C] to [1,C,1,...]. When the input has a different dtype from the running stats (e.g. bf16 input with f32 running stats), this produced invalid aten.view ops that change the element type.

Fix by using each value's own dtype for its reshape result type. Also add a verifier to aten.view that rejects element type mismatches between input and output.

Fixes #4480
Fixes #4479

@rkayaith rkayaith force-pushed the bn-decompose-dtype-fix branch from d444578 to 3c4f05c Compare March 3, 2026 19:51
Use the running stats dtype instead of the input dtype when reshaping
`running_mean`/`running_var` in `DecomposeAtenNativeBatchNormOp`. When
the input is e.g. bf16 with f32 running stats, the reshape was producing
`aten.view` ops with a bf16 result type from an f32 input, which is
invalid.

Fixes llvm#4480

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@rkayaith rkayaith force-pushed the bn-decompose-dtype-fix branch from 3c4f05c to 1f2abc8 Compare March 3, 2026 19:55
`aten.view` maps to the shape-only overload (`aten::view(Tensor,
SymInt[])`), which preserves dtype. Without a verifier, invalid IR with
mismatched input/output element types reaches `genericViewLikeFold` and
crashes with an assertion failure in `DenseElementsAttr::get`.

Fixes llvm#4479

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@rkayaith rkayaith force-pushed the bn-decompose-dtype-fix branch from 1fab32b to fb59044 Compare March 4, 2026 17:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DecomposeAtenNativeBatchNormOp uses wrong dtype for running stats reshape aten.view fold crashes on element type mismatch instead of verifying

1 participant