Skip to content

[diffusion] feat: support nvfp4#16885

Draft
zcnrex wants to merge 16 commits intosgl-project:mainfrom
zcnrex:sgld-quant
Draft

[diffusion] feat: support nvfp4#16885
zcnrex wants to merge 16 commits intosgl-project:mainfrom
zcnrex:sgld-quant

Conversation

@zcnrex
Copy link
Contributor

@zcnrex zcnrex commented Jan 11, 2026

Motivation

Modifications

Accuracy Tests

Benchmarking and Profiling

sglang generate --model-path black-forest-labs/FLUX.2-dev --prompt "A smiling girl holding a rectangular white signboard with the text \'sGl Diffusion x FLUx.2\", in animate style"

Original on H200

A_smiling_girl_holding_a_rectangular_white_signboard_with_the_text_sGl_Diffusion_x_FLUx 2_in_ani_20260114-171842_5e83b0df
[01-14 17:18:42] Running pipeline stages: ['input_validation_stage', 'prompt_encoding_stage_primary', 'image_encoding_stage_primary', 'conditioning_stage', 'latent_preparation_stage', 'timestep_preparation_stage', 'denoising_stage', 'decoding_stage']
[01-14 17:18:42] [InputValidationStage] started...
[01-14 17:18:42] [InputValidationStage] finished in 0.0001 seconds
[01-14 17:18:42] [TextEncodingStage] started...
[01-14 17:18:43] [TextEncodingStage] finished in 0.8544 seconds
[01-14 17:18:43] [ImageVAEEncodingStage] started...
[01-14 17:18:43] [ImageVAEEncodingStage] finished in 0.0000 seconds
[01-14 17:18:43] [ConditioningStage] started...
[01-14 17:18:43] [ConditioningStage] finished in 0.0000 seconds
[01-14 17:18:43] [LatentPreparationStage] started...
[01-14 17:18:43] [LatentPreparationStage] finished in 0.0009 seconds
[01-14 17:18:43] [TimestepPreparationStage] started...
[01-14 17:18:43] [TimestepPreparationStage] finished in 0.0008 seconds
[01-14 17:18:43] [DenoisingStage] started...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:22<00:00,  2.25it/s]
[01-14 17:19:05] [DenoisingStage] average time per step: 0.4438 seconds
[01-14 17:19:06] [DenoisingStage] finished in 22.5906 seconds
[01-14 17:19:06] [DecodingStage] started...
[01-14 17:19:06] [DecodingStage] finished in 0.2871 seconds
[01-14 17:19:06] Peak GPU memory: 64.57 GB, Remaining GPU memory at peak: 75.83 GB. Components that can stay resident: ['text_encoder']
[01-14 17:19:07] Output saved to outputs/A_smiling_girl_holding_a_rectangular_white_signboard_with_the_text_sGl_Diffusion_x_FLUx.2_in_ani_20260114-171842_5e83b0df.png
[01-14 17:19:07] Pixel data generated successfully in 24.83 seconds
[01-14 17:19:07] Completed batch processing. Generated 1 outputs in 24.83 seconds
[01-14 17:19:07] Memory usage - Max peak: 66117.49 MB, Avg peak: 66117.49 MB

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions bot added quant LLM Quantization diffusion SGLang Diffusion labels Jan 11, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zcnrex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates NVIDIA FP4 (NVFP4) quantization into the multimodal generation runtime. It provides the necessary infrastructure to configure, load, and execute models using NVFP4 weights, aiming to enhance performance and reduce resource consumption. The changes involve adding a new quantization method and its corresponding linear layer implementation, which handles the intricacies of 4-bit weight representation and optimized matrix multiplication.

Highlights

  • NVFP4 Quantization Support: Introduced comprehensive support for NVIDIA FP4 (NVFP4) quantization, enabling more efficient model inference by reducing memory footprint and potentially increasing speed.
  • New Quantization Configuration: Added ModelOptFp4Config to manage NVFP4-specific configurations, including parsing group size, handling module exclusion, and recognizing NVFP4 serialized checkpoints.
  • Efficient Linear Method Implementation: Implemented ModelOptFp4LinearMethod to handle the creation, loading, and processing of NVFP4 quantized weights, including specialized handling for input and weight scales, and blockwise interleaving.
  • Optimized GEMM Operations: Integrated fp4_gemm operations, leveraging flashinfer or sgl_kernel for high-performance 4-bit matrix multiplication, crucial for the efficiency of quantized models.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for nvfp4 quantization by introducing a new modelopt_quant.py file. The implementation is largely adapted from existing code, which is a good approach. However, I've identified a critical issue with how flashinfer dependencies are handled, which could lead to runtime crashes if the library isn't installed. I've also suggested improvements to make the code more robust by replacing a risky assert with a proper conditional check, and to clarify a validation message in the configuration loading logic. Overall, the changes are in a good direction, and addressing these points will enhance the code's reliability.

