Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion docs/diffusion/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@ backend.
| `nunchaku-svdq` | Pre-quantized Nunchaku transformer weights, usually named `svdq-{int4\|fp4}_r{rank}-...` | `--transformer-weights-path` | Model-specific support such as Qwen-Image, FLUX, and Z-Image | `nunchaku` | SGLang can infer precision and rank from the filename and supports both `int4` and `nvfp4` |
| `msmodelslim` | Pre-quantized msmodelslim transformer weights | `--model-path` | Wan2.2 family | None | Currently only compatible with the Ascend NPU family and supports both `w8a8` and `w4a4` |

## Validated ModelOpt Checkpoints

This section is the canonical support matrix for diffusion ModelOpt checkpoints
that have been brought up and verified in SGLang.

### FP8

| Base Model | Validated Scope | HF DiT Weights | Notes |
| --- | --- | --- | --- |
| `black-forest-labs/FLUX.1-dev` | single-transformer override, deterministic latent/image comparison, H100 benchmark, torch-profiler trace | `BBuf/flux1-dev-modelopt-fp8-sglang-transformer` | SGLang converter keeps a validated BF16 fallback set for modulation and FF projection layers; use `--model-id FLUX.1-dev` for local mirrors |
| `black-forest-labs/FLUX.2-dev` | single-transformer override load and generation path | `BBuf/flux2-dev-modelopt-fp8-sglang-transformer` | published SGLang-ready transformer override |
| `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | primary `transformer` quantized, `transformer_2` kept BF16 | `BBuf/wan22-t2v-a14b-modelopt-fp8-sglang-transformer` | do not describe this as dual-transformer full-model FP8 unless that path is validated separately |

### NVFP4

| Base Model | Validated Scope | HF DiT Weights | Notes |
| --- | --- | --- | --- |
| `black-forest-labs/FLUX.1-dev` | mixed BF16+NVFP4 transformer override, correctness validation, 4x RTX 5090 benchmark, torch-profiler trace | `unpublished` | use `build_modelopt_nvfp4_transformer.py`; validated builder keeps selected FLUX.1 modules in BF16 and sets `swap_weight_nibbles=false` |
| `black-forest-labs/FLUX.2-dev` | packed-QKV load path | `black-forest-labs/FLUX.2-dev-NVFP4` | validated packed export detection and runtime layout handling |
| `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | primary `transformer` quantized with official ModelOpt FP4 export, `transformer_2` kept BF16 | `unpublished` | global `--transformer-weights-path` targets only the primary `transformer`; keep `transformer_2` on the base checkpoint unless you pass a per-component override; validated on B200 with `SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND=cudnn` |

## ModelOpt FP8

### Usage Examples
Expand Down Expand Up @@ -83,7 +104,7 @@ sglang generate \
- The layerwise offload path now preserves the non-contiguous FP8 weight stride
expected by the runtime FP8 GEMM path.
- To build the converted checkpoint yourself from a ModelOpt diffusers export,
use `python -m sglang.multimodal_gen.tools.convert_modelopt_fp8_checkpoint`.
use `python -m sglang.multimodal_gen.tools.build_modelopt_fp8_transformer`.

## NVFP4

