Skip to content

[TorchOnnxToTorch] Fix LayerNormalization scale/bias dtype mismatch with stash_type#4498

Open
bjw-0 wants to merge 3 commits intollvm:mainfrom
bjw-0:fix-layernorm-scale-bias-dtype
Open

[TorchOnnxToTorch] Fix LayerNormalization scale/bias dtype mismatch with stash_type#4498
bjw-0 wants to merge 3 commits intollvm:mainfrom
bjw-0:fix-layernorm-scale-bias-dtype

Conversation

@bjw-0
Copy link

@bjw-0 bjw-0 commented Mar 11, 2026

Summary

When ONNX LayerNormalization has stash_type != input dtype (e.g., stash_type=1 (float32) with float16 inputs), the existing lowering correctly casts x to the stash dtype but passes scale and bias through unchanged. This causes torch.aten.native_layer_norm to receive mixed dtypes (f32 input with f16 scale/bias), which fails verification or produces incorrect results at runtime.

This PR fixes the issue by:

  1. Casting scale and bias to the stash dtype when it differs from the input dtype
  2. Using the stash dtype for the native_layer_norm output type
  3. Casting the result(s) back to the original output dtype

This complements the stash_type support added in #3888 and is the same class of fix as #4474 (which addressed SimplifiedLayerNormalization).

Fixes #3725

Test plan

  • Added two FileCheck tests in simple_ops_g_to_p.mlir:
    • test_layer_norm_stash_type_f16: Single-result case — verifies x, scale, and bias are all cast to f32, native_layer_norm runs in f32, and result is cast back to f16
    • test_layer_norm_stash_type_f16_3results: Three-result case — verifies all three outputs (y, mean, invStdDev) are cast back to f16
  • Existing test_layer_norm and test_layer_norm_single_result tests (f32 input, no stash_type mismatch) remain unchanged and continue to pass

When ONNX LayerNormalization has stash_type != input dtype (e.g.,
stash_type=float32 with float16 inputs), the existing code correctly
casts x to the stash dtype but passes scale and bias through unchanged.
This causes torch.aten.native_layer_norm to receive mixed dtypes (f32
input with f16 scale/bias), which fails verification or produces
incorrect results.

Fix by also casting scale and bias to the stash dtype when it differs
from the input dtype. The output of native_layer_norm is then cast back
to the original result dtype.

This complements the stash_type support added in llvm#3888 and is the same
class of fix as llvm#4474 (which addressed SimplifiedLayerNormalization).

Fixes llvm#3725
@bjw-0
Copy link
Author

bjw-0 commented Mar 11, 2026

@zjgarvey @Groverkss Would you mind taking a look at this? This is the standard LayerNormalization counterpart to the SimplifiedLayerNormalization stash_type fix in #4474. Thanks!

@bjw-0 bjw-0 changed the title [onnx] Fix LayerNormalization scale/bias dtype mismatch with stash_type [TorchOnnxToTorch] Fix LayerNormalization scale/bias dtype mismatch with stash_type Mar 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Is it necessary to verify the stash_type keyword for LayerNormalization op?

1 participant