reland [Diffusion] Add FLUX.1-dev ModelOpt NVFP4 support#22672
reland [Diffusion] Add FLUX.1-dev ModelOpt NVFP4 support#22672
Conversation
|
/tag-and-rerun-ci |
There was a problem hiding this comment.
Code Review
This pull request expands ModelOpt quantization support for diffusion models, introducing FP8 and NVFP4 compatibility for FLUX and LTX-2 families. It adds a new tool for building mixed-precision NVFP4 checkpoints, implements JIT module prewarming for torch.compile, and adds support for NVFP4 nibble swapping. Review feedback suggests broadening exception handling in the FSDP loader and improving the safety of directory management in the build scripts.
| try: | ||
| weight_loader(temp_param, full_tensor) | ||
| except AssertionError as exc: | ||
| raise AssertionError( | ||
| "Failed to shard/load parameter " | ||
| f"{target_param_name}: full_tensor.shape={tuple(full_tensor.shape)}, " | ||
| f"meta_sharded_param.shape={tuple(meta_sharded_param.shape)}, " | ||
| f"temp_param.shape={tuple(temp_param.shape)}, " | ||
| f"param_cls={type(actual_param).__name__}" | ||
| ) from exc |
There was a problem hiding this comment.
While catching AssertionError provides useful context for shape mismatches during weight loading, it might be safer to catch a broader Exception or specifically RuntimeError as well, as some weight loaders might raise different exception types depending on the underlying failure (e.g., device-side errors or memory allocation issues). If the intent is strictly to debug shape mismatches, this is fine, but consider if other loading failures should also be wrapped with this diagnostic information.
| if output_path.exists(): | ||
| if not overwrite: | ||
| raise FileExistsError( | ||
| f"Output directory already exists: {output_path}. " | ||
| "Use --overwrite to replace it." | ||
| ) | ||
| shutil.rmtree(output_path) |
There was a problem hiding this comment.
The use of shutil.rmtree(output_path) when overwrite=True is dangerous if the user accidentally provides a path to a directory containing important data (like the base model directory). It would be safer to only delete specific files that the tool expects to write, or at least issue a warning before deletion.
|
/tag-and-rerun-ci |
|
I dug into the default The failure is not in
I also ran two controls:
Those checks succeeded once I forced the I pushed a small follow-up on this branch (
|
Summary
Validation
torch.compiledisabled throughout benchmark/profile/correctness runspytest -q python/sglang/multimodal_gen/test/unit/test_transformer_quant.py -qin the remote diffusion container37.6940s29.0421s(22.95%faster)38.2545s29.4954s(22.90%faster)0.9933, final image PSNR28.16 dBbf16:
nvfp4:
Notes
--transformer-pathfor the mixed SGLang transformer override.torch.compiledisabled.