Skip to content

Implementation of Qwen3-VL#1177

Open
ridcl wants to merge 5 commits intogoogle:mainfrom
ridcl:qwen3-vl
Open

Implementation of Qwen3-VL#1177
ridcl wants to merge 5 commits intogoogle:mainfrom
ridcl:qwen3-vl

Conversation

@ridcl
Copy link

@ridcl ridcl commented Feb 28, 2026

This PR implements Qwen3-VL and (partially) resolves #1063.

What's included:

  • Vision and text model
  • Parameter loading from Huggingface endpoint
  • Model config for Qwen3-VL 4B

What's not included yet:

  • Sampling. Current sampler in Tunix doesn't support vision input. Moreover, Qwen3-VL has notion of grid_thw (spatial information about image and video inputs) that requires additional handling. Thus sampling is out of scope for now.
  • Sharding for the vision model.
  • Mapping to SGLang and vLLM. The regular Qwen3 model contains it, but I haven't found any docs on the topic.

Correctness check

Since we don't have sampling yet, I checked layerwise matching with transformers. The script is here. The output:

Embedding max diff: 0.629883
Vision tokens  max=0.6299  mean=0.0270
Deepstack[0]   max=0.0469  mean=0.0027
Deepstack[1]   max=0.7568  mean=0.0083
Deepstack[2]   max=0.4941  mean=0.0127
tf_attention_mask is None: False  (will also run with explicit JAX-derived mask to isolate masking)
Layer  0  max=0.9570  mean=0.0281  worst=(seq=93, dim=731)  is_vision=True
Layer  1  max=2.1875  mean=0.0311  worst=(seq=97, dim=0)  is_vision=True
Layer  2  max=2.0000  mean=0.0344  worst=(seq=98, dim=0)  is_vision=True
Layer  3  max=2.3750  mean=0.0346  worst=(seq=98, dim=0)  is_vision=True
Layer  4  max=3.2500  mean=0.0343  worst=(seq=98, dim=0)  is_vision=True
Layer  5  max=4.0000  mean=0.0340  worst=(seq=97, dim=0)  is_vision=True
Layer  6  max=64.0000  mean=0.0341  worst=(seq=1, dim=4)  is_vision=False
Layer  7  max=64.0000  mean=0.0340  worst=(seq=1, dim=4)  is_vision=False
Layer  8  max=64.0000  mean=0.0355  worst=(seq=1, dim=4)  is_vision=False
Layer  9  max=64.0000  mean=0.0344  worst=(seq=1, dim=4)  is_vision=False
Layer 10  max=64.0000  mean=0.0353  worst=(seq=1, dim=4)  is_vision=False
Layer 11  max=64.0000  mean=0.0359  worst=(seq=1, dim=4)  is_vision=False
Layer 12  max=64.0000  mean=0.0385  worst=(seq=1, dim=4)  is_vision=False
Layer 13  max=64.0000  mean=0.0401  worst=(seq=1, dim=4)  is_vision=False
Layer 14  max=64.0000  mean=0.0416  worst=(seq=1, dim=4)  is_vision=False
Layer 15  max=64.0000  mean=0.0435  worst=(seq=1, dim=4)  is_vision=False
Layer 16  max=64.0000  mean=0.0450  worst=(seq=1, dim=4)  is_vision=False
Layer 17  max=64.0000  mean=0.0477  worst=(seq=1, dim=4)  is_vision=False
Layer 18  max=64.0000  mean=0.0520  worst=(seq=1, dim=4)  is_vision=False
Layer 19  max=64.0000  mean=0.0580  worst=(seq=1, dim=4)  is_vision=False
Layer 20  max=64.0000  mean=0.0642  worst=(seq=1, dim=4)  is_vision=False
Layer 21  max=64.0000  mean=0.0700  worst=(seq=1, dim=4)  is_vision=False
Layer 22  max=64.0000  mean=0.0832  worst=(seq=1, dim=4)  is_vision=False
Layer 23  max=64.0000  mean=0.1035  worst=(seq=1, dim=4)  is_vision=False
Layer 24  max=64.0000  mean=0.1362  worst=(seq=1, dim=4)  is_vision=False
Layer 25  max=64.0000  mean=0.1587  worst=(seq=1, dim=4)  is_vision=False
Layer 26  max=64.0000  mean=0.1861  worst=(seq=1, dim=4)  is_vision=False
Layer 27  max=64.0000  mean=0.2099  worst=(seq=1, dim=4)  is_vision=False
Layer 28  max=64.0000  mean=0.2432  worst=(seq=1, dim=4)  is_vision=False
Layer 29  max=64.0000  mean=0.2894  worst=(seq=1, dim=4)  is_vision=False
Layer 30  max=64.0000  mean=0.3507  worst=(seq=1, dim=4)  is_vision=False
Layer 31  max=64.0000  mean=0.4176  worst=(seq=1, dim=4)  is_vision=False
Layer 32  max=64.0000  mean=0.4967  worst=(seq=1, dim=4)  is_vision=False
Layer 33  max=64.0000  mean=0.5858  worst=(seq=1, dim=4)  is_vision=False
Layer 34  max=84.5000  mean=0.6787  worst=(seq=98, dim=4)  is_vision=True
Layer 35  max=64.0000  mean=0.9139  worst=(seq=1, dim=4)  is_vision=False
Final norm max diff: 43.500000
Logits  max=492.4688  mean=7.2643
JAX top-5: [1986, 334, 95456, 32, 108893]
 PT top-5: [1986, 32, 785, 95456, 8420]
Top-1 match: True

The discrepancy between activations comes from the difference in XLA vs. cuDNN implementation of matmul (discrepancy in the last digit of bfloat16) and is amplified by MLP layers. Despite seemingly large difference in the output logits, I check that the top-1 token matches between JAX and PyTorch implementations.

Reference

Vision encoder is based on:

Text decoder is mostly a copy of tunix/models/qwen3 with RoPE -> mRoPE and deepstack integration.

Colab Notebook

Not used to Colab notebooks, but here's a gist showing the usage with image features:
https://gist.github.com/ridcl/9adb25ecf5a843c3cfae1a9285cf4473

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the Qwen3-VL model, a multimodal large language model, into the tunix framework. The implementation provides the core architecture for both the vision encoder and the text decoder, along with the necessary utilities for loading parameters from Huggingface. A key feature is the adaptation of Rotary Positional Embeddings (RoPE) to a 3D multimodal variant (M-RoPE) to correctly handle combined text and image inputs. The PR also includes robust testing to ensure the correctness of parameter loading, LoRA integration, and the integrity of vision model weights through save/load cycles.

Highlights

  • Qwen3-VL Model Implementation: Implemented the Qwen3-VL multimodal model, integrating both vision and text components.
  • Parameter Loading and Configuration: Enabled parameter loading for Qwen3-VL 4B from Huggingface endpoints and defined its model configuration.
  • Multimodal Rotary Positional Embeddings (M-RoPE): Introduced a 3D M-RoPE mechanism for handling combined text and vision sequences, allowing for causal masking based on temporal positions and spatial awareness.
  • Comprehensive Testing: Developed comprehensive unit tests for Qwen3-VL model parameters, including LoRA merging for the text decoder and vision weight round-trip verification.
  • Deepstack Vision Feature Integration: Integrated deepstack vision features into the text decoder layers to enrich multimodal understanding.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • tests/models/qwen3vl/qwen3vl_params_test.py
    • Added a new test suite to validate Qwen3-VL model parameters, including LoRA merging for the text decoder and round-trip tests for vision encoder weights.
  • tunix/models/qwen3vl/init.py
    • Created the __init__.py file to establish the qwen3vl package and expose its model and params modules.
  • tunix/models/qwen3vl/model.py
    • Implemented the main Qwen3VL model, including its ModelConfig, Embedder, RMSNorm, Attention, MLP, and DecoderLayer components, with specific logic for vision integration and 3D M-RoPE.
  • tunix/models/qwen3vl/params.py
    • Provided utilities for loading and converting Qwen3-VL PyTorch weights from safetensors, handling expert stacking, and facilitating LoRA merged model saving.
  • tunix/models/qwen3vl/vision.py
    • Defined the Qwen3-VL vision encoder components, such as VisionEmbeddings, VisionGridData, VisionModelConfig, VisionPatchEmbed, VisionAttention, VisionMLP, VisionBlock, and VisionPatchMerger, along with rotary positional embeddings for vision tokens.
Activity
  • The author ridcl initiated the implementation of the Qwen3-VL model.
  • This pull request partially addresses issue Qwen3-VL #1063, focusing on the core model implementation.
  • Correctness was verified through layerwise matching against Huggingface transformers, confirming top-1 token prediction accuracy despite bfloat16 activation differences.
  • The implementation draws inspiration from sdan/Qwen3-VL-JAX, erfanzar/EasyDeL, and huggingface/transformers.
  • A Colab Notebook gist demonstrates the model's usage with image features.
  • All required checklist items, including unit tests, documentation, and adherence to contribution guidelines, have been completed.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the implementation for the Qwen3-VL model, including the vision and text components, parameter loading from Huggingface checkpoints, and associated tests. The implementation is comprehensive and follows the repository's JAX-native and NNX-first philosophy. The code is well-structured into separate modules for the model, parameters, and vision encoder. The tests are also thorough, covering both LoRA parameter merging and round-trip weight loading for the vision components.

My review has identified a few critical issues related to model logic, particularly in the deepstack feature injection and a reference to an undefined MoELayer. I've also found some issues in the test coverage and parameter mappings that should be addressed to improve correctness and maintainability. Overall, this is a great contribution, and with these fixes, it will be a solid implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen3-VL

1 participant