Expand All @@ -110,10 +131,33 @@ sglang generate \
--save-output
```

For a dual-transformer Wan2.2 export where only the primary `transformer`
was quantized:

```bash
SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND=cudnn \
sglang generate \
--model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \
--transformer-weights-path /path/to/wan22-nvfp4-export/transformer \
--prompt "a fox walking through neon rain" \
--save-output
```

### Notes

- `--transformer-weights-path` is still the canonical CLI for NVFP4
transformer checkpoints.
- For dual-transformer pipelines such as `Wan2.2-T2V-A14B-Diffusers`, the
global `--transformer-weights-path` applies only to the primary
`transformer`. Use a per-component override such as `--transformer-2-path`
only when you intentionally want a non-default `transformer_2`.
- On Blackwell, the validated Wan2.2 ModelOpt NVFP4 path currently prefers
FlashInfer FP4 GEMM via
`SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND=cudnn`.
- This environment-variable override is a current workaround for NVFP4 cases
where the default sglang JIT/CUTLASS `sm100` path rejects a large-M shape at
`can_implement()`. The intended long-term fix is to add a validated CUTLASS
fallback for those shapes rather than rely on the override.
- Direct `--model-path` loading is a compatibility path for FLUX.2 NVFP4-style
repos or local directories.
- If `--transformer-weights-path` is provided explicitly, it takes precedence
Expand Down
13 changes: 13 additions & 0 deletions python/sglang/jit_kernel/nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ def _nvfp4_arch_env():
return override_jit_cuda_arch(major, minor, suffix="a")


@torch.compiler.disable
def prewarm_nvfp4_jit_modules(
*, include_expert_quant: bool = False, include_blockwise_moe: bool = False
) -> None:
"""Materialize NVFP4 JIT modules before torch.compile traces the model."""
_jit_nvfp4_quant_module()
_jit_nvfp4_scaled_mm_module()
if include_expert_quant:
_jit_nvfp4_expert_quant_module()
if include_blockwise_moe:
_jit_nvfp4_blockwise_moe_module()


@cache_once
def _jit_nvfp4_quant_module() -> Module:
with _nvfp4_arch_env():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ This skill owns the ModelOpt-to-SGLang bridge. It is not a generic kernel-tuning
- Benchmark only when BF16 and quantized commands are identical except for the checkpoint override being tested.
- For diffusion FP8, keep `dit_cpu_offload=false`. `dit_layerwise_offload=true` is valid on the fixed path when you want lower DiT residency.
- For multi-transformer pipelines, use per-component overrides when different components need different checkpoints.
- When a branch is missing the validated helper tools, refresh `python/sglang/multimodal_gen/tools/convert_modelopt_fp8_checkpoint.py` and `python/sglang/multimodal_gen/tools/compare_diffusion_trajectory_similarity.py` instead of inventing one-off scripts elsewhere.
- After validating a new ModelOpt quant path, update both the FP8 and NVFP4 support tables in this skill before closing the task.
- When a branch is missing the validated helper tools, refresh `python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py`, `python/sglang/multimodal_gen/tools/build_modelopt_nvfp4_transformer.py`, and `python/sglang/multimodal_gen/tools/compare_diffusion_trajectory_similarity.py` instead of inventing one-off scripts elsewhere.
- After validating a new ModelOpt quant path, update the ModelOpt support matrix in `docs/diffusion/quantization.md` before closing the task.

## Read First

Expand All @@ -38,7 +38,8 @@ Read these sources before changing code:
- `python/sglang/multimodal_gen/runtime/utils/quantization_utils.py`
- `python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py`
- Helper tools in this repo:
- [`python/sglang/multimodal_gen/tools/convert_modelopt_fp8_checkpoint.py`](../../../tools/convert_modelopt_fp8_checkpoint.py)
- [`python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py`](../../../tools/build_modelopt_fp8_transformer.py)
- [`python/sglang/multimodal_gen/tools/build_modelopt_nvfp4_transformer.py`](../../../tools/build_modelopt_nvfp4_transformer.py)
- [`python/sglang/multimodal_gen/tools/compare_diffusion_trajectory_similarity.py`](../../../tools/compare_diffusion_trajectory_similarity.py)

If you are working on a new model family, inspect the transformer's config and tensor naming before changing the generic converter.
Expand All @@ -52,32 +53,17 @@ This repo now contains:
- diffusion-side NVFP4 loading from ModelOpt exports
- FLUX.2 packed-QKV detection that distinguishes packed NVFP4 checkpoints from standard diffusers exports
- automatic protection against incompatible FP8 CPU offload while keeping layerwise DiT offload available
- FP8 export conversion:
[`python/sglang/multimodal_gen/tools/convert_modelopt_fp8_checkpoint.py`](../../../tools/convert_modelopt_fp8_checkpoint.py)
- FP8 transformer build:
[`python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py`](../../../tools/build_modelopt_fp8_transformer.py)
- trajectory similarity validation:
[`python/sglang/multimodal_gen/tools/compare_diffusion_trajectory_similarity.py`](../../../tools/compare_diffusion_trajectory_similarity.py)

## Documentation Maintenance

- Keep two separate support tables in this skill: one for FP8 and one for NVFP4.
- After finishing a new quant support path, update both tables in every mirrored copy of this skill.
- Each row must record the validated scope, the Hugging Face repo or path for the quantized DiT weights, and the key caveats.
- Keep the validated ModelOpt support matrix in `docs/diffusion/quantization.md`.
- Each row should record the validated scope, the Hugging Face repo or path for the quantized DiT weights, and the key caveats.
- If the quantized DiT weights are not published yet, write `unpublished` explicitly instead of leaving the field blank.

## FP8 Supported Models

| Base Model | Validated Scope | HF DiT Weights | Notes |
| --- | --- | --- | --- |
| `black-forest-labs/FLUX.1-dev` | single-transformer override, deterministic latent/image comparison, H100 benchmark, torch-profiler trace | `BBuf/flux1-dev-modelopt-fp8-sglang-transformer` | SGLang converter keeps a validated BF16 fallback set for modulation and FF projection layers; use `--model-id FLUX.1-dev` for local mirrors |
| `black-forest-labs/FLUX.2-dev` | single-transformer override load and generation path | `BBuf/flux2-dev-modelopt-fp8-sglang-transformer` | published SGLang-ready transformer override |
| `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | primary `transformer` quantized, `transformer_2` kept BF16 | `BBuf/wan22-t2v-a14b-modelopt-fp8-sglang-transformer` | do not describe this as dual-transformer full-model FP8 unless that path is validated separately |

## NVFP4 Supported Models

| Base Model | Validated Scope | HF DiT Weights | Notes |
| --- | --- | --- | --- |
| `black-forest-labs/FLUX.2-dev` | packed-QKV load path | `black-forest-labs/FLUX.2-dev-NVFP4` | validated packed export detection and runtime layout handling |

## FP8 Vs NVFP4

FP8 and NVFP4 are not wired into SGLang in exactly the same way.
Expand Down Expand Up @@ -120,7 +106,7 @@ python quantize.py \
--model <model-name> \
--override-model-path <hf-repo-or-local-model> \
--model-dtype <Half|BFloat16> \
--format <fp8|nvfp4> \
--format <fp8|fp4> \
--batch-size 1 \
--calib-size <calib-size> \
--n-steps <calib-steps> \
Expand All @@ -130,6 +116,9 @@ python quantize.py \
--hf-ckpt-dir <out>/hf
```

For current ModelOpt diffusion examples, use `--format fp4` for NVFP4 exports.
Do not assume the checked-out ModelOpt version accepts a literal `nvfp4` format string unless you verified it locally.

For multi-transformer models:

- quantize each backbone deliberately
Expand All @@ -141,7 +130,7 @@ For multi-transformer models:
FP8 requires an extra conversion step:

```bash
PYTHONPATH=python python3 -m sglang.multimodal_gen.tools.convert_modelopt_fp8_checkpoint \
PYTHONPATH=python python3 -m sglang.multimodal_gen.tools.build_modelopt_fp8_transformer \
--modelopt-hf-dir <out>/hf \
--modelopt-backbone-ckpt <out>/ckpt/backbone.pt \
--base-transformer-dir <base-model-transformer-dir> \
Expand All @@ -166,9 +155,25 @@ For `FLUX.1-dev`, the validated fallback set currently keeps these modules in BF
- `transformer_blocks.*.ff_context.net.0.proj`
- `transformer_blocks.*.ff_context.net.2`
- `single_transformer_blocks.*.norm.linear`
- `single_transformer_blocks.*.proj_mlp`

Use `--model-type flux1` to force that profile, or rely on `--model-type auto` when the export config identifies `FluxTransformer2DModel`.

