Skip to content

Fix: Align Qwen3 dtype handling with Qwen2 for TPU#1904

Open
xianglon-commits wants to merge 3 commits intovllm-project:mainfrom
xianglon-commits:fix/qwen3-dtype-handling
Open

Fix: Align Qwen3 dtype handling with Qwen2 for TPU#1904
xianglon-commits wants to merge 3 commits intovllm-project:mainfrom
xianglon-commits:fix/qwen3-dtype-handling

Conversation

@xianglon-commits
Copy link

This PR addresses a ValueError caused by a dtype mismatch (bfloat16 vs float32) within the TPU attention kernels when running the Qwen3 model. This issue was encountered when using Tunix with the vLLM-JAX backend on the lance-ds branch.

Expected kv_cache.dtype=dtype(bfloat16) to be equal to k.dtype=dtype(bfloat16) and v.dtype=dtype(bfloat16), but found v.dtype=dtype('float32').

Following the pattern used to fix similar issues in qwen2.py (as suggested by the Tunix team, see https://screenshot.googleplex.com/8AwrZ44yv57Bt95), this change renames the dtype parameter to param_dtype within the __init__ methods of layers in tpu_inference/models/jax/qwen3.py.

This change aims to prevent potential naming conflicts and ensure the intended bfloat16 data type is consistently propagated and used for model parameters on TPU, resolving the mixed-precision error.

wang2yn84 and others added 3 commits March 3, 2026 18:31
Aligned Qwen3 model definition with Qwen2 fixes by renaming
__init__ parameter 'dtype' to 'param_dtype' to ensure
correct bfloat16 handling on TPU, as suggested.
@kyuyeunk
Copy link
Collaborator

would this be related to this pr? #1771

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants