[Torch] Fix batch norm decomposition dtype for mixed-precision inputs#4483
Draft
[Torch] Fix batch norm decomposition dtype for mixed-precision inputs#4483
Conversation
d444578 to
3c4f05c
Compare
Use the running stats dtype instead of the input dtype when reshaping `running_mean`/`running_var` in `DecomposeAtenNativeBatchNormOp`. When the input is e.g. bf16 with f32 running stats, the reshape was producing `aten.view` ops with a bf16 result type from an f32 input, which is invalid. Fixes llvm#4480 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
3c4f05c to
1f2abc8
Compare
`aten.view` maps to the shape-only overload (`aten::view(Tensor, SymInt[])`), which preserves dtype. Without a verifier, invalid IR with mismatched input/output element types reaches `genericViewLikeFold` and crashes with an assertion failure in `DenseElementsAttr::get`. Fixes llvm#4479 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1fab32b to
fb59044
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
DecomposeAtenNativeBatchNormOpused the input dtype when reshapingrunning_mean/running_varfrom[C]to[1,C,1,...]. When the input has a different dtype from the running stats (e.g. bf16 input with f32 running stats), this produced invalidaten.viewops that change the element type.Fix by using each value's own dtype for its reshape result type. Also add a verifier to
aten.viewthat rejects element type mismatches between input and output.Fixes #4480
Fixes #4479