Skip to content

Conversation

@xyang16
Copy link
Contributor

@xyang16 xyang16 commented Dec 13, 2025

Purpose

This PR set default MXFP4 LoRA backend to Marlin because Triton has accuracy issues and Marlin has slight better performance.

  • Use Triton only if Marlin is disabled (set VLLM_MXFP4_USE_MARLIN=0 explicitly) and triton_kernels is supported.
  • Use Marlin by default
    • if VLLM_MXFP4_USE_MARLIN is not set
    • if VLLM_MXFP4_USE_MARLIN=1
    • if triton_kernels is not supported

Benchmarking

Marlin:

VLLM_MXFP4_USE_MARLIN=1 vllm serve openai/gpt-oss-20b \
  --tensor-parallel-size 1 \
  --max-num-seqs 16 \
  --compilation-config '{"cudagraph_mode": "PIECEWISE", "compile_sizes": [1, 2, 4, 8, 16]}' \
  --enable-lora \
  --max-loras 1 \
  --lora-modules lora1=/opt/dlami/nvme/models/gpt-oss-20b-lora-gsm8k \
  --max-lora-rank 32 \
  --no-enable-prefix-caching
vllm bench serve \
  --model openai/gpt-oss-20b \
  --lora-modules lora1 \
  --dataset-name sharegpt \
  --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
  --sharegpt-output-len 800 \
  --max-concurrency 16 \
  --num-prompts 1000 \
  --num-warmups 60 \
  --ignore-eos
============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  418.75    
Total input tokens:                      226792    
Total generated tokens:                  800000    
Request throughput (req/s):              2.39      
Output token throughput (tok/s):         1910.47   
Peak output token throughput (tok/s):    2053.00   
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          2452.06   
---------------Time to First Token----------------
Mean TTFT (ms):                          70.91     
Median TTFT (ms):                        60.91     
P99 TTFT (ms):                           204.49    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.24      
Median TPOT (ms):                        8.25      
P99 TPOT (ms):                           8.50      
---------------Inter-token Latency----------------
Mean ITL (ms):                           8.24      
Median ITL (ms):                         8.11      
P99 ITL (ms):                            9.11      
==================================================

Triton:

vllm serve openai/gpt-oss-20b \
  --tensor-parallel-size 1 \
  --max-num-seqs 16 \
  --compilation-config '{"cudagraph_mode": "PIECEWISE", "compile_sizes": [1, 2, 4, 8, 16]}' \
  --enable-lora \
  --max-loras 1 \
  --lora-modules lora1=/opt/dlami/nvme/models/gpt-oss-20b-lora-gsm8k \
  --max-lora-rank 32 \
  --no-enable-prefix-caching
vllm bench serve \
  --model openai/gpt-oss-20b \
  --lora-modules lora1 \
  --dataset-name sharegpt \
  --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
  --sharegpt-output-len 800 \
  --max-concurrency 16 \
  --num-prompts 1000 \
  --num-warmups 60 \
  --ignore-eos
============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  439.55    
Total input tokens:                      226792    
Total generated tokens:                  800000    
Request throughput (req/s):              2.28      
Output token throughput (tok/s):         1820.06   
Peak output token throughput (tok/s):    1968.00   
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          2336.02   
---------------Time to First Token----------------
Mean TTFT (ms):                          75.77     
Median TTFT (ms):                        49.08     
P99 TTFT (ms):                           214.07    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.65      
Median TPOT (ms):                        8.66      
P99 TPOT (ms):                           8.94      
---------------Inter-token Latency----------------
Mean ITL (ms):                           8.65      
Median ITL (ms):                         8.50      
P99 ITL (ms):                            9.49      
==================================================

Marlin is slightly better because in mxfp4 Triton LoRA is implemented in UnfusedOAITritonExperts. It has to unfuse the activation and reduction to allow to inject lora modules, so this makes it lose the triton_kernels's optimizations of fused activation and fused moe_sum.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request changes the default MXFP4 backend for LoRA from Triton to Marlin, citing better performance and accuracy. The logic is updated to select Marlin by default, and only fall back to Triton if VLLM_MXFP4_USE_MARLIN is explicitly set to 0 and Triton kernels are supported. The change is correct, well-scoped to LoRA as per the PR title, and aligns with the stated purpose. The implementation is clear and concise.

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add metrics in the PR description to show the acc / perf issue you mentioned.

  1. lm_eval for acc
  2. vllm bench... for performance

Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After addressing @yewentao256's comment, LGTM

@xyang16
Copy link
Contributor Author

xyang16 commented Dec 15, 2025

Please add metrics in the PR description to show the acc / perf issue you mentioned.

  1. lm_eval for acc
  2. vllm bench... for performance

@yewentao256 Thanks for review!

  1. Accuracy: Currently gpt-oss mxfp4 with LoRA + triton kernel is generating garbage output, see [Bug]: FULL_AND_PIECEWISE cudagraph mode leading to !!! in generated text #29539 (comment)
  2. Performance: I have pasted vllm bench numbers in the description.

@yewentao256
Copy link
Member

Please add metrics in the PR description to show the acc / perf issue you mentioned.

  1. lm_eval for acc
  2. vllm bench... for performance

@yewentao256 Thanks for review!

  1. Accuracy: Currently gpt-oss mxfp4 with LoRA + triton kernel is generating garbage output, see [Bug]: FULL_AND_PIECEWISE cudagraph mode leading to !!! in generated text #29539 (comment)
  2. Performance: I have pasted vllm bench numbers in the description.

Let's fix the issue first then, do you have time to take a deep look into this issue?

@xyang16
Copy link
Contributor Author

xyang16 commented Dec 16, 2025

Let's fix the issue first then, do you have time to take a deep look into this issue?

I raised a PR to fix this: #30585

There's another PR to fix the cudagraph issue: #30650

@jeejeelee jeejeelee enabled auto-merge (squash) December 18, 2025 00:08
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants