Skip to content

Feat/support block quant#862

Merged
cjx0709 merged 28 commits into
mainfrom
feat/support_block_quant
Mar 23, 2026
Merged

Feat/support block quant#862
cjx0709 merged 28 commits into
mainfrom
feat/support_block_quant

Conversation

@cjx0709

@cjx0709 cjx0709 commented Mar 5, 2026

Copy link
Copy Markdown
Collaborator

Motivation

Add block-wise quantization support from TPU inference, with coverage for both dense Linear layers and EPMoE.

Closes #879

Modifications

  • Add block-wise quant support for 'QuantizedLinear', basing on adapted quanted linear kernel code from TPU Inference project includeing:
    • per-channel
    • block-channel/ sub-channel
    • 2D block scale formats
  • Add block-wise quant support for 'EPMoE', includeing:
    • dynamic/static scale loading
    • dynamic/static quantization
    • scale normalization to the GMM kernel layout
  • Add block-wise supported quantized matmul kernel for TPU inference
    • Integrate the TPU blockwise quantized matmul kernel with tuned fallback logic and safer validation checks.
  • Move block-quant scale expansion (jnp.repeat) from runtime to init/load time for better inference performance

Accuracy Tests

This PR will not effects for existing non-quantized model or channel-wise kernel. And add test List into CI about:

  • Unit tests:
    • python/sgl_jax/test/kernels/quantized_linear_test.py
    • python/sgl_jax/test/kernels/moe_block_quant_test.py
    • python/sgl_jax/test/test_linear_tp.py
  • TPU E2E tests:
    • test/srt/test_moe_block_quant_e2e.py
    • test/srt/quantization/test_w8_block_dynamic_quantization.py
    • test/srt/quantization/test_w8_moe_block_linear_channel_quantization.py

Qwen3-30B-A3B (tp=4, ep=4)

Eval Set Variant Score Delta vs BF16
GSM8K BF16 0.939 -
GSM8K INT8 mixed (Linear per-channel + MoE block) 0.941 +0.002
GPQA Diamond BF16 0.152 -
GPQA Diamond INT8 mixed (Linear per-channel + MoE block) 0.141 -0.011

Qwen3-8B (tp=4)

Eval Set Variant Score Delta vs BF16
GSM8K BF16 0.861 -
GSM8K INT8 block dynamic 0.897 +0.036
GPQA Diamond BF16 0.066 -
GPQA Diamond INT8 block dynamic 0.121 +0.055

Benchmarking and Profiling

Hardware: TPU v6e-4
Benchmark tool: bench_one_batch_server.py (single-batch offline, input_len=4096, output_len=1024)
Server config: precompilation enabled

Qwen3-8B BF16 (tp=4)

batch size latency (s) input throughput (tok/s) output throughput (tok/s) ITL (ms)
1 5.91 48172.07 175.83 5.69
4 6.48 56459.25 661.91 6.04
8 7.10 57645.86 1254.03 6.38
16 8.77 58387.38 2142.96 7.47
32 12.28 58743.15 3261.23 9.81

Qwen3-8B INT8 block dynamic (tp=4)

Config: int8_block_128_dynamic.yaml

batch size latency (s) input throughput (tok/s) output throughput (tok/s) ITL (ms)
1 6.00 36538.96 173.85 5.75
4 6.29 41405.21 694.67 5.76
8 7.18 42115.50 1279.84 6.25
16 9.15 42494.77 2154.47 7.43
32 13.06 42679.25 3279.24 9.76

Qwen3-30B-A3B BF16 (tp=4, ep=4)

batch size latency (s) input throughput (tok/s) output throughput (tok/s) ITL (ms)
1 4.87 23956.37 217.74 4.59
4 6.10 25592.71 750.14 5.33
8 9.09 25871.30 1046.73 7.64
16 13.54 25951.58 1487.14 10.76
32 22.60 26033.21 1865.62 17.15

Qwen3-30B-A3B INT8 mixed (tp=4, ep=4)

Config: int8_moe_block_128_linear_channel_dynamic.yaml (Linear per-channel + MoE block)

batch size latency (s) input throughput (tok/s) output throughput (tok/s) ITL (ms)
1 5.56 21717.74 190.75 5.24
4 5.88 23351.42 790.99 5.06
8 8.07 23682.42 1224.88 6.53
16 12.87 23867.11 1617.74 9.89
32 22.91 23780.35 1883.15 16.99

Key Observations

  • Qwen3-8B: INT8 block 128 prefill throughput ~27% lower than BF16 (blockwise kernel overhead), but decode throughput is on par or slightly better due to reduced memory bandwidth from INT8 weights. ITL is nearly identical.
  • Qwen3-30B-A3B (MoE): INT8 mixed (Linear per-channel + MoE block) prefill throughput ~9% lower than BF16, decode throughput slightly higher at larger batch sizes. ITL comparable.
  • Accuracy: Both quantized variants maintain accuracy within noise of BF16 baselines on GSM8K and GPQA Diamond.

Checklist

  • Please use English, otherwise it will be closed.
  • The purpose of the PR, or link existing issues this PR will resolve.
  • The test plan, such as providing test command.
  • (Optional) The necessary documentation update.

@gemini-code-assist

Copy link
Copy Markdown

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@cjx0709 cjx0709 force-pushed the feat/support_block_quant branch 2 times, most recently from 1009dbd to 98d6c8f Compare March 9, 2026 09:10
@cjx0709 cjx0709 marked this pull request as ready for review March 9, 2026 11:43
@gemini-code-assist

Copy link
Copy Markdown

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly advances the framework's quantization capabilities by introducing robust support for block quantization. It integrates specialized kernels for efficient blockwise matrix multiplication, particularly optimized for TPU execution. The changes also provide more granular control over quantization settings through extended configuration options, enabling users to specify block sizes and exclude specific layers from quantization. This enhancement is crucial for optimizing the performance and memory footprint of large language models and MoE architectures.