For FLUX.1-dev NVFP4 model families that need a mixed BF16+NVFP4 checkpoint, build the merged transformer explicitly:

```bash
PYTHONPATH=python python3 -m sglang.multimodal_gen.tools.build_modelopt_nvfp4_transformer \
--base-transformer-dir <base-model-transformer-dir> \
--modelopt-hf-dir <out>/hf/transformer \
--output-dir <out>/transformer-mixed \
--pattern-preset flux1-nvfp4
```

The validated FLUX.1-dev mixed builder also needs to preserve:

- `quant_type: NVFP4` in `config.json`
- `swap_weight_nibbles: false` for the validated diffusers export

### 4. Load The Quantized Checkpoint In SGLang

Single-transformer example:
Expand Down Expand Up @@ -304,5 +309,5 @@ When documenting results:
| `runtime/utils/quantization_utils.py` | resolves flat ModelOpt configs and reconstructs NVFP4 config from metadata |
| `runtime/loader/transformer_load_utils.py` | guards incompatible FP8 offload modes |
| `runtime/models/dits/flux_2.py` | packed-QKV handling for the packed FLUX.2 NVFP4 family |
| `tools/convert_modelopt_fp8_checkpoint.py` | FP8 offline conversion into SGLang-native layout |
| `tools/build_modelopt_fp8_transformer.py` | Build an SGLang-loadable FP8 transformer from a ModelOpt export |
| `tools/compare_diffusion_trajectory_similarity.py` | reduced deterministic BF16-vs-quantized validation |
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ These options **trade output quality** for speed or VRAM savings. Results will d
| **Approximate Attention** | `--attention-backend sage_attn` / `sage_attn_3` / `sliding_tile_attn` / `video_sparse_attn` / `sparse_video_gen_2_attn` / `vmoba_attn` / `sla_attn` / `sage_sla_attn` | Replaces exact attention with approximate or sparse variants. `sage_attn`: INT8/FP8 quantized Q·K; `sliding_tile_attn`: spatial-temporal tile skipping; others: model-specific sparse patterns. | ~1.5–2x on attention (varies by backend) | Quality degradation varies by backend and model. `sage_attn` is the most general; sparse backends (`sliding_tile_attn`, `video_sparse_attn`, etc.) are video-model-specific and may require config files (e.g. `--mask-strategy-file-path` for STA). Requires corresponding packages installed. |
| **Cache-DiT** | `SGLANG_CACHE_DIT_ENABLED=true` + `--cache-dit-config <path>` | Caches intermediate residuals across denoising steps and skips redundant computations via a Selective Computation Mask (SCM). | ~1.5–2x on supported models | Quality depends on SCM config. Incompatible with `--dit-layerwise-offload`. Requires correct per-model config YAML. |
| **Quantized Models (Nunchaku / SVDQuant)** | `--enable-svdquant --transformer-weights-path <path>` + optional `--quantization-precision int4\|nvfp4`, `--quantization-rank 32` | W4A4-style quantization via [Nunchaku](https://nunchaku.tech). Reduces DiT weight memory by ~4x. Precision/rank can be auto-inferred from weight filename or set explicitly. | ~1.5–2x compute speedup | Lossy quantization; quality depends on rank and precision. Requires pre-quantized weights. Ampere (SM8x) or SM12x only (no Hopper SM90). Higher rank = better quality but more memory. |
| **Pre-quantized Weights** | `--transformer-weights-path <path>` | Load any pre-quantized transformer weights (FP8, INT8, etc.) from a single `.safetensors` file, a directory, or a HuggingFace repo ID. | ~1.3–1.5x compute (dtype dependent) | Requires pre-converted weights (e.g. via `tools/convert_hf_to_fp8.py` for FP8). Quality slightly worse than BF16; varies by quantization format. |
| **Pre-quantized Weights** | `--transformer-weights-path <path>` | Load any pre-quantized transformer weights (FP8, INT8, etc.) from a single `.safetensors` file, a directory, or a HuggingFace repo ID. | ~1.3–1.5x compute (dtype dependent) | Requires a validated quantized transformer override, such as one produced by `tools/build_modelopt_fp8_transformer.py` for ModelOpt FP8. Quality slightly worse than BF16; varies by quantization format. |
| **Component Precision Override** | `--dit-precision fp16`, `--vae-precision fp16\|bf16` | On-the-fly dtype conversion for individual components. E.g. convert a BF16 model to FP16 at load time, or run VAE in BF16 instead of FP32. | Reduces memory; FP16 can be faster on some GPUs | May affect numerical stability. VAE is FP32 by default for accuracy; lowering it is lossy. DiT defaults to BF16. |
| **Fewer Inference Steps** | `--num-inference-steps N` (sampling param) | Reduces the number of denoising steps. Fewer steps = faster. | Linear speedup | Quality degrades with too few steps. Model-dependent optimal range. |

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ def _get_fp4_gemm_op():
return current_platform.get_modelopt_fp4_gemm_op()


def _prepare_nvfp4_weight_bytes(
weight: torch.Tensor, *, swap_weight_nibbles: bool
) -> torch.Tensor:
"""Normalize serialized NVFP4 bytes before padding for the runtime kernel."""
if not swap_weight_nibbles:
return weight.contiguous()
return ((weight >> 4) | (weight << 4)).contiguous()


class ModelOptQuantConfig(QuantizationConfig):
def __init__(
self,
Expand Down Expand Up @@ -180,6 +189,7 @@ def __init__(
exclude_modules: List[str] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
checkpoint_uses_packed_qkv: bool = False,
swap_weight_nibbles: bool = True,
) -> None:
super().__init__(exclude_modules, packed_modules_mapping)
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
Expand All @@ -190,6 +200,7 @@ def __init__(
)
self.group_size = group_size
self.checkpoint_uses_packed_qkv = checkpoint_uses_packed_qkv
self.swap_weight_nibbles = swap_weight_nibbles

@classmethod
def get_name(cls) -> str:
Expand Down Expand Up @@ -237,6 +248,7 @@ def _add_group_size_from_dict(config: dict):
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
group_size = None
exclude_modules = []
swap_weight_nibbles = True

# Flat format (config.json quantization_config)
quant_method = config.get("quant_algo")
Expand All @@ -248,13 +260,18 @@ def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
first_group = next(iter(config_groups.values()), {})
group_size = first_group.get("weights", {}).get("group_size")
exclude_modules = config.get("ignore", [])
swap_weight_nibbles = config.get("swap_weight_nibbles", True)
else:
# Nested format (hf_quant_config.json)
try:
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
group_size = ModelOptFp4Config.common_group_size(config)
exclude_modules = quant_config.get("exclude_modules", [])
swap_weight_nibbles = quant_config.get(
"swap_weight_nibbles",
config.get("swap_weight_nibbles", True),
)
except (ValueError, KeyError):
raise ValueError("Cannot find 'quant_algo' in quantization config.")

Expand All @@ -274,6 +291,7 @@ def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
exclude_modules=exclude_modules,
packed_modules_mapping=config.get("packed_modules_mapping"),
checkpoint_uses_packed_qkv=config.get("checkpoint_uses_packed_qkv", False),
swap_weight_nibbles=swap_weight_nibbles,
)

def get_quant_method(self, layer: torch.nn.Module, prefix: str):
Expand Down Expand Up @@ -459,9 +477,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

layer.output_size_per_partition = layer.weight.shape[0]

# Swap nibbles: (byte >> 4) | (byte << 4).
w = layer.weight.data
w_swapped = ((w >> 4) | (w << 4)).contiguous()
w_swapped = _prepare_nvfp4_weight_bytes(
w, swap_weight_nibbles=self.quant_config.swap_weight_nibbles
)
weight, weights_padding_cols = pad_nvfp4_weight(w_swapped)
layer.weights_padding_cols = weights_padding_cols
copy_or_rebind_param(layer, "weight", weight)
Expand Down
Loading
Loading