diff --git a/docs/diffusion/quantization.md b/docs/diffusion/quantization.md index b325458b9b42..f1e111e0e04c 100644 --- a/docs/diffusion/quantization.md +++ b/docs/diffusion/quantization.md @@ -48,6 +48,27 @@ backend. | `nunchaku-svdq` | Pre-quantized Nunchaku transformer weights, usually named `svdq-{int4\|fp4}_r{rank}-...` | `--transformer-weights-path` | Model-specific support such as Qwen-Image, FLUX, and Z-Image | `nunchaku` | SGLang can infer precision and rank from the filename and supports both `int4` and `nvfp4` | | `msmodelslim` | Pre-quantized msmodelslim transformer weights | `--model-path` | Wan2.2 family | None | Currently only compatible with the Ascend NPU family and supports both `w8a8` and `w4a4` | +## Validated ModelOpt Checkpoints + +This section is the canonical support matrix for diffusion ModelOpt checkpoints +that have been brought up and verified in SGLang. + +### FP8 + +| Base Model | Validated Scope | HF DiT Weights | Notes | +| --- | --- | --- | --- | +| `black-forest-labs/FLUX.1-dev` | single-transformer override, deterministic latent/image comparison, H100 benchmark, torch-profiler trace | `BBuf/flux1-dev-modelopt-fp8-sglang-transformer` | SGLang converter keeps a validated BF16 fallback set for modulation and FF projection layers; use `--model-id FLUX.1-dev` for local mirrors | +| `black-forest-labs/FLUX.2-dev` | single-transformer override load and generation path | `BBuf/flux2-dev-modelopt-fp8-sglang-transformer` | published SGLang-ready transformer override | +| `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | primary `transformer` quantized, `transformer_2` kept BF16 | `BBuf/wan22-t2v-a14b-modelopt-fp8-sglang-transformer` | do not describe this as dual-transformer full-model FP8 unless that path is validated separately | + +### NVFP4 + +| Base Model | Validated Scope | HF DiT Weights | Notes | +| --- | --- | --- | --- | +| `black-forest-labs/FLUX.1-dev` | mixed BF16+NVFP4 transformer override, correctness validation, 4x RTX 5090 benchmark, torch-profiler trace | `unpublished` | use `build_modelopt_nvfp4_transformer.py`; validated builder keeps selected FLUX.1 modules in BF16 and sets `swap_weight_nibbles=false` | +| `black-forest-labs/FLUX.2-dev` | packed-QKV load path | `black-forest-labs/FLUX.2-dev-NVFP4` | validated packed export detection and runtime layout handling | +| `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | primary `transformer` quantized with official ModelOpt FP4 export, `transformer_2` kept BF16 | `unpublished` | global `--transformer-weights-path` targets only the primary `transformer`; keep `transformer_2` on the base checkpoint unless you pass a per-component override; validated on B200 with `SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND=cudnn` | + ## ModelOpt FP8 ### Usage Examples @@ -83,7 +104,7 @@ sglang generate \ - The layerwise offload path now preserves the non-contiguous FP8 weight stride expected by the runtime FP8 GEMM path. - To build the converted checkpoint yourself from a ModelOpt diffusers export, - use `python -m sglang.multimodal_gen.tools.convert_modelopt_fp8_checkpoint`. + use `python -m sglang.multimodal_gen.tools.build_modelopt_fp8_transformer`. ## NVFP4 @@ -110,10 +131,33 @@ sglang generate \ --save-output ``` +For a dual-transformer Wan2.2 export where only the primary `transformer` +was quantized: + +```bash +SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND=cudnn \ +sglang generate \ + --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --transformer-weights-path /path/to/wan22-nvfp4-export/transformer \ + --prompt "a fox walking through neon rain" \ + --save-output +``` + ### Notes - `--transformer-weights-path` is still the canonical CLI for NVFP4 transformer checkpoints. +- For dual-transformer pipelines such as `Wan2.2-T2V-A14B-Diffusers`, the + global `--transformer-weights-path` applies only to the primary + `transformer`. Use a per-component override such as `--transformer-2-path` + only when you intentionally want a non-default `transformer_2`. +- On Blackwell, the validated Wan2.2 ModelOpt NVFP4 path currently prefers + FlashInfer FP4 GEMM via + `SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND=cudnn`. +- This environment-variable override is a current workaround for NVFP4 cases + where the default sglang JIT/CUTLASS `sm100` path rejects a large-M shape at + `can_implement()`. The intended long-term fix is to add a validated CUTLASS + fallback for those shapes rather than rely on the override. - Direct `--model-path` loading is a compatibility path for FLUX.2 NVFP4-style repos or local directories. - If `--transformer-weights-path` is provided explicitly, it takes precedence diff --git a/python/sglang/jit_kernel/nvfp4.py b/python/sglang/jit_kernel/nvfp4.py index f54e6414b9e7..ff3c20072668 100644 --- a/python/sglang/jit_kernel/nvfp4.py +++ b/python/sglang/jit_kernel/nvfp4.py @@ -47,6 +47,19 @@ def _nvfp4_arch_env(): return override_jit_cuda_arch(major, minor, suffix="a") +@torch.compiler.disable +def prewarm_nvfp4_jit_modules( + *, include_expert_quant: bool = False, include_blockwise_moe: bool = False +) -> None: + """Materialize NVFP4 JIT modules before torch.compile traces the model.""" + _jit_nvfp4_quant_module() + _jit_nvfp4_scaled_mm_module() + if include_expert_quant: + _jit_nvfp4_expert_quant_module() + if include_blockwise_moe: + _jit_nvfp4_blockwise_moe_module() + + @cache_once def _jit_nvfp4_quant_module() -> Module: with _nvfp4_arch_env(): diff --git a/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-modelopt-quant/SKILL.md b/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-modelopt-quant/SKILL.md index 6336984fd0d3..4e81ef4eb2f7 100644 --- a/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-modelopt-quant/SKILL.md +++ b/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-modelopt-quant/SKILL.md @@ -23,8 +23,8 @@ This skill owns the ModelOpt-to-SGLang bridge. It is not a generic kernel-tuning - Benchmark only when BF16 and quantized commands are identical except for the checkpoint override being tested. - For diffusion FP8, keep `dit_cpu_offload=false`. `dit_layerwise_offload=true` is valid on the fixed path when you want lower DiT residency. - For multi-transformer pipelines, use per-component overrides when different components need different checkpoints. -- When a branch is missing the validated helper tools, refresh `python/sglang/multimodal_gen/tools/convert_modelopt_fp8_checkpoint.py` and `python/sglang/multimodal_gen/tools/compare_diffusion_trajectory_similarity.py` instead of inventing one-off scripts elsewhere. -- After validating a new ModelOpt quant path, update both the FP8 and NVFP4 support tables in this skill before closing the task. +- When a branch is missing the validated helper tools, refresh `python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py`, `python/sglang/multimodal_gen/tools/build_modelopt_nvfp4_transformer.py`, and `python/sglang/multimodal_gen/tools/compare_diffusion_trajectory_similarity.py` instead of inventing one-off scripts elsewhere. +- After validating a new ModelOpt quant path, update the ModelOpt support matrix in `docs/diffusion/quantization.md` before closing the task. ## Read First @@ -38,7 +38,8 @@ Read these sources before changing code: - `python/sglang/multimodal_gen/runtime/utils/quantization_utils.py` - `python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py` - Helper tools in this repo: - - [`python/sglang/multimodal_gen/tools/convert_modelopt_fp8_checkpoint.py`](../../../tools/convert_modelopt_fp8_checkpoint.py) + - [`python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py`](../../../tools/build_modelopt_fp8_transformer.py) + - [`python/sglang/multimodal_gen/tools/build_modelopt_nvfp4_transformer.py`](../../../tools/build_modelopt_nvfp4_transformer.py) - [`python/sglang/multimodal_gen/tools/compare_diffusion_trajectory_similarity.py`](../../../tools/compare_diffusion_trajectory_similarity.py) If you are working on a new model family, inspect the transformer's config and tensor naming before changing the generic converter. @@ -52,32 +53,17 @@ This repo now contains: - diffusion-side NVFP4 loading from ModelOpt exports - FLUX.2 packed-QKV detection that distinguishes packed NVFP4 checkpoints from standard diffusers exports - automatic protection against incompatible FP8 CPU offload while keeping layerwise DiT offload available -- FP8 export conversion: - [`python/sglang/multimodal_gen/tools/convert_modelopt_fp8_checkpoint.py`](../../../tools/convert_modelopt_fp8_checkpoint.py) +- FP8 transformer build: + [`python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py`](../../../tools/build_modelopt_fp8_transformer.py) - trajectory similarity validation: [`python/sglang/multimodal_gen/tools/compare_diffusion_trajectory_similarity.py`](../../../tools/compare_diffusion_trajectory_similarity.py) ## Documentation Maintenance -- Keep two separate support tables in this skill: one for FP8 and one for NVFP4. -- After finishing a new quant support path, update both tables in every mirrored copy of this skill. -- Each row must record the validated scope, the Hugging Face repo or path for the quantized DiT weights, and the key caveats. +- Keep the validated ModelOpt support matrix in `docs/diffusion/quantization.md`. +- Each row should record the validated scope, the Hugging Face repo or path for the quantized DiT weights, and the key caveats. - If the quantized DiT weights are not published yet, write `unpublished` explicitly instead of leaving the field blank. -## FP8 Supported Models - -| Base Model | Validated Scope | HF DiT Weights | Notes | -| --- | --- | --- | --- | -| `black-forest-labs/FLUX.1-dev` | single-transformer override, deterministic latent/image comparison, H100 benchmark, torch-profiler trace | `BBuf/flux1-dev-modelopt-fp8-sglang-transformer` | SGLang converter keeps a validated BF16 fallback set for modulation and FF projection layers; use `--model-id FLUX.1-dev` for local mirrors | -| `black-forest-labs/FLUX.2-dev` | single-transformer override load and generation path | `BBuf/flux2-dev-modelopt-fp8-sglang-transformer` | published SGLang-ready transformer override | -| `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | primary `transformer` quantized, `transformer_2` kept BF16 | `BBuf/wan22-t2v-a14b-modelopt-fp8-sglang-transformer` | do not describe this as dual-transformer full-model FP8 unless that path is validated separately | - -## NVFP4 Supported Models - -| Base Model | Validated Scope | HF DiT Weights | Notes | -| --- | --- | --- | --- | -| `black-forest-labs/FLUX.2-dev` | packed-QKV load path | `black-forest-labs/FLUX.2-dev-NVFP4` | validated packed export detection and runtime layout handling | - ## FP8 Vs NVFP4 FP8 and NVFP4 are not wired into SGLang in exactly the same way. @@ -120,7 +106,7 @@ python quantize.py \ --model \ --override-model-path \ --model-dtype \ - --format \ + --format \ --batch-size 1 \ --calib-size \ --n-steps \ @@ -130,6 +116,9 @@ python quantize.py \ --hf-ckpt-dir /hf ``` +For current ModelOpt diffusion examples, use `--format fp4` for NVFP4 exports. +Do not assume the checked-out ModelOpt version accepts a literal `nvfp4` format string unless you verified it locally. + For multi-transformer models: - quantize each backbone deliberately @@ -141,7 +130,7 @@ For multi-transformer models: FP8 requires an extra conversion step: ```bash -PYTHONPATH=python python3 -m sglang.multimodal_gen.tools.convert_modelopt_fp8_checkpoint \ +PYTHONPATH=python python3 -m sglang.multimodal_gen.tools.build_modelopt_fp8_transformer \ --modelopt-hf-dir /hf \ --modelopt-backbone-ckpt /ckpt/backbone.pt \ --base-transformer-dir \ @@ -166,9 +155,25 @@ For `FLUX.1-dev`, the validated fallback set currently keeps these modules in BF - `transformer_blocks.*.ff_context.net.0.proj` - `transformer_blocks.*.ff_context.net.2` - `single_transformer_blocks.*.norm.linear` +- `single_transformer_blocks.*.proj_mlp` Use `--model-type flux1` to force that profile, or rely on `--model-type auto` when the export config identifies `FluxTransformer2DModel`. +For FLUX.1-dev NVFP4 model families that need a mixed BF16+NVFP4 checkpoint, build the merged transformer explicitly: + +```bash +PYTHONPATH=python python3 -m sglang.multimodal_gen.tools.build_modelopt_nvfp4_transformer \ + --base-transformer-dir \ + --modelopt-hf-dir /hf/transformer \ + --output-dir /transformer-mixed \ + --pattern-preset flux1-nvfp4 +``` + +The validated FLUX.1-dev mixed builder also needs to preserve: + +- `quant_type: NVFP4` in `config.json` +- `swap_weight_nibbles: false` for the validated diffusers export + ### 4. Load The Quantized Checkpoint In SGLang Single-transformer example: @@ -304,5 +309,5 @@ When documenting results: | `runtime/utils/quantization_utils.py` | resolves flat ModelOpt configs and reconstructs NVFP4 config from metadata | | `runtime/loader/transformer_load_utils.py` | guards incompatible FP8 offload modes | | `runtime/models/dits/flux_2.py` | packed-QKV handling for the packed FLUX.2 NVFP4 family | -| `tools/convert_modelopt_fp8_checkpoint.py` | FP8 offline conversion into SGLang-native layout | +| `tools/build_modelopt_fp8_transformer.py` | Build an SGLang-loadable FP8 transformer from a ModelOpt export | | `tools/compare_diffusion_trajectory_similarity.py` | reduced deterministic BF16-vs-quantized validation | diff --git a/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-performance/SKILL.md b/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-performance/SKILL.md index 5c8d91e01f1f..eadc83899c9e 100644 --- a/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-performance/SKILL.md +++ b/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-performance/SKILL.md @@ -47,7 +47,7 @@ These options **trade output quality** for speed or VRAM savings. Results will d | **Approximate Attention** | `--attention-backend sage_attn` / `sage_attn_3` / `sliding_tile_attn` / `video_sparse_attn` / `sparse_video_gen_2_attn` / `vmoba_attn` / `sla_attn` / `sage_sla_attn` | Replaces exact attention with approximate or sparse variants. `sage_attn`: INT8/FP8 quantized Q·K; `sliding_tile_attn`: spatial-temporal tile skipping; others: model-specific sparse patterns. | ~1.5–2x on attention (varies by backend) | Quality degradation varies by backend and model. `sage_attn` is the most general; sparse backends (`sliding_tile_attn`, `video_sparse_attn`, etc.) are video-model-specific and may require config files (e.g. `--mask-strategy-file-path` for STA). Requires corresponding packages installed. | | **Cache-DiT** | `SGLANG_CACHE_DIT_ENABLED=true` + `--cache-dit-config ` | Caches intermediate residuals across denoising steps and skips redundant computations via a Selective Computation Mask (SCM). | ~1.5–2x on supported models | Quality depends on SCM config. Incompatible with `--dit-layerwise-offload`. Requires correct per-model config YAML. | | **Quantized Models (Nunchaku / SVDQuant)** | `--enable-svdquant --transformer-weights-path ` + optional `--quantization-precision int4\|nvfp4`, `--quantization-rank 32` | W4A4-style quantization via [Nunchaku](https://nunchaku.tech). Reduces DiT weight memory by ~4x. Precision/rank can be auto-inferred from weight filename or set explicitly. | ~1.5–2x compute speedup | Lossy quantization; quality depends on rank and precision. Requires pre-quantized weights. Ampere (SM8x) or SM12x only (no Hopper SM90). Higher rank = better quality but more memory. | -| **Pre-quantized Weights** | `--transformer-weights-path ` | Load any pre-quantized transformer weights (FP8, INT8, etc.) from a single `.safetensors` file, a directory, or a HuggingFace repo ID. | ~1.3–1.5x compute (dtype dependent) | Requires pre-converted weights (e.g. via `tools/convert_hf_to_fp8.py` for FP8). Quality slightly worse than BF16; varies by quantization format. | +| **Pre-quantized Weights** | `--transformer-weights-path ` | Load any pre-quantized transformer weights (FP8, INT8, etc.) from a single `.safetensors` file, a directory, or a HuggingFace repo ID. | ~1.3–1.5x compute (dtype dependent) | Requires a validated quantized transformer override, such as one produced by `tools/build_modelopt_fp8_transformer.py` for ModelOpt FP8. Quality slightly worse than BF16; varies by quantization format. | | **Component Precision Override** | `--dit-precision fp16`, `--vae-precision fp16\|bf16` | On-the-fly dtype conversion for individual components. E.g. convert a BF16 model to FP16 at load time, or run VAE in BF16 instead of FP32. | Reduces memory; FP16 can be faster on some GPUs | May affect numerical stability. VAE is FP32 by default for accuracy; lowering it is lossy. DiT defaults to BF16. | | **Fewer Inference Steps** | `--num-inference-steps N` (sampling param) | Reduces the number of denoising steps. Fewer steps = faster. | Linear speedup | Quality degrades with too few steps. Model-dependent optimal range. | diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py index 627edb93edeb..d7ec21a3570b 100755 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelopt_quant.py @@ -52,6 +52,15 @@ def _get_fp4_gemm_op(): return current_platform.get_modelopt_fp4_gemm_op() +def _prepare_nvfp4_weight_bytes( + weight: torch.Tensor, *, swap_weight_nibbles: bool +) -> torch.Tensor: + """Normalize serialized NVFP4 bytes before padding for the runtime kernel.""" + if not swap_weight_nibbles: + return weight.contiguous() + return ((weight >> 4) | (weight << 4)).contiguous() + + class ModelOptQuantConfig(QuantizationConfig): def __init__( self, @@ -180,6 +189,7 @@ def __init__( exclude_modules: List[str] = None, packed_modules_mapping: Optional[Dict[str, List[str]]] = None, checkpoint_uses_packed_qkv: bool = False, + swap_weight_nibbles: bool = True, ) -> None: super().__init__(exclude_modules, packed_modules_mapping) self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized @@ -190,6 +200,7 @@ def __init__( ) self.group_size = group_size self.checkpoint_uses_packed_qkv = checkpoint_uses_packed_qkv + self.swap_weight_nibbles = swap_weight_nibbles @classmethod def get_name(cls) -> str: @@ -237,6 +248,7 @@ def _add_group_size_from_dict(config: dict): def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: group_size = None exclude_modules = [] + swap_weight_nibbles = True # Flat format (config.json quantization_config) quant_method = config.get("quant_algo") @@ -248,6 +260,7 @@ def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: first_group = next(iter(config_groups.values()), {}) group_size = first_group.get("weights", {}).get("group_size") exclude_modules = config.get("ignore", []) + swap_weight_nibbles = config.get("swap_weight_nibbles", True) else: # Nested format (hf_quant_config.json) try: @@ -255,6 +268,10 @@ def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: quant_method = quant_config["quant_algo"] group_size = ModelOptFp4Config.common_group_size(config) exclude_modules = quant_config.get("exclude_modules", []) + swap_weight_nibbles = quant_config.get( + "swap_weight_nibbles", + config.get("swap_weight_nibbles", True), + ) except (ValueError, KeyError): raise ValueError("Cannot find 'quant_algo' in quantization config.") @@ -274,6 +291,7 @@ def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: exclude_modules=exclude_modules, packed_modules_mapping=config.get("packed_modules_mapping"), checkpoint_uses_packed_qkv=config.get("checkpoint_uses_packed_qkv", False), + swap_weight_nibbles=swap_weight_nibbles, ) def get_quant_method(self, layer: torch.nn.Module, prefix: str): @@ -459,9 +477,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.output_size_per_partition = layer.weight.shape[0] - # Swap nibbles: (byte >> 4) | (byte << 4). w = layer.weight.data - w_swapped = ((w >> 4) | (w << 4)).contiguous() + w_swapped = _prepare_nvfp4_weight_bytes( + w, swap_weight_nibbles=self.quant_config.swap_weight_nibbles + ) weight, weights_padding_cols = pad_nvfp4_weight(w_swapped) layer.weights_padding_cols = weights_padding_cols copy_or_rebind_param(layer, "weight", weight) diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py index 7c0b6b78851b..ad6294bb7000 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py @@ -1,3 +1,4 @@ +import copy import logging from typing import Any @@ -26,6 +27,30 @@ logger = init_logger(__name__) +def _server_args_for_transformer_component( + server_args: ServerArgs, component_name: str +) -> ServerArgs: + """Mask global quantized override flags for secondary transformer components.""" + if component_name != "transformer_2": + return server_args + + if ( + server_args.transformer_weights_path is None + and server_args.nunchaku_config is None + ): + return server_args + + component_server_args = copy.copy(server_args) + component_server_args.transformer_weights_path = None + component_server_args.nunchaku_config = None + logger.info( + "Ignoring global transformer_weights_path for %s; keep it on the base " + "checkpoint unless a per-component override path is provided.", + component_name, + ) + return component_server_args + + class TransformerLoader(ComponentLoader): """Shared loader for (video/audio) DiT transformers.""" @@ -36,11 +61,15 @@ def load_customized( self, component_model_path: str, server_args: ServerArgs, component_name: str ): """Load the transformer based on the model path, and inference args.""" + component_server_args = _server_args_for_transformer_component( + server_args, component_name + ) + # 1. hf config config = get_diffusers_component_config(component_path=component_model_path) safetensors_list = resolve_transformer_safetensors_to_load( - server_args, component_model_path + component_server_args, component_model_path ) # 2. dit config @@ -61,7 +90,7 @@ def load_customized( quant_spec = resolve_transformer_quant_load_spec( hf_config=config, - server_args=server_args, + server_args=component_server_args, safetensors_list=safetensors_list, component_model_path=component_model_path, model_cls=model_cls, @@ -83,7 +112,7 @@ def load_customized( } if ( init_params["quant_config"] is None - and server_args.transformer_weights_path is not None + and component_server_args.transformer_weights_path is not None ): logger.warning( f"transformer_weights_path provided, but quantization config not resolved, which is unexpected and likely to cause errors" @@ -99,9 +128,9 @@ def load_customized( device=get_local_torch_device(), hsdp_replicate_dim=server_args.hsdp_replicate_dim, hsdp_shard_dim=server_args.hsdp_shard_dim, - cpu_offload=server_args.dit_cpu_offload, - pin_cpu_memory=server_args.pin_cpu_memory, - fsdp_inference=server_args.use_fsdp_inference, + cpu_offload=component_server_args.dit_cpu_offload, + pin_cpu_memory=component_server_args.pin_cpu_memory, + fsdp_inference=component_server_args.use_fsdp_inference, param_dtype=quant_spec.param_dtype, reduce_dtype=torch.float32, output_dtype=None, diff --git a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py index 2941f5e24f95..c0375be9b385 100644 --- a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py +++ b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py @@ -410,7 +410,16 @@ def load_model_from_full_model_state_dict( ): requires_grad = False temp_param.requires_grad = requires_grad - weight_loader(temp_param, full_tensor) + try: + weight_loader(temp_param, full_tensor) + except AssertionError as exc: + raise AssertionError( + "Failed to shard/load parameter " + f"{target_param_name}: full_tensor.shape={tuple(full_tensor.shape)}, " + f"meta_sharded_param.shape={tuple(meta_sharded_param.shape)}, " + f"temp_param.shape={tuple(temp_param.shape)}, " + f"param_cls={type(actual_param).__name__}" + ) from exc sharded_tensor = temp_param.data else: # In cases where parts of the model aren't sharded, some parameters will be plain tensors diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index 6ecfc22503fd..481296a2da99 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -261,6 +261,7 @@ def __init__( bias=bias, gather_output=True, quant_config=quant_config, + prefix=f"{prefix}.to_q" if prefix else "to_q", ) self.to_k = ColumnParallelLinear( query_dim, @@ -268,6 +269,7 @@ def __init__( bias=bias, gather_output=True, quant_config=quant_config, + prefix=f"{prefix}.to_k" if prefix else "to_k", ) self.to_v = ColumnParallelLinear( query_dim, @@ -275,6 +277,7 @@ def __init__( bias=bias, gather_output=True, quant_config=quant_config, + prefix=f"{prefix}.to_v" if prefix else "to_v", ) if not self.pre_only: self.to_out = torch.nn.ModuleList([]) @@ -310,6 +313,7 @@ def __init__( bias=added_proj_bias, gather_output=True, quant_config=quant_config, + prefix=f"{prefix}.add_q_proj" if prefix else "add_q_proj", ) self.add_k_proj = ColumnParallelLinear( added_kv_proj_dim, @@ -317,6 +321,7 @@ def __init__( bias=added_proj_bias, gather_output=True, quant_config=quant_config, + prefix=f"{prefix}.add_k_proj" if prefix else "add_k_proj", ) self.add_v_proj = ColumnParallelLinear( added_kv_proj_dim, @@ -324,6 +329,7 @@ def __init__( bias=added_proj_bias, gather_output=True, quant_config=quant_config, + prefix=f"{prefix}.add_v_proj" if prefix else "add_v_proj", ) self.to_add_out = ColumnParallelLinear( self.inner_dim, @@ -495,6 +501,7 @@ def __init__( bias=True, gather_output=True, quant_config=quant_config, + prefix=f"{prefix}.proj_mlp" if prefix else "proj_mlp", ) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = ColumnParallelLinear( @@ -503,6 +510,7 @@ def __init__( bias=True, gather_output=True, quant_config=quant_config, + prefix=f"{prefix}.proj_out" if prefix else "proj_out", ) self.attn = FluxAttention( query_dim=dim, @@ -513,6 +521,7 @@ def __init__( eps=1e-6, pre_only=True, quant_config=quant_config, + prefix=f"{prefix}.attn" if prefix else "attn", ) def forward( diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index 7b868be9faf6..4db8c6816fca 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -19,6 +19,7 @@ import torch.nn as nn from tqdm.auto import tqdm +from sglang.jit_kernel.nvfp4 import prewarm_nvfp4_jit_modules from sglang.multimodal_gen import envs from sglang.multimodal_gen.configs.pipeline_configs.base import ModelTaskType, STA_Mode from sglang.multimodal_gen.configs.pipeline_configs.flux import ( @@ -257,9 +258,26 @@ def _maybe_enable_torch_compile(self, module: object) -> None: compile_kwargs["mode"] = mode logger.info(f"Compiling transformer with mode: {mode}") + if self._needs_nvfp4_jit_prewarm(module): + logger.info( + "Prewarming NVFP4 JIT modules before torch.compile to avoid " + "Dynamo tracing JIT initialization." + ) + prewarm_nvfp4_jit_modules() + # TODO(triple-mu): support customized fullgraph and dynamic in the future module.compile(**compile_kwargs) + @staticmethod + def _needs_nvfp4_jit_prewarm(module: nn.Module) -> bool: + for submodule in module.modules(): + quant_method = getattr(submodule, "quant_method", None) + if quant_method is None: + continue + if type(quant_method).__name__ == "ModelOptFp4LinearMethod": + return True + return False + def _maybe_enable_cache_dit( self, num_inference_steps: int | tuple[int, int], batch: Req ) -> None: diff --git a/python/sglang/multimodal_gen/runtime/platforms/cuda.py b/python/sglang/multimodal_gen/runtime/platforms/cuda.py index f729292752f8..dfd4490eab04 100644 --- a/python/sglang/multimodal_gen/runtime/platforms/cuda.py +++ b/python/sglang/multimodal_gen/runtime/platforms/cuda.py @@ -141,6 +141,24 @@ def get_modelopt_flashinfer_fp4_backend(cls) -> str: @classmethod @lru_cache(maxsize=1) def get_modelopt_fp4_gemm_op(cls) -> tuple[Callable | None, str | None]: + # TODO: Remove this explicit FlashInfer preference once the sm100 CUTLASS + # LargeM dispatch grows a validated fallback for Blackwell NVFP4 shapes + # such as Wan2.2's large-M attention projections. + prefer_flashinfer = ( + envs.SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND is not None + ) + + if prefer_flashinfer: + try: + from flashinfer import mm_fp4 as flashinfer_mm_fp4 + + return flashinfer_mm_fp4, cls.get_modelopt_flashinfer_fp4_backend() + except ImportError: + logger.warning( + "SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND is set, " + "but flashinfer.mm_fp4 is unavailable. Falling back to cutlass." + ) + try: from sgl_kernel import cutlass_scaled_fp4_mm as cutlass_fp4_gemm diff --git a/python/sglang/multimodal_gen/test/unit/test_transformer_quant.py b/python/sglang/multimodal_gen/test/unit/test_transformer_quant.py index 4d37b5bc7d02..18833ccc56f0 100644 --- a/python/sglang/multimodal_gen/test/unit/test_transformer_quant.py +++ b/python/sglang/multimodal_gen/test/unit/test_transformer_quant.py @@ -10,6 +10,8 @@ from types import SimpleNamespace from unittest.mock import patch +import torch + partial_json_parser = types.ModuleType("partial_json_parser") partial_json_parser_core = types.ModuleType("partial_json_parser.core") partial_json_parser_exceptions = types.ModuleType("partial_json_parser.core.exceptions") @@ -41,15 +43,24 @@ def _loads(input_str, _flags=None): ) sys.modules.setdefault("partial_json_parser.core.options", partial_json_parser_options) +from sglang.multimodal_gen.runtime.layers.linear import UnquantizedLinearMethod from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import ( NunchakuConfig, ) +from sglang.multimodal_gen.runtime.layers.quantization.modelopt_quant import ( + ModelOptFp4Config, + _prepare_nvfp4_weight_bytes, +) from sglang.multimodal_gen.runtime.loader.transformer_load_utils import ( _filter_duplicate_precision_variant_safetensors, _Flux2Nvfp4FallbackAdapter, resolve_transformer_quant_load_spec, resolve_transformer_safetensors_to_load, ) +from sglang.multimodal_gen.runtime.models.dits.flux import FluxSingleTransformerBlock +from sglang.multimodal_gen.tools.build_modelopt_nvfp4_transformer import ( + _updated_quant_config, +) class _FakeFluxTransformer: @@ -148,22 +159,21 @@ def test_resolve_transformer_quant_load_spec_keeps_nunchaku_hook( mock_metadata.return_value = { "config": json.dumps({"_class_name": _FakeFluxTransformer.__name__}) } - nunchaku_config = NunchakuConfig( - transformer_weights_path="/tmp/svdq-int4_r32.safetensors" - ) - server_args = self._make_server_args( - transformer_weights_path=nunchaku_config.transformer_weights_path, - nunchaku_config=nunchaku_config, - ) + with tempfile.NamedTemporaryFile(suffix=".safetensors") as f: + nunchaku_config = NunchakuConfig(transformer_weights_path=f.name) + server_args = self._make_server_args( + transformer_weights_path=nunchaku_config.transformer_weights_path, + nunchaku_config=nunchaku_config, + ) - spec = resolve_transformer_quant_load_spec( - hf_config={}, - server_args=server_args, - safetensors_list=[nunchaku_config.transformer_weights_path], - component_model_path="/unused/component/path", - model_cls=_FakeFluxTransformer, - cls_name=_FakeFluxTransformer.__name__, - ) + spec = resolve_transformer_quant_load_spec( + hf_config={}, + server_args=server_args, + safetensors_list=[nunchaku_config.transformer_weights_path], + component_model_path="/unused/component/path", + model_cls=_FakeFluxTransformer, + cls_name=_FakeFluxTransformer.__name__, + ) self.assertIsNone(spec.quant_config) self.assertIs(spec.nunchaku_config, nunchaku_config) @@ -189,6 +199,114 @@ def test_flux2_mixed_nvfp4_fallback_disables_conflicting_offloads(self): self.assertFalse(server_args.dit_cpu_offload) self.assertFalse(server_args.text_encoder_cpu_offload) + def test_prepare_nvfp4_weight_bytes_swaps_nibbles(self): + weight = torch.tensor([[0xAB, 0x10]], dtype=torch.uint8) + + prepared = _prepare_nvfp4_weight_bytes(weight, swap_weight_nibbles=True) + + self.assertEqual(prepared.tolist(), [[0xBA, 0x01]]) + + def test_prepare_nvfp4_weight_bytes_can_skip_nibble_swap(self): + weight = torch.tensor([[0xAB, 0x10]], dtype=torch.uint8) + + prepared = _prepare_nvfp4_weight_bytes(weight, swap_weight_nibbles=False) + + self.assertEqual(prepared.tolist(), [[0xAB, 0x10]]) + + def test_modelopt_fp4_config_reads_swap_weight_nibbles_from_flat_config(self): + config = ModelOptFp4Config.from_config( + { + "quant_algo": "NVFP4", + "group_size": 16, + "ignore": [], + "swap_weight_nibbles": False, + } + ) + + self.assertFalse(config.swap_weight_nibbles) + + def test_modelopt_fp4_config_reads_swap_weight_nibbles_from_nested_config(self): + config = ModelOptFp4Config.from_config( + { + "quantization": { + "quant_algo": "NVFP4", + "exclude_modules": [], + "swap_weight_nibbles": False, + }, + "config_groups": {"default": {"weights": {"group_size": 16}}}, + } + ) + + self.assertFalse(config.swap_weight_nibbles) + + def test_builder_adds_diffusers_quant_type_for_nvfp4(self): + updated = _updated_quant_config( + { + "quantization_config": { + "quant_method": "modelopt", + "quant_algo": "NVFP4", + "ignore": [], + } + }, + fallback_patterns=["single_transformer_blocks.*.proj_mlp*"], + swap_weight_nibbles=False, + ) + + self.assertEqual(updated["quantization_config"]["quant_type"], "NVFP4") + self.assertEqual( + updated["quantization_config"]["ignore"], + ["single_transformer_blocks.*.proj_mlp*"], + ) + + @patch("sglang.multimodal_gen.runtime.layers.linear.get_group_rank", return_value=0) + @patch("sglang.multimodal_gen.runtime.layers.linear.get_group_size", return_value=1) + @patch( + "sglang.multimodal_gen.runtime.layers.linear.get_tp_group", return_value=None + ) + @patch( + "sglang.multimodal_gen.runtime.layers.attention.layer.get_ring_parallel_world_size", + return_value=1, + ) + @patch( + "sglang.multimodal_gen.runtime.layers.attention.selector.get_global_server_args", + return_value=SimpleNamespace(attention_backend=None), + ) + def test_flux_single_transformer_block_modelopt_excludes_use_full_prefix( + self, + _mock_server_args, + _mock_ring_world_size, + _mock_tp_group, + _mock_group_size, + _mock_group_rank, + ): + quant_config = ModelOptFp4Config( + is_checkpoint_nvfp4_serialized=True, + group_size=16, + exclude_modules=[ + "single_transformer_blocks.*.proj_mlp*", + "single_transformer_blocks.*.proj_out*", + "single_transformer_blocks.*.attn.to_q", + ], + ) + + block = FluxSingleTransformerBlock( + dim=64, + num_attention_heads=4, + attention_head_dim=16, + mlp_ratio=2.0, + quant_config=quant_config, + prefix="single_transformer_blocks.0", + ) + + self.assertEqual(block.proj_mlp.prefix, "single_transformer_blocks.0.proj_mlp") + self.assertEqual(block.proj_out.prefix, "single_transformer_blocks.0.proj_out") + self.assertEqual( + block.attn.to_q.prefix, "single_transformer_blocks.0.attn.to_q" + ) + self.assertIsInstance(block.proj_mlp.quant_method, UnquantizedLinearMethod) + self.assertIsInstance(block.proj_out.quant_method, UnquantizedLinearMethod) + self.assertIsInstance(block.attn.to_q.quant_method, UnquantizedLinearMethod) + if __name__ == "__main__": unittest.main() diff --git a/python/sglang/multimodal_gen/tools/convert_modelopt_fp8_checkpoint.py b/python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py similarity index 66% rename from python/sglang/multimodal_gen/tools/convert_modelopt_fp8_checkpoint.py rename to python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py index 89a731fba863..5d87f5ff1bf2 100644 --- a/python/sglang/multimodal_gen/tools/convert_modelopt_fp8_checkpoint.py +++ b/python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py @@ -1,4 +1,4 @@ -"""Convert a ModelOpt diffusion FP8 export into an SGLang-loadable checkpoint. +"""Build an SGLang-loadable ModelOpt FP8 diffusion transformer. The core conversion path is model-agnostic: - read the ModelOpt diffusers transformer export @@ -12,7 +12,7 @@ Example: - python -m sglang.multimodal_gen.tools.convert_modelopt_fp8_checkpoint \ + python -m sglang.multimodal_gen.tools.build_modelopt_fp8_transformer \ --modelopt-hf-dir /tmp/modelopt_flux2_fp8/hf \ --modelopt-backbone-ckpt /tmp/modelopt_flux2_fp8/ckpt/backbone.pt \ --base-transformer-dir /path/to/FLUX.2-dev/transformer \ @@ -57,6 +57,21 @@ r"^transformer_blocks\.\d+\.ff_context\.net\.2$", r"^single_transformer_blocks\.\d+\.norm\.linear$", ] +DEFAULT_LTX2_KEEP_BF16_PATTERNS = [ + r"^(audio_)?adaln_single\.emb\.timestep_embedder\.linear_[12]$", + r"^(audio_)?adaln_single\.linear$", + r"^audio_caption_projection\.linear_[12]$", + r"^audio_patchify_proj$", + r"^audio_proj_out$", + r"^av_ca_(a2v_gate|audio_scale_shift|v2a_gate|video_scale_shift)_adaln_single\.emb\.timestep_embedder\.linear_[12]$", + r"^av_ca_(a2v_gate|audio_scale_shift|v2a_gate|video_scale_shift)_adaln_single\.linear$", + r"^caption_projection\.linear_[12]$", + r"^patchify_proj$", + r"^proj_out$", + r"^transformer_blocks\.(0|43|44|45|46|47)\.(attn1|attn2|audio_attn1|audio_attn2|audio_to_video_attn|video_to_audio_attn)\.to_(q|k|v)$", + r"^transformer_blocks\.(0|43|44|45|46|47)\.(attn1|attn2|audio_attn1|audio_attn2|audio_to_video_attn|video_to_audio_attn)\.to_out\.0$", + r"^transformer_blocks\.(0|43|44|45|46|47)\.(ff|audio_ff)\.proj_(in|out)$", +] def _resolve_transformer_dir(path: str) -> str: @@ -126,9 +141,116 @@ def _load_config(model_dir: str) -> dict: return json.load(f) +def _load_first_shard_metadata( + model_dir: str, weight_map: Mapping[str, str] +) -> dict[str, str]: + if not weight_map: + return {} + first_shard = next(iter(weight_map.values())) + with safe_open( + os.path.join(model_dir, first_shard), framework="pt", device="cpu" + ) as f: + return dict(f.metadata() or {}) + + +def _module_name_variants(weight_name: str) -> list[str]: + module_name = weight_name[:-7] if weight_name.endswith(".weight") else weight_name + variants = [module_name] + + for prefix in ("model.diffusion_model.", "velocity_model."): + if module_name.startswith(prefix): + variants.append(module_name[len(prefix) :]) + + canonicalized: list[str] = [] + for variant in variants: + canonicalized.append( + re.sub(r"(\.audio_ff|\.ff)\.net\.0\.proj$", r"\1.proj_in", variant) + ) + canonicalized.append( + re.sub(r"(\.audio_ff|\.ff)\.net\.2$", r"\1.proj_out", variant) + ) + variants.extend(canonicalized) + + deduped: list[str] = [] + for variant in variants: + if variant not in deduped: + deduped.append(variant) + return deduped + + +def _preferred_module_name(weight_name: str) -> str: + return _module_name_variants(weight_name)[-1] + + +def _scale_key_candidates(weight_name: str) -> list[str]: + candidates = [weight_name] + if weight_name.startswith("model.diffusion_model."): + candidates.append( + "velocity_model." + weight_name[len("model.diffusion_model.") :] + ) + return candidates + + +def _resolve_scale_key( + weight_name: str, + scale_map: Mapping[str, Mapping[str, torch.Tensor]], +) -> str | None: + for candidate in _scale_key_candidates(weight_name): + if candidate in scale_map: + return candidate + return None + + +def _is_ltx2_x0_export( + *, + config: Mapping[str, object], + source_metadata: Mapping[str, str], + source_weight_map: Mapping[str, str], +) -> bool: + if config.get("_class_name") != "X0Model": + return False + if not any(name.startswith("model.diffusion_model.") for name in source_weight_map): + return False + try: + metadata_config = json.loads(str(source_metadata.get("config", ""))) + except json.JSONDecodeError: + return False + return isinstance(metadata_config.get("transformer"), dict) + + +def _build_output_config( + *, + source_config: Mapping[str, object], + source_metadata: Mapping[str, str], + quant_config: Mapping[str, object], + is_ltx2_x0_export: bool, +) -> dict[str, object]: + if is_ltx2_x0_export: + metadata_config = json.loads(str(source_metadata["config"])) + output_config = dict(metadata_config["transformer"]) + output_config["_class_name"] = "LTX2VideoTransformer3DModel" + else: + output_config = dict(source_config) + + output_config["quantization_config"] = dict(quant_config) + return output_config + + +def _should_keep_ltx2_transformer_key(weight_name: str) -> bool: + if not weight_name.startswith("model.diffusion_model."): + return False + connector_prefixes = ( + "model.diffusion_model.audio_embeddings_connector.", + "model.diffusion_model.video_embeddings_connector.", + ) + return not weight_name.startswith(connector_prefixes) + + def get_default_keep_bf16_patterns( *, model_type: str, class_name: str | None ) -> list[str]: + if model_type == "ltx2": + return list(DEFAULT_LTX2_KEEP_BF16_PATTERNS) if model_type == "flux1": return list(DEFAULT_FLUX1_KEEP_BF16_PATTERNS) if model_type == "flux2": @@ -149,8 +271,11 @@ def should_keep_bf16( if not keep_bf16_patterns: return False - module_name = weight_name[:-7] if weight_name.endswith(".weight") else weight_name - return any(re.search(pattern, module_name) for pattern in keep_bf16_patterns) + return any( + re.search(pattern, module_name) + for pattern in keep_bf16_patterns + for module_name in _module_name_variants(weight_name) + ) def is_ignored_by_modelopt( @@ -160,10 +285,12 @@ def is_ignored_by_modelopt( if not ignore_patterns: return False - module_name = weight_name[:-7] if weight_name.endswith(".weight") else weight_name for pattern in ignore_patterns: regex_str = pattern.replace(".", r"\.").replace("*", r".*") - if re.fullmatch(regex_str, module_name): + if any( + re.fullmatch(regex_str, module_name) + for module_name in _module_name_variants(weight_name) + ): return True return False @@ -242,7 +369,7 @@ def _load_selected_tensors( return tensors -def convert_modelopt_fp8_checkpoint( +def build_modelopt_fp8_transformer( *, modelopt_hf_dir: str, modelopt_backbone_ckpt: str, @@ -265,23 +392,29 @@ def convert_modelopt_fp8_checkpoint( raise ValueError( "Expected a flat quantization_config dict in the ModelOpt export." ) - if ( - quant_config.get("quant_method") != "modelopt" - or "FP8" not in str(quant_config.get("quant_algo", "")).upper() - ): + if quant_config.get("quant_method") != "modelopt": raise ValueError( "This tool only supports ModelOpt diffusers FP8 exports " - "(quant_method=modelopt, quant_algo=FP8)." + "(quant_method=modelopt)." ) + source_weight_map_all, index_filename = _load_weight_map(source_dir) + source_metadata = _load_first_shard_metadata(source_dir, source_weight_map_all) + is_ltx2_export = _is_ltx2_x0_export( + config=config, + source_metadata=source_metadata, + source_weight_map=source_weight_map_all, + ) class_name = config.get("_class_name") ignore_patterns = list(quant_config.get("ignore", []) or []) patterns = list( get_default_keep_bf16_patterns(model_type=model_type, class_name=class_name) ) + if is_ltx2_export and model_type == "auto": + patterns.extend(DEFAULT_LTX2_KEEP_BF16_PATTERNS) if keep_bf16_patterns: patterns.extend(keep_bf16_patterns) - if patterns and base_dir is None: + if patterns and base_dir is None and not is_ltx2_export: raise ValueError( "BF16 fallback patterns are enabled, but --base-transformer-dir was not provided." ) @@ -298,25 +431,73 @@ def convert_modelopt_fp8_checkpoint( _copy_non_shard_files(source_dir, str(output_path)) - source_weight_map, index_filename = _load_weight_map(source_dir) + if is_ltx2_export: + source_weight_map = { + name: filename + for name, filename in source_weight_map_all.items() + if _should_keep_ltx2_transformer_key(name) + } + else: + source_weight_map = source_weight_map_all base_weight_map: dict[str, str] = {} if base_dir is not None: base_weight_map, _ = _load_weight_map(base_dir) + fallback_weight_names = sorted( + weight_name + for weight_name in source_weight_map + if weight_name.endswith(".weight") and should_keep_bf16(weight_name, patterns) + ) + fallback_weight_names_set = set(fallback_weight_names) backbone_state = torch.load(backbone_ckpt_path, map_location="cpu")[ "model_state_dict" ] fp8_scale_map = build_fp8_scale_map(backbone_state, maxbound=maxbound) - serialized_quant_config = json.dumps(quant_config, sort_keys=True) - - fallback_weight_names = sorted( - weight_name - for weight_name in source_weight_map - if weight_name.endswith(".weight") and should_keep_bf16(weight_name, patterns) + quant_algo = str(quant_config.get("quant_algo", "")).upper() + if quant_algo and "FP8" not in quant_algo: + raise ValueError( + "This tool only supports ModelOpt diffusers FP8 exports, " + f"got quant_algo={quant_config.get('quant_algo')!r}." + ) + if not quant_algo and not fp8_scale_map: + raise ValueError( + "Could not infer an FP8 ModelOpt export: quantization_config.quant_algo " + "is missing and backbone.pt does not contain FP8 scale tensors." + ) + effective_quant_config = json.loads(json.dumps(quant_config)) + if not quant_algo: + effective_quant_config["quant_algo"] = "FP8" + + auto_ignore_modules = sorted( + { + _preferred_module_name(weight_name) + for weight_name in source_weight_map + if weight_name.endswith(".weight") + and _resolve_scale_key(weight_name, fp8_scale_map) is None + } + ) + fallback_ignore_modules = sorted( + {_preferred_module_name(weight_name) for weight_name in fallback_weight_names} ) + ignore_patterns = sorted( + { + *ignore_patterns, + *auto_ignore_modules, + *fallback_ignore_modules, + } + ) + effective_quant_config["ignore"] = ignore_patterns + serialized_quant_config = json.dumps(effective_quant_config, sort_keys=True) + output_config = _build_output_config( + source_config=config, + source_metadata=source_metadata, + quant_config=effective_quant_config, + is_ltx2_x0_export=is_ltx2_export, + ) + fallback_tensors = ( _load_selected_tensors(base_dir, base_weight_map, fallback_weight_names) - if fallback_weight_names + if fallback_weight_names and base_dir is not None else {} ) fallback_scale_names = { @@ -340,15 +521,23 @@ def convert_modelopt_fp8_checkpoint( for filename, names in sorted(weights_by_file.items()): shard_path = os.path.join(source_dir, filename) shard_tensors = load_file(shard_path, device="cpu") + selected_names = set(names) with safe_open(shard_path, framework="pt", device="cpu") as f: metadata = dict(f.metadata() or {}) metadata.setdefault("format", "pt") + metadata["_class_name"] = str( + output_config.get("_class_name", metadata.get("_class_name", "")) + ) + metadata["config"] = json.dumps(output_config, sort_keys=True) metadata["quantization_config"] = serialized_quant_config metadata["_quantization_metadata"] = serialized_quant_config for name in list(shard_tensors.keys()): + if name not in selected_names: + del shard_tensors[name] + continue if "_quantizer." in name: del shard_tensors[name] continue @@ -362,12 +551,14 @@ def convert_modelopt_fp8_checkpoint( continue if name in fallback_tensors: shard_tensors[name] = fallback_tensors[name] + scale_key = _resolve_scale_key(name, fp8_scale_map) if ( name.endswith(".weight") - and name in fp8_scale_map + and scale_key is not None and name not in fallback_tensors + and name not in fallback_weight_names_set ): - scale_tensors = fp8_scale_map[name] + scale_tensors = fp8_scale_map[scale_key] shard_tensors[name] = quantize_fp8_weight( shard_tensors[name], scale_tensors["weight_scale"] ) @@ -397,12 +588,15 @@ def convert_modelopt_fp8_checkpoint( sort_keys=True, ) + with open(output_path / "config.json", "w", encoding="utf-8") as f: + json.dump(output_config, f, indent=2, sort_keys=True) + return { "quantized_weights": sum( 1 for name in source_weight_map if name.endswith(".weight") - and name in fp8_scale_map + and _resolve_scale_key(name, fp8_scale_map) is not None and not is_ignored_by_modelopt(name, ignore_patterns) ), "bf16_fallback_weights": len(fallback_weight_names), @@ -415,8 +609,8 @@ def convert_modelopt_fp8_checkpoint( def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( - "Inject FP8 scales from ModelOpt backbone.pt into a diffusers export so " - "SGLang diffusion can load it natively." + "Build an SGLang-loadable ModelOpt FP8 diffusion transformer from a " + "ModelOpt diffusers export." ) ) parser.add_argument( @@ -443,11 +637,11 @@ def _parse_args() -> argparse.Namespace: ) parser.add_argument( "--model-type", - choices=["auto", "flux1", "flux2", "none"], + choices=["auto", "flux1", "flux2", "ltx2", "none"], default="auto", help=( "Optional model-family BF16 fallback profile. 'none' uses the generic " - "conversion path. 'auto' enables the validated FLUX.1 / FLUX.2 " + "conversion path. 'auto' enables the validated FLUX.1 / FLUX.2 / LTX-2 " "fallback set when the export config matches those transformer classes." ), ) @@ -477,7 +671,7 @@ def _parse_args() -> argparse.Namespace: def main() -> None: args = _parse_args() - stats = convert_modelopt_fp8_checkpoint( + stats = build_modelopt_fp8_transformer( modelopt_hf_dir=args.modelopt_hf_dir, modelopt_backbone_ckpt=args.modelopt_backbone_ckpt, output_dir=args.output_dir, diff --git a/python/sglang/multimodal_gen/tools/build_modelopt_nvfp4_transformer.py b/python/sglang/multimodal_gen/tools/build_modelopt_nvfp4_transformer.py new file mode 100644 index 000000000000..e6f13b306a41 --- /dev/null +++ b/python/sglang/multimodal_gen/tools/build_modelopt_nvfp4_transformer.py @@ -0,0 +1,402 @@ +"""Build an SGLang-loadable ModelOpt NVFP4 diffusion transformer. + +This tool keeps the ModelOpt-exported NVFP4 tensors for most transformer +modules, but can replace a validated subset of numerically sensitive modules +with their original BF16 tensors from the base transformer checkpoint. + +It is primarily intended for FLUX.1-dev style ModelOpt NVFP4 exports where: +- the base pipeline should remain separate from the quantized transformer +- fallback BF16 modules are model-family specific +- the serialized FP4 weight byte order may already match the runtime kernel +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import shutil +from collections import defaultdict +from pathlib import Path +from typing import Iterable, Mapping, Sequence + +from safetensors import safe_open +from safetensors.torch import load_file, save_file + +INDEX_FILENAMES = [ + "model.safetensors.index.json", + "diffusion_pytorch_model.safetensors.index.json", +] + +DEFAULT_FLUX1_NVFP4_FALLBACK_PATTERNS = [ + "transformer_blocks.*.norm1.linear*", + "transformer_blocks.*.norm1_context.linear*", + "transformer_blocks.*.ff.net.0.proj*", + "transformer_blocks.*.ff.net.2*", + "transformer_blocks.*.ff_context.net.0.proj*", + "transformer_blocks.*.ff_context.net.2*", + "single_transformer_blocks.*.norm.linear*", + "single_transformer_blocks.*.proj_mlp*", +] + +_TENSOR_MODULE_SUFFIXES = ( + ".weight_scale_2", + ".weight_scale", + ".input_scale", + ".weight", + ".bias", +) + + +def _resolve_transformer_dir(path: str) -> str: + candidate = Path(path).expanduser().resolve() + if (candidate / "config.json").is_file(): + return str(candidate) + transformer_dir = candidate / "transformer" + if (transformer_dir / "config.json").is_file(): + return str(transformer_dir) + raise FileNotFoundError(f"Could not resolve a transformer directory from: {path}") + + +def _find_index_file(model_dir: str) -> str | None: + for filename in INDEX_FILENAMES: + candidate = os.path.join(model_dir, filename) + if os.path.isfile(candidate): + return filename + + matches = sorted( + filename + for filename in os.listdir(model_dir) + if filename.endswith(".safetensors.index.json") + ) + return matches[0] if matches else None + + +def _load_weight_map(model_dir: str) -> tuple[dict[str, str], str | None]: + index_filename = _find_index_file(model_dir) + if index_filename is not None: + with open(os.path.join(model_dir, index_filename), encoding="utf-8") as f: + index_data = json.load(f) + return dict(index_data["weight_map"]), index_filename + + safetensors_files = sorted( + filename + for filename in os.listdir(model_dir) + if filename.endswith(".safetensors") + ) + if len(safetensors_files) != 1: + raise ValueError( + f"Expected an index file or a single safetensors shard in {model_dir}, " + f"found {len(safetensors_files)} shard(s)." + ) + + shard_name = safetensors_files[0] + with safe_open( + os.path.join(model_dir, shard_name), framework="pt", device="cpu" + ) as f: + weight_map = {key: shard_name for key in f.keys()} + index_filename = f"{Path(shard_name).stem}.safetensors.index.json" + return weight_map, index_filename + + +def _load_config(model_dir: str) -> dict: + config_path = os.path.join(model_dir, "config.json") + with open(config_path, encoding="utf-8") as f: + return json.load(f) + + +def _write_config(model_dir: Path, config: Mapping[str, object]) -> None: + with open(model_dir / "config.json", "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, sort_keys=True) + f.write("\n") + + +def _copy_non_shard_files(source_dir: str, output_dir: str) -> None: + ignored = set(INDEX_FILENAMES) + for entry in os.listdir(source_dir): + if entry.endswith(".safetensors") or entry in ignored: + continue + source_path = os.path.join(source_dir, entry) + output_path = os.path.join(output_dir, entry) + if os.path.isdir(source_path): + shutil.copytree(source_path, output_path, dirs_exist_ok=True) + else: + shutil.copy2(source_path, output_path) + + +def _load_selected_tensors( + model_dir: str, + weight_map: Mapping[str, str], + tensor_names: Iterable[str], +): + tensors = {} + names_by_file: dict[str, list[str]] = defaultdict(list) + for name in tensor_names: + names_by_file[weight_map[name]].append(name) + + for filename, names in names_by_file.items(): + shard_path = os.path.join(model_dir, filename) + with safe_open(shard_path, framework="pt", device="cpu") as f: + for name in names: + tensors[name] = f.get_tensor(name).contiguous() + return tensors + + +def _module_name_for_tensor(tensor_name: str) -> str: + for suffix in _TENSOR_MODULE_SUFFIXES: + if tensor_name.endswith(suffix): + return tensor_name[: -len(suffix)] + return tensor_name + + +def _matches_any_pattern(module_name: str, patterns: Sequence[str]) -> bool: + if not patterns: + return False + for pattern in patterns: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + if re.fullmatch(regex_str, module_name): + return True + return False + + +def _preset_patterns(pattern_preset: str) -> list[str]: + if pattern_preset == "none": + return [] + if pattern_preset == "flux1-nvfp4": + return list(DEFAULT_FLUX1_NVFP4_FALLBACK_PATTERNS) + raise ValueError(f"Unsupported pattern preset: {pattern_preset}") + + +def _updated_quant_config( + source_config: Mapping[str, object], + *, + fallback_patterns: Sequence[str], + swap_weight_nibbles: bool, +) -> dict[str, object]: + output_config = json.loads(json.dumps(source_config)) + quant_config = output_config.get("quantization_config") + if not isinstance(quant_config, dict): + raise ValueError("Expected a flat quantization_config dict in config.json.") + if ( + quant_config.get("quant_method") != "modelopt" + or "FP4" not in str(quant_config.get("quant_algo", "")).upper() + ): + raise ValueError( + "This tool only supports ModelOpt diffusion NVFP4 exports " + "(quant_method=modelopt, quant_algo=FP4/NVFP4)." + ) + + ignore_patterns = list(quant_config.get("ignore", []) or []) + for pattern in fallback_patterns: + if pattern not in ignore_patterns: + ignore_patterns.append(pattern) + + quant_config["ignore"] = ignore_patterns + quant_config.setdefault( + "quant_type", str(quant_config.get("quant_algo", "")).upper() + ) + quant_config["swap_weight_nibbles"] = swap_weight_nibbles + return output_config + + +def build_modelopt_nvfp4_transformer( + *, + base_transformer_dir: str, + modelopt_hf_dir: str, + output_dir: str, + pattern_preset: str = "none", + keep_bf16_patterns: Sequence[str] | None = None, + swap_weight_nibbles: bool | None = None, + overwrite: bool = False, +) -> dict[str, int | bool]: + source_dir = _resolve_transformer_dir(modelopt_hf_dir) + base_dir = _resolve_transformer_dir(base_transformer_dir) + + patterns = _preset_patterns(pattern_preset) + if keep_bf16_patterns: + patterns.extend(keep_bf16_patterns) + + resolved_swap_weight_nibbles = ( + swap_weight_nibbles + if swap_weight_nibbles is not None + else (False if pattern_preset == "flux1-nvfp4" else True) + ) + output_config = _updated_quant_config( + _load_config(source_dir), + fallback_patterns=patterns, + swap_weight_nibbles=resolved_swap_weight_nibbles, + ) + quant_config = output_config["quantization_config"] + serialized_quant_config = json.dumps(quant_config, sort_keys=True) + + output_path = Path(output_dir).expanduser().resolve() + if output_path.exists(): + if not overwrite: + raise FileExistsError( + f"Output directory already exists: {output_path}. " + "Use --overwrite to replace it." + ) + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + _copy_non_shard_files(source_dir, str(output_path)) + _write_config(output_path, output_config) + + source_weight_map, index_filename = _load_weight_map(source_dir) + base_weight_map, _ = _load_weight_map(base_dir) + + fallback_tensor_names = sorted( + name + for name in base_weight_map + if name in source_weight_map + and _matches_any_pattern(_module_name_for_tensor(name), patterns) + ) + fallback_tensors = _load_selected_tensors( + base_dir, + base_weight_map, + fallback_tensor_names, + ) + fallback_modules = { + _module_name_for_tensor(tensor_name) for tensor_name in fallback_tensor_names + } + + weights_by_file: dict[str, list[str]] = defaultdict(list) + for tensor_name, filename in source_weight_map.items(): + weights_by_file[filename].append(tensor_name) + + updated_weight_map: dict[str, str] = {} + total_size = 0 + replaced_tensor_count = 0 + removed_aux_tensor_count = 0 + + for filename, tensor_names in sorted(weights_by_file.items()): + shard_path = os.path.join(source_dir, filename) + shard_tensors = load_file(shard_path, device="cpu") + + with safe_open(shard_path, framework="pt", device="cpu") as f: + metadata = dict(f.metadata() or {}) + + metadata.setdefault("format", "pt") + metadata["quantization_config"] = serialized_quant_config + metadata["_quantization_metadata"] = serialized_quant_config + + for name in list(shard_tensors.keys()): + if "_quantizer." in name: + del shard_tensors[name] + removed_aux_tensor_count += 1 + continue + + module_name = _module_name_for_tensor(name) + if module_name not in fallback_modules: + continue + + if name in fallback_tensors: + shard_tensors[name] = fallback_tensors[name] + replaced_tensor_count += 1 + else: + del shard_tensors[name] + removed_aux_tensor_count += 1 + + save_file(shard_tensors, os.path.join(output_path, filename), metadata=metadata) + + for name, tensor in shard_tensors.items(): + updated_weight_map[name] = filename + total_size += tensor.element_size() * tensor.numel() + + if index_filename is None: + raise ValueError( + "Expected a sharded or indexed ModelOpt HF export, but no index file was found." + ) + + with open(output_path / index_filename, "w", encoding="utf-8") as f: + json.dump( + { + "metadata": {"total_size": total_size}, + "weight_map": updated_weight_map, + }, + f, + indent=2, + sort_keys=True, + ) + f.write("\n") + + return { + "fallback_modules": len(fallback_modules), + "replaced_tensors": replaced_tensor_count, + "removed_aux_tensors": removed_aux_tensor_count, + "output_shards": len(weights_by_file), + "swap_weight_nibbles": resolved_swap_weight_nibbles, + } + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Build an SGLang-loadable ModelOpt NVFP4 diffusion transformer and " + "optionally keep selected modules in BF16." + ) + ) + parser.add_argument( + "--base-transformer-dir", + required=True, + help="Original BF16 transformer directory, or a parent model directory.", + ) + parser.add_argument( + "--modelopt-hf-dir", + required=True, + help="ModelOpt --hf-ckpt-dir output, or its transformer subdirectory.", + ) + parser.add_argument( + "--output-dir", + required=True, + help="Directory to write the mixed transformer checkpoint.", + ) + parser.add_argument( + "--pattern-preset", + choices=["none", "flux1-nvfp4"], + default="none", + help="Optional model-family BF16 fallback preset.", + ) + parser.add_argument( + "--keep-bf16-pattern", + action="append", + default=[], + help=( + "Glob-style pattern matched against module names without trailing tensor " + "suffixes such as .weight or .bias." + ), + ) + parser.add_argument( + "--swap-weight-nibbles", + action=argparse.BooleanOptionalAction, + default=None, + help=( + "Whether the runtime should swap packed FP4 nibbles before padding. " + "Defaults to false for --pattern-preset flux1-nvfp4 and true otherwise." + ), + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Replace --output-dir if it already exists.", + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + stats = build_modelopt_nvfp4_transformer( + base_transformer_dir=args.base_transformer_dir, + modelopt_hf_dir=args.modelopt_hf_dir, + output_dir=args.output_dir, + pattern_preset=args.pattern_preset, + keep_bf16_patterns=args.keep_bf16_pattern, + swap_weight_nibbles=args.swap_weight_nibbles, + overwrite=args.overwrite, + ) + print(json.dumps(stats, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main()