Highlights

  • Block Quantization Support: Introduced comprehensive support for block quantization in linear layers and Mixture-of-Experts (MoE) layers, allowing for more efficient weight representation.
  • Optimized Quantization Kernels: Integrated new, optimized kernels for blockwise quantized matrix multiplication, specifically designed for JAX Pallas and TPU environments, to enhance performance.
  • Enhanced Quantization Configuration: Extended the QuantizationConfig with weight_block_size and ignored_layers fields, providing finer-grained control over quantization rules and layer exclusion.
  • Dynamic Block Quantization Schemes: Added new YAML configuration files to define dynamic block quantization, including mixed quantization strategies for MoE layers and standard linear layers.
  • Improved Sharding and Scale Handling: Updated internal utilities to better manage sharding during block quantization and to correctly handle various scale tensor formats for MoE layers, ensuring compatibility with new kernels.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sgl_jax/srt/configs/quantization_config.py
    • Added Integral import for type hinting.
    • Introduced _normalize_weight_block_size function to validate and normalize block size inputs.
    • Updated QuantizationConfig dataclass with is_static_checkpoint, ignored_layers, and weight_block_size fields.
    • Modified from_yaml method to parse the newly added quantization configuration fields.
  • python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/init.py
    • Added a new __init__.py file to expose the quantized_matmul_kernel as quantized_matmul.
  • python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/blockwise_kernel.py
    • Added a new file implementing a blockwise quantized matrix multiplication kernel using JAX Pallas for TPU.
  • python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/kernel.py
    • Added a new file implementing a standard quantized matrix multiplication kernel using JAX Pallas.
  • python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/tuned_block_sizes.py
    • Added a new file defining TunedKey and TunedValue NamedTuples for optimized block sizes.
    • Included a large dictionary (TUNED_BLOCK_SIZES_RAW) containing tuned block sizes for various TPU versions and configurations.
    • Provided utility functions like get_device_vmem_limit, get_tpu_version, get_key, and get_tuned_block_sizes.
  • python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/util.py
    • Added a new file containing utility functions for quantized matrix multiplication kernels, including quantize_tensor, get_vmem_limit, and input validation.
  • python/sgl_jax/srt/kernels/quantized_matmul/kernel.py
    • Modified xla_quantized_matmul_local to support block quantization and conditionally use a third-party blockwise kernel on TPU.
    • Added lazy loading mechanisms for third-party blockwise kernels and their tuning APIs.
    • Introduced helper functions for calculating multiples, powers of two, and safe blockwise tuned values.
  • python/sgl_jax/srt/layers/linear.py
    • Modified LinearBase to remove named_scope and adjust bias sharding.
    • Updated LinearBase.__call__ to use jnp.dot and shard_map for sharded dot products.
    • Modified QuantizedLinear to accept weight_block_size and activation_quant_dtype.
    • Updated QuantizedLinear.from_linear to handle block quantization during conversion and static input scenarios.
    • Adjusted QuantizedLinear.__call__ to pass block quantization parameters to the underlying matmul kernel and adapt scale sharding.
  • python/sgl_jax/srt/layers/moe.py
    • Added weight_block_size attribute to the EPMoE class.
    • Introduced _normalize_scale_for_gmm to standardize scale tensor shapes for GMM operations.
    • Modified quantize_weights to support block quantization for MoE weights and update scale sharding.
    • Updated __call__ to utilize _normalize_scale_for_gmm for wi_0_scale, wi_1_scale, and wo_scale.
  • python/sgl_jax/srt/utils/quantization/configs/int8_block_128_dynamic.yaml
    • Added a new YAML configuration file for INT8 dynamic block quantization with 128x128 block sizes.
  • python/sgl_jax/srt/utils/quantization/configs/int8_moe_block_128_linear_channel_dynamic.yaml
    • Added a new YAML configuration file for INT8 dynamic mixed quantization, applying block quantization to MoE layers and per-channel quantization to linear layers.
  • python/sgl_jax/srt/utils/quantization/quantization_utils.py
    • Imported _normalize_weight_block_size from quantization_config.
    • Added _get_block_reshape_sharding and _get_safe_block_quant_input_sharding for improved sharding handling during block quantization.
    • Modified apply_linear_quantization to parse weight_block_size from rules and ignored_layers from the configuration.
    • Updated _replace_linear_recursive to skip ignored layers and pass weight_block_size to QuantizedLinear.from_linear.
    • Modified quantize_tensor_simple to prevent division by zero in scale calculation.
    • Updated quantize_tensor to handle sharding during block reshaping and final output reshaping.
  • python/sgl_jax/srt/utils/weight_utils.py
    • Added _maybe_convert_epmoe_scale_for_kernel to convert offline EPMoE scales to the 4D layout required by GMM.
    • Integrated _maybe_convert_epmoe_scale_for_kernel into load_weights_from_safetensors for proper scale loading.
  • python/sgl_jax/test/kernels/moe_block_quant_test.py
    • Added a new test file to verify the block quantization logic of EPMoE, including tests for scale shapes and invalid scale layouts.
  • python/sgl_jax/test/kernels/quantized_linear_test.py
    • Added a new test file for quantized linear layers, covering various scale formats, block quantization kernel accuracy, tuning fallback, linear rule overrides, and ignored layers.
  • python/sgl_jax/test/test_linear_tp.py
    • Added a new test file to ensure row-parallel linear layers produce results consistent with dense dot products.
  • test/srt/quantization/test_w8_block_dynamic_quantization.py
    • Added a new end-to-end test for W8 INT8 dynamic block quantization.
  • test/srt/quantization/test_w8_moe_block_linear_channel_quantization.py
    • Added a new end-to-end test for W8 INT8 mixed quantization, specifically for MoE block and linear channel quantization.
  • test/srt/run_suite.py
    • Modified TestFile class to include a runner attribute.
    • Updated run_one_file to use sys.executable instead of uv run python3 and support pytest as a test runner.
    • Added the newly created test files to the unit-test-tpu-v6e-1, unit-test-tpu-v6e-4, and e2e-test-tpu-v6e-4 test suites.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@cjx0709 cjx0709 force-pushed the feat/support_block_quant branch 8 times, most recently from c70a952 to c3e5734 Compare March 12, 2026 06:51
@cjx0709 cjx0709 requested a review from JamesBrianD March 12, 2026 08:26
@JamesBrianD JamesBrianD requested a review from Bob-Chen222 March 12, 2026 08:33
@JamesBrianD

Copy link
Copy Markdown
Collaborator

@Bob-Chen222 Please help review this pr.

Comment thread python/sgl_jax/srt/layers/linear.py Outdated
Comment thread python/sgl_jax/srt/kernels/quantized_matmul/kernel.py
Comment thread python/sgl_jax/srt/kernels/quantized_matmul/3rd_quantized_matmul/__init__.py Outdated
@cjx0709 cjx0709 force-pushed the feat/support_block_quant branch from d1e6ff7 to 9df8bf3 Compare March 13, 2026 06:06
@cjx0709 cjx0709 force-pushed the feat/support_block_quant branch from 9df8bf3 to cdfba68 Compare March 13, 2026 06:51
cjx0709 added a commit that referenced this pull request Mar 18, 2026
Squashed from PR #862 (feat/support_block_quant)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@JamesBrianD

This comment was marked as resolved.

@cjx0709 cjx0709 force-pushed the feat/support_block_quant branch from be236f6 to 785aafb Compare March 18, 2026 12:57
- Add expand_block_scale() in blockwise_utils.py with channel_to_block
  extensibility for future non-uniform block quant support
- Update kernel.py to detect pre-expanded 3D scale (ndim==3) instead of
  compact 2D scale, removing runtime convert_block_scale_to_kernel_layout
- Update QuantizedLinear to pre-expand scale at init time (both static
  and dynamic paths in from_linear, plus auto-expand in __init__)
- Add _maybe_expand_linear_block_scale() in weight_utils.py to convert
  2D checkpoint scale to 3D kernel-ready layout at load time
- Remove redundant _normalize_scale_for_gmm() calls from EPMoE.__call__
  since scales are already in correct 4D format from quantize_weights()
  or _maybe_convert_epmoe_scale_for_kernel()

2. Fix test to use pre-expanded 3D scale for kernel API

The xla_quantized_matmul_local kernel now expects pre-expanded 3D scale
[in_blocks, 1, n_out] instead of compact 2D [out_blocks, in_blocks].
Update the direct kernel test to call expand_block_scale before invoking
the kernel.

3. Restore _normalize_scale_for_gmm as safety net in EPMoE.__call__

The normalize call is cheap for already-4D scales (just validation)
and handles edge cases where callers bypass quantize_weights() and
directly set 2D/3D scale params. The real perf win is in the linear
layer where jnp.repeat was inside shard_map on every step.
@cjx0709 cjx0709 force-pushed the feat/support_block_quant branch from 785aafb to 0f532a4 Compare March 18, 2026 13:01
@cjx0709

This comment was marked as resolved.

Comment thread python/sgl_jax/srt/layers/moe.py
Comment thread python/sgl_jax/srt/kernels/quantized_matmul/kernel.py Outdated
@cjx0709 cjx0709 force-pushed the feat/support_block_quant branch from c2682dd to f926266 Compare March 19, 2026 11:42
cjx0709 added a commit that referenced this pull request Mar 20, 2026
Squashed from PR #862 (feat/support_block_quant)
@cjx0709 cjx0709 merged commit f009b6e into main Mar 23, 2026
19 checks passed
@cjx0709 cjx0709 deleted the feat/support_block_quant branch March 23, 2026 04:00
@cjx0709 cjx0709 mentioned this pull request May 6, 2026
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Support block-wise quantization for Linear and EPMoE on TPU

3 participants