Commit d444578
[Torch] Fix batch norm decomposition dtype for mixed-precision inputs
Use the running stats dtype instead of the input dtype when reshaping
`running_mean`/`running_var` in `DecomposeAtenNativeBatchNormOp`. When
the input is e.g. bf16 with f32 running stats, the reshape was producing
`aten.view` ops with a bf16 result type from an f32 input, which is
invalid.
Fixes #4480
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>1 parent c748d48 commit d444578
File tree
2 files changed
+26
-1
lines changed- lib/Dialect/Torch/Transforms
- test/Dialect/Torch
2 files changed
+26
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8592 | 8592 | | |
8593 | 8593 | | |
8594 | 8594 | | |
8595 | | - | |
| 8595 | + | |
| 8596 | + | |
8596 | 8597 | | |
8597 | 8598 | | |
8598 | 8599 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1013 | 1013 | | |
1014 | 1014 | | |
1015 | 1015 | | |
| 1016 | + | |
| 1017 | + | |
| 1018 | + | |
| 1019 | + | |
| 1020 | + | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
| 1027 | + | |
| 1028 | + | |
| 1029 | + | |
| 1030 | + | |
| 1031 | + | |
| 1032 | + | |
| 1033 | + | |
| 1034 | + | |
| 1035 | + | |
| 1036 | + | |
| 1037 | + | |
| 1038 | + | |
| 1039 | + | |
0 commit comments