Commit 1f2abc8
[Torch] Fix batch norm decomposition dtype for mixed-precision inputs
Use the running stats dtype instead of the input dtype when reshaping
`running_mean`/`running_var` in `DecomposeAtenNativeBatchNormOp`. When
the input is e.g. bf16 with f32 running stats, the reshape was producing
`aten.view` ops with a bf16 result type from an f32 input, which is
invalid.
Fixes #4480
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>1 parent c748d48 commit 1f2abc8
File tree
2 files changed
+36
-9
lines changed- lib/Dialect/Torch/Transforms
- test/Dialect/Torch
2 files changed
+36
-9
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8592 | 8592 | | |
8593 | 8593 | | |
8594 | 8594 | | |
8595 | | - | |
8596 | | - | |
8597 | | - | |
8598 | 8595 | | |
8599 | | - | |
8600 | | - | |
8601 | | - | |
8602 | | - | |
| 8596 | + | |
| 8597 | + | |
| 8598 | + | |
| 8599 | + | |
| 8600 | + | |
| 8601 | + | |
| 8602 | + | |
| 8603 | + | |
| 8604 | + | |
| 8605 | + | |
8603 | 8606 | | |
8604 | 8607 | | |
8605 | 8608 | | |
| |||
8621 | 8624 | | |
8622 | 8625 | | |
8623 | 8626 | | |
8624 | | - | |
| 8627 | + | |
8625 | 8628 | | |
8626 | 8629 | | |
8627 | 8630 | | |
| |||
8631 | 8634 | | |
8632 | 8635 | | |
8633 | 8636 | | |
8634 | | - | |
| 8637 | + | |
8635 | 8638 | | |
8636 | 8639 | | |
8637 | 8640 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1013 | 1013 | | |
1014 | 1014 | | |
1015 | 1015 | | |
| 1016 | + | |
| 1017 | + | |
| 1018 | + | |
| 1019 | + | |
| 1020 | + | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
| 1027 | + | |
| 1028 | + | |
| 1029 | + | |
| 1030 | + | |
| 1031 | + | |
| 1032 | + | |
| 1033 | + | |
| 1034 | + | |
| 1035 | + | |
| 1036 | + | |
| 1037 | + | |
| 1038 | + | |
| 1039 | + | |
0 commit comments