-
Notifications
You must be signed in to change notification settings - Fork 419
Fix bug for qk norm in attentions.py #2604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
7a03c53 to
4a32d8b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why this was not caught in the existing tests, e.g., the one related to qwen3next? (non-urgent, can discuss later)
I think it happens in |
This issue was also caught by Airflow tests: https://b.corp.google.com/issues/447464486#comment10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What error were you seeing without this? Mind adding it to the PR description?
redo qk norm initialization remove testing code change Add aot test for non-llama qk norm
042c583 to
9d13e0d
Compare
|
Feel free to add these information to your PR description. I will close my identical fix and let's merge yours. Fixes: b/458142671 |
Description
Bug in attentions.py where query_norm and key_norm were initialized to None first and then assigned to nnx modules. However, nnx was treating query_norm and key_norm as static variables, throwing an error. This pr fixes that issue.
Also, added an AOT test that tests this code block to prevent such issues from happening again:
maxtext/src/MaxText/layers/attentions.py
Line 488 in cb136bc
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: #2602
Tests
Was able to get a successful run with the changes from this pr: https://paste.googleplex.com/6601423567585280
The AOT test also passes locally on cpu vm: https://paste.googleplex.com/5073430399549440
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.