Skip to content

[GPT-OSS] support fp8 online quantization for gpt-oss bf16#18988

Open
zminglei wants to merge 2 commits intosgl-project:mainfrom
zminglei:gpt-oss-fp8
Open

[GPT-OSS] support fp8 online quantization for gpt-oss bf16#18988
zminglei wants to merge 2 commits intosgl-project:mainfrom
zminglei:gpt-oss-fp8

Conversation

@zminglei
Copy link
Collaborator

@zminglei zminglei commented Feb 18, 2026

Motivation

  1. Keep moe_runner_backend as auto when launch gpt-oss bf16 with online quantization (e.g. fp8) to pick up either deep_gemm or triton moe backend, since triton_kernels moe backend doesn't support online quantization like fp8 yet.
  2. Updated FP8MoeMethod to accept with_bias to support models with bias in moe projs, like GPT-OSS.

Modifications

Accuracy Tests

Before:

# with fp8 online quantization:
SGLANG_TORCH_PROFILER_DIR=/home/jobuser/zminglei/sglang TIKTOKEN_RS_CACHE_DIR=/shared/public/sharing/inseek/gpt-oss-vocab python3 -m sglang.launch_server --model-path '/shared/public/elr-models/openai/gpt-oss-120b-bf16/' --reasoning-parser gpt-oss --tp 4 --quantization fp8

  self.load_weights_and_postprocess(
  File "/home/jobuser/zminglei/sglang/python/sglang/srt/model_loader/loader.py", line 686, in load_weights_and_postprocess
    model.load_weights(weights)
  File "/home/jobuser/zminglei/sglang/python/sglang/srt/models/gpt_oss.py", line 749, in load_weights
    self._load_normal_weights(
  File "/home/jobuser/zminglei/sglang/python/sglang/srt/models/gpt_oss.py", line 1082, in _load_normal_weights
    weight_loader(
  File "/home/jobuser/zminglei/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 949, in weight_loader_fused
    self._load_model_weight_or_group_weight_scale(
  File "/home/jobuser/zminglei/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 358, in _load_model_weight_or_group_weight_scale
    self._load_w2(
  File "/home/jobuser/zminglei/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 527, in _load_w2
    loaded_weight = loaded_weight.narrow(
IndexError: start out of range (expected to be in range of [-2880, 2880], but got 5760)


# without online quantization
SGLANG_TORCH_PROFILER_DIR=/home/jobuser/zminglei/sglang TIKTOKEN_RS_CACHE_DIR=/shared/public/sharing/inseek/gpt-oss-vocab python3 -m sglang.launch_server --model-path '/shared/public/elr-models/openai/gpt-oss-120b-bf16/' --reasoning-parser gpt-oss --tp 4

python benchmark/gsm8k/bench_sglang.py --data-path /shared/public/data/gsm8k/test.jsonl --num-questions 1319 --parallel 1319
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:47<00:00, 27.70it/s]
Accuracy: 0.848
Invalid: 0.016
Latency: 47.626 s
Output throughput: 9070.203 token/s

After:

# with fp8 online quantization:
SGLANG_TORCH_PROFILER_DIR=/home/jobuser/zminglei/sglang TIKTOKEN_RS_CACHE_DIR=/shared/public/sharing/inseek/gpt-oss-vocab python3 -m sglang.launch_server --model-path '/shared/public/elr-models/openai/gpt-oss-120b-bf16/' --reasoning-parser gpt-oss --tp 4 --quantization fp8

python benchmark/gsm8k/bench_sglang.py --data-path /shared/public/data/gsm8k/test.jsonl --num-questions 1319 --parallel 1319
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:31<00:00, 41.33it/s]
Accuracy: 0.834
Invalid: 0.016
Latency: 31.922 s
Output throughput: 14066.212 token/s


# without online quantization:
SGLANG_TORCH_PROFILER_DIR=/home/jobuser/zminglei/sglang TIKTOKEN_RS_CACHE_DIR=/shared/public/sharing/inseek/gpt-oss-vocab python3 -m sglang.launch_server --model-path '/shared/public/elr-models/openai/gpt-oss-120b-bf16/' --reasoning-parser gpt-oss --tp 4

python benchmark/gsm8k/bench_sglang.py --data-path /shared/public/data/gsm8k/test.jsonl --num-questions 1319 --parallel 1319
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:45<00:00, 29.13it/s]
Accuracy: 0.848
Invalid: 0.014
Latency: 45.280 s
Output throughput: 9611.029 token/s

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zminglei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for FP8 online quantization in GPT-OSS bf16 models. It resolves compatibility issues by ensuring that the appropriate MOE backend is selected when quantization is enabled, specifically preventing the use of triton_kernels which lacks FP8 support. Furthermore, the changes extend the FP8 quantization method to properly integrate bias parameters within MoE layers, critical for maintaining model integrity and performance for models that utilize them.

Highlights

  • FP8 Online Quantization Support: Enabled FP8 online quantization for GPT-OSS bf16 models, allowing for more efficient inference.
  • MOE Backend Selection Logic: Adjusted the moe_runner_backend selection to default to auto when online quantization (e.g., FP8) is active, as the triton_kernels backend does not yet support FP8.
  • Bias Support in FP8 Quantization: Updated the FP8MoeMethod to correctly handle and register bias terms within Mixture-of-Experts (MoE) projections, which is essential for models like GPT-OSS that include biases.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/quantization/fp8.py
    • Initialized a with_bias attribute to False in the FP8MoeMethod constructor.
    • Modified the create_weights method to accept a with_bias argument and store its value in the instance.
    • Added conditional logic within create_weights to register w13_weight_bias and w2_weight_bias parameters when self.with_bias is true.
    • Updated the apply method to pass optional b13 and b2 bias parameters to TritonMoeQuantInfo.
  • python/sglang/srt/server_args.py
    • Modified the condition for automatically setting moe_runner_backend to "triton_kernel" for GPT-OSS models, adding an explicit check to ensure self.quantization is None. This prevents triton_kernel from being selected when FP8 quantization is active.
Activity
  • The pull request author provided a detailed motivation for the changes, outlining the need for FP8 online quantization and bias support.
  • Accuracy test results for GSM8K with FP8 quantization were included, showing an accuracy of 0.830.
  • No human comments or reviews have been posted on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds support for FP8 online quantization for GPT-OSS models, which is a valuable enhancement. The changes are well-implemented, particularly the addition of bias handling in Fp8MoEMethod and the adjustment of the MoE backend selection logic in server_args.py. The code is clear, follows existing patterns, and correctly uses getattr for safe access to optional bias parameters. The logic to prevent the use of the triton_kernel backend with quantization is also sound. Overall, this is a solid contribution.

@zminglei zminglei marked this pull request as ready for review February 19, 2026 00:11
@zminglei
Copy link
Collaborator Author

zminglei commented Feb 19, 2026

/tag-and-rerun-ci again

jasperjiaguo added a commit to jasperjiaguo/vllm that referenced this pull request Feb 19, 2026
GPT-OSS-120B has biased MoE layers (gate_up_proj_bias, down_proj_bias).
When serving the BF16 model with `--quantization fp8`, the Fp8MoEMethod
does not register bias parameters, causing weight loading failures.

This adds bias support to Fp8MoEMethod:
- Register w13_bias/w2_bias in create_weights() when moe.has_bias is set
- Pass biases through to fused_experts() in apply()
- Guard against unsupported FusedMoEModularKernel + bias combination

Tested on 4xH200 with GPT-OSS-120B BF16:
- vllm serve --quantization fp8 loads successfully with bias
- GSM8K accuracy maintained (0.834 FP8 vs 0.848 BF16)

Companion PR: sgl-project/sglang#18988

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/vllm that referenced this pull request Feb 19, 2026
GPT-OSS-120B has biased MoE layers (gate_up_proj_bias, down_proj_bias).
When serving the BF16 model with `--quantization fp8`, the Fp8MoEMethod
does not register bias parameters, causing weight loading failures.

This adds bias support to Fp8MoEMethod:
- Register w13_bias/w2_bias in create_weights() when moe.has_bias is set
- Pass biases through to fused_experts() in apply()
- Guard against unsupported FusedMoEModularKernel + bias combination

Tested on 4xH200 with GPT-OSS-120B BF16:
- vllm serve --quantization fp8 loads successfully with bias
- GSM8K accuracy maintained (0.834 FP8 vs 0.848 BF16)

Companion PR: sgl-project/sglang#18988

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
jasperjiaguo added a commit to jasperjiaguo/vllm that referenced this pull request Feb 19, 2026
GPT-OSS-120B has biased MoE layers (gate_up_proj_bias, down_proj_bias).
When serving the BF16 model with `--quantization fp8`, Fp8MoEMethod and
Fp8OnlineMoEMethod don't register bias parameters, causing weight
loading failures.

This adds bias support to both FP8 MoE method classes:
- Register w13_bias/w2_bias in Fp8MoEMethod.create_weights() when
  moe.has_bias is set
- Inject biases into quant_config via get_fused_moe_quant_config()
- Register biases in Fp8OnlineMoEMethod.create_weights() using the
  original (unpatched) weight_loader

Tested on 4xH200 with GPT-OSS-120B BF16 + vllm 0.15.1:
- vllm serve --quantization fp8 loads and serves successfully
- TRITON Fp8 MoE backend selected correctly
- GSM8K accuracy: 0.834 (FP8) vs 0.848 (BF16)
- 1.5x throughput improvement with FP8

Companion PR: sgl-project/sglang#18988

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments