Skip to content

PCC Drop in SDPA Attention due to Softmax #3301

@sonalibaskaran2499

Description

@sonalibaskaran2499

Describe the bug

  • Observing a PCC drop in the phi2/causal_lm variant, PCC=0.9310092454805221
  • The drop originates from the sdpa attention torch.nn.functional.scaled_dot_product_attention
  • In TTNN, torch.nn.scaled_dot_product_attention is decomposed into multiple lower-level operations. After debugging, traced the PCC degradation to the softmax step: attn_weights = torch.softmax(qk_masked, dim=-1)
  • The following models from issue #2861 are also using SDPA attention and exhibit PCC degradation:
test_all_models_torch[gemma/pytorch-google/gemma-1.1-7b-it-single_device-inference]                                                         language    red         p150         INCORRECT_RESULT           PASSING       0.9724180883623303      0.99       PCC_EN   single_device    131.686   N/A
test_all_models_torch[llama/causal_lm/pytorch-llama_3_1_8b_instruct-single_device-inference]                                                language    red         p150         INCORRECT_RESULT           PASSING       0.9702271827420873      0.99       PCC_EN   single_device    171.237   N/A
test_all_models_torch[mistral/pytorch-mistral_nemo_instruct_2407-single_device-inference]                                                   language    red         p150         INCORRECT_RESULT           PASSING       0.9821535219381753      0.99       PCC_EN   single_device    187.107   N/A
test_all_models_torch[phi2/causal_lm/pytorch-microsoft/phi-2-single_device-inference]                                                       language    red         n150         INCORRECT_RESULT           PASSING       0.9344665590421978      0.98       PCC_EN   single_device    142.625   N/A
test_all_models_torch[phi2/causal_lm/pytorch-microsoft/phi-2-single_device-inference]                                                       language    red         p150         INCORRECT_RESULT           PASSING       0.9310092454805221      0.98       PCC_EN   single_device    109.987   N/A

Logs

Softmax sanity - softmax_xla.log

Steps to reproduce the issue

git checkout sonalib/softmax
git submodule update --init --recursive
git lfs pull --include '*.pt’

# model run 
pytest tests/runner/test_models.py::test_all_models_torch[phi2/causal_lm/pytorch-microsoft/phi-2-single_device-inference] -svv

# sanity run
pytest /home/tt-xla/tests/torch/graphs/test_softmax.py -svv

# emitpy run
pytest /home/tt-xla/examples/pytorch/codegen/python/softmax.py -svv 

Expected behavior

PCC should meet or exceed the 0.99 threshold

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions