Skip to content

Conversation

@Rohan-Bierneni
Copy link
Collaborator

@Rohan-Bierneni Rohan-Bierneni commented Nov 5, 2025

Description

Bug in attentions.py where query_norm and key_norm were initialized to None first and then assigned to nnx modules. However, nnx was treating query_norm and key_norm as static variables, throwing an error. This pr fixes that issue.

Also, added an AOT test that tests this code block to prevent such issues from happening again:

if self.use_qk_norm and not is_llama4_decoder_block:

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: #2602

Tests

Was able to get a successful run with the changes from this pr: https://paste.googleplex.com/6601423567585280

The AOT test also passes locally on cpu vm: https://paste.googleplex.com/5073430399549440

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Collaborator

@parambole parambole left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this was not caught in the existing tests, e.g., the one related to qwen3next? (non-urgent, can discuss later)

@eitanporat
Copy link
Collaborator

Do you know why this was not caught in the existing tests, e.g., the one related to qwen3next? (non-urgent, can discuss later)

I think it happens in use_qk_norm branch of the code and i don't see it here in the config src/MaxText/configs/models/qwen3-next-80b-a3b.yml

@SurbhiJainUSC
Copy link
Collaborator

Do you know why this was not caught in the existing tests, e.g., the one related to qwen3next? (non-urgent, can discuss later)

This issue was also caught by Airflow tests: https://b.corp.google.com/issues/447464486#comment10

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What error were you seeing without this? Mind adding it to the PR description?

redo qk norm initialization

remove testing code change

Add aot test for non-llama qk norm
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-fullattention branch from 042c583 to 9d13e0d Compare November 5, 2025 23:48
@hengtaoguo
Copy link
Collaborator

Feel free to add these information to your PR description. I will close my identical fix and let's merge yours.

  • This is breaking our Gemma3/Qwen3 XLML tests and decoding utils for these two model families.

Fixes: b/458142671

@Rohan-Bierneni Rohan-Bierneni self-assigned this Nov 6, 2025
@copybara-service copybara-service bot merged commit b8fb668 into main Nov 6, 2025
48 of 52 checks passed
@copybara-service copybara-service bot deleted the rbierneni-qwen3-next-fullattention branch November 6, 2025 16:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cannot assign data value of type '<class 'MaxText.layers.normalizations.RMSNorm'>'

8 participants