Comment on lines 36 to 46
try:
from flashinfer import mm_fp4 as flashinfer_fp4_gemm

enable_flashinfer_fp4_gemm = True
except ImportError:
if current_platform.is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm as cutlass_fp4_gemm
enable_flashinfer_fp4_gemm = False
reorder_rows_for_gated_act_gemm = None
shuffle_matrix_a = None
shuffle_matrix_sf_a = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The try-except block for flashinfer imports is incomplete. It's missing shuffle_matrix_a and shuffle_matrix_sf_a, which are used later in process_weights_after_loading. This will cause a runtime NameError or ImportError if flashinfer is not installed. To make the import handling more robust, these functions should be imported here and set to None in the except block, similar to other flashinfer functions.

Suggested change
try:
from flashinfer import mm_fp4 as flashinfer_fp4_gemm
enable_flashinfer_fp4_gemm = True
except ImportError:
if current_platform.is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm as cutlass_fp4_gemm
enable_flashinfer_fp4_gemm = False
reorder_rows_for_gated_act_gemm = None
shuffle_matrix_a = None
shuffle_matrix_sf_a = None
try:
from flashinfer import mm_fp4 as flashinfer_fp4_gemm
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
enable_flashinfer_fp4_gemm = True
except ImportError:
if current_platform.is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm as cutlass_fp4_gemm
enable_flashinfer_fp4_gemm = False
reorder_rows_for_gated_act_gemm = None
shuffle_matrix_a = None
shuffle_matrix_sf_a = None

Comment on lines 423 to 442
if FLASHINFER_FP4_GEMM_BACKEND == "trtllm":
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

weight = layer.weight
scale = layer.weight_scale
epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
scale = (
shuffle_matrix_sf_a(scale.view(torch.uint8), epilogue_tile_m)
.reshape(scale.shape)
.view(torch.float8_e4m3fn)
)

layer.weight_scale_interleaved = Parameter(scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False)
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block will crash with an ImportError if flashinfer is not installed but FLASHINFER_FP4_GEMM_BACKEND is set to "trtllm". The from flashinfer import ... statement is not guarded. You should add a check for enable_flashinfer_fp4_gemm at the beginning of this block to ensure flashinfer is available. With the suggested change to the top-level imports, the local import statement here will also become unnecessary.

        if FLASHINFER_FP4_GEMM_BACKEND == "trtllm":
            if not enable_flashinfer_fp4_gemm:
                raise ImportError(
                    "flashinfer is not installed, but FLASHINFER_FP4_GEMM_BACKEND is set to 'trtllm'."
                )
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.

            weight = layer.weight
            scale = layer.weight_scale
            epilogue_tile_m = 128
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            scale = (
                shuffle_matrix_sf_a(scale.view(torch.uint8), epilogue_tile_m)
                .reshape(scale.shape)
                .view(torch.float8_e4m3fn)
            )

            layer.weight_scale_interleaved = Parameter(scale, requires_grad=False)
            layer.weight = Parameter(weight, requires_grad=False)
            return

Comment on lines 268 to 273
if not quant_method in ["FP8", "NVFP4"]:
raise ValueError(
f"ModelOpt currently only supports: FP8, NVFP4"
" quantizations in sglang. Please check the "
"quantization config for your model's configuration."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The validation for quant_method is a bit misleading. Since this class is ModelOptFp4Config, it should ideally only handle NVFP4 quantization. Allowing FP8 in the check, even if it's caught and rejected later, can be confusing. It would be clearer to raise an error immediately if the quant_method is not NVFP4.

        if quant_method != "NVFP4":
            raise ValueError(
                f"ModelOptFp4Config only supports NVFP4, but got {quant_method}. "
                "Please check the quantization config for your model's configuration."
            )

Comment on lines +301 to +309
elif (
pattern_split[-1] in fused_patterns
and pattern_split[-1] in prefix_split[-1]
):
# Check if the last part of the excluded pattern is contained in the last part of the prefix
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
assert len(prefix_split) == 5 and len(pattern_split) == 5
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert here to validate the length of prefix_split and pattern_split is risky. An assert should be used for conditions that should never be false in a correct program, not for validating inputs like layer names, which can vary between models. If a model with a different layer naming structure is used, this will cause a crash. It's safer to convert this into a conditional check within the elif statement.

            elif (
                pattern_split[-1] in fused_patterns
                and pattern_split[-1] in prefix_split[-1]
                and len(prefix_split) == 5
                and len(pattern_split) == 5
            ):
                # Check if the last part of the excluded pattern is contained in the last part of the prefix
                # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
                # e.g., model.layers.{i}.self_attn.{fused_weight_name}
                return True

@mickqian
Copy link
Collaborator

is it ready to be reviewed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments