[TorchOnnxToTorch] Fix LayerNormalization scale/bias dtype mismatch with stash_type#4498
Open
[TorchOnnxToTorch] Fix LayerNormalization scale/bias dtype mismatch with stash_type#4498
Conversation
When ONNX LayerNormalization has stash_type != input dtype (e.g., stash_type=float32 with float16 inputs), the existing code correctly casts x to the stash dtype but passes scale and bias through unchanged. This causes torch.aten.native_layer_norm to receive mixed dtypes (f32 input with f16 scale/bias), which fails verification or produces incorrect results. Fix by also casting scale and bias to the stash dtype when it differs from the input dtype. The output of native_layer_norm is then cast back to the original result dtype. This complements the stash_type support added in llvm#3888 and is the same class of fix as llvm#4474 (which addressed SimplifiedLayerNormalization). Fixes llvm#3725
Author
|
@zjgarvey @Groverkss Would you mind taking a look at this? This is the standard |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
When ONNX
LayerNormalizationhasstash_type!= input dtype (e.g.,stash_type=1(float32) with float16 inputs), the existing lowering correctly castsxto the stash dtype but passesscaleandbiasthrough unchanged. This causestorch.aten.native_layer_normto receive mixed dtypes (f32 input with f16 scale/bias), which fails verification or produces incorrect results at runtime.This PR fixes the issue by:
scaleandbiasto the stash dtype when it differs from the input dtypenative_layer_normoutput typeThis complements the stash_type support added in #3888 and is the same class of fix as #4474 (which addressed
SimplifiedLayerNormalization).Fixes #3725
Test plan
simple_ops_g_to_p.mlir:test_layer_norm_stash_type_f16: Single-result case — verifies x, scale, and bias are all cast to f32, native_layer_norm runs in f32, and result is cast back to f16test_layer_norm_stash_type_f16_3results: Three-result case — verifies all three outputs (y, mean, invStdDev) are cast back to f16test_layer_normandtest_layer_norm_single_resulttests (f32 input, no stash_type mismatch) remain unchanged and continue to pass