Conversation
Summary of ChangesHello @zcnrex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates NVIDIA FP4 (NVFP4) quantization into the multimodal generation runtime. It provides the necessary infrastructure to configure, load, and execute models using NVFP4 weights, aiming to enhance performance and reduce resource consumption. The changes involve adding a new quantization method and its corresponding linear layer implementation, which handles the intricacies of 4-bit weight representation and optimized matrix multiplication. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for nvfp4 quantization by introducing a new modelopt_quant.py file. The implementation is largely adapted from existing code, which is a good approach. However, I've identified a critical issue with how flashinfer dependencies are handled, which could lead to runtime crashes if the library isn't installed. I've also suggested improvements to make the code more robust by replacing a risky assert with a proper conditional check, and to clarify a validation message in the configuration loading logic. Overall, the changes are in a good direction, and addressing these points will enhance the code's reliability.
| try: | ||
| from flashinfer import mm_fp4 as flashinfer_fp4_gemm | ||
|
|
||
| enable_flashinfer_fp4_gemm = True | ||
| except ImportError: | ||
| if current_platform.is_cuda(): | ||
| from sgl_kernel import cutlass_scaled_fp4_mm as cutlass_fp4_gemm | ||
| enable_flashinfer_fp4_gemm = False | ||
| reorder_rows_for_gated_act_gemm = None | ||
| shuffle_matrix_a = None | ||
| shuffle_matrix_sf_a = None |
There was a problem hiding this comment.
The try-except block for flashinfer imports is incomplete. It's missing shuffle_matrix_a and shuffle_matrix_sf_a, which are used later in process_weights_after_loading. This will cause a runtime NameError or ImportError if flashinfer is not installed. To make the import handling more robust, these functions should be imported here and set to None in the except block, similar to other flashinfer functions.
| try: | |
| from flashinfer import mm_fp4 as flashinfer_fp4_gemm | |
| enable_flashinfer_fp4_gemm = True | |
| except ImportError: | |
| if current_platform.is_cuda(): | |
| from sgl_kernel import cutlass_scaled_fp4_mm as cutlass_fp4_gemm | |
| enable_flashinfer_fp4_gemm = False | |
| reorder_rows_for_gated_act_gemm = None | |
| shuffle_matrix_a = None | |
| shuffle_matrix_sf_a = None | |
| try: | |
| from flashinfer import mm_fp4 as flashinfer_fp4_gemm | |
| from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a | |
| enable_flashinfer_fp4_gemm = True | |
| except ImportError: | |
| if current_platform.is_cuda(): | |
| from sgl_kernel import cutlass_scaled_fp4_mm as cutlass_fp4_gemm | |
| enable_flashinfer_fp4_gemm = False | |
| reorder_rows_for_gated_act_gemm = None | |
| shuffle_matrix_a = None | |
| shuffle_matrix_sf_a = None |
| if FLASHINFER_FP4_GEMM_BACKEND == "trtllm": | ||
| # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. | ||
| # FlashInfer provides nvfp4_quantize to quantize + shuffle the | ||
| # layout but we use our own quantization so we have to call | ||
| # shuffles ourselves. | ||
| from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a | ||
|
|
||
| weight = layer.weight | ||
| scale = layer.weight_scale | ||
| epilogue_tile_m = 128 | ||
| weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) | ||
| scale = ( | ||
| shuffle_matrix_sf_a(scale.view(torch.uint8), epilogue_tile_m) | ||
| .reshape(scale.shape) | ||
| .view(torch.float8_e4m3fn) | ||
| ) | ||
|
|
||
| layer.weight_scale_interleaved = Parameter(scale, requires_grad=False) | ||
| layer.weight = Parameter(weight, requires_grad=False) | ||
| return |
There was a problem hiding this comment.
This block will crash with an ImportError if flashinfer is not installed but FLASHINFER_FP4_GEMM_BACKEND is set to "trtllm". The from flashinfer import ... statement is not guarded. You should add a check for enable_flashinfer_fp4_gemm at the beginning of this block to ensure flashinfer is available. With the suggested change to the top-level imports, the local import statement here will also become unnecessary.
if FLASHINFER_FP4_GEMM_BACKEND == "trtllm":
if not enable_flashinfer_fp4_gemm:
raise ImportError(
"flashinfer is not installed, but FLASHINFER_FP4_GEMM_BACKEND is set to 'trtllm'."
)
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
weight = layer.weight
scale = layer.weight_scale
epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
scale = (
shuffle_matrix_sf_a(scale.view(torch.uint8), epilogue_tile_m)
.reshape(scale.shape)
.view(torch.float8_e4m3fn)
)
layer.weight_scale_interleaved = Parameter(scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False)
return| if not quant_method in ["FP8", "NVFP4"]: | ||
| raise ValueError( | ||
| f"ModelOpt currently only supports: FP8, NVFP4" | ||
| " quantizations in sglang. Please check the " | ||
| "quantization config for your model's configuration." | ||
| ) |
There was a problem hiding this comment.
The validation for quant_method is a bit misleading. Since this class is ModelOptFp4Config, it should ideally only handle NVFP4 quantization. Allowing FP8 in the check, even if it's caught and rejected later, can be confusing. It would be clearer to raise an error immediately if the quant_method is not NVFP4.
if quant_method != "NVFP4":
raise ValueError(
f"ModelOptFp4Config only supports NVFP4, but got {quant_method}. "
"Please check the quantization config for your model's configuration."
)| elif ( | ||
| pattern_split[-1] in fused_patterns | ||
| and pattern_split[-1] in prefix_split[-1] | ||
| ): | ||
| # Check if the last part of the excluded pattern is contained in the last part of the prefix | ||
| # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa | ||
| # e.g., model.layers.{i}.self_attn.{fused_weight_name} | ||
| assert len(prefix_split) == 5 and len(pattern_split) == 5 | ||
| return True |
There was a problem hiding this comment.
Using assert here to validate the length of prefix_split and pattern_split is risky. An assert should be used for conditions that should never be false in a correct program, not for validating inputs like layer names, which can vary between models. If a model with a different layer naming structure is used, this will cause a crash. It's safer to convert this into a conditional check within the elif statement.
elif (
pattern_split[-1] in fused_patterns
and pattern_split[-1] in prefix_split[-1]
and len(prefix_split) == 5
and len(pattern_split) == 5
):
# Check if the last part of the excluded pattern is contained in the last part of the prefix
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
return True|
is it ready to be reviewed? |
Motivation
Modifications
Accuracy Tests
Benchmarking and Profiling
Original on H200
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci