Skip to content

Conversation

@tonera
Copy link

@tonera tonera commented Dec 12, 2025

Motivation

On a GB10 (cc 12.1) machine, loading weights using NunchakuFluxTransformer2dModel.from_pretrained(...) takes approximately 37 seconds, while on a 3090 machine it takes approximately 3 seconds.

The profiler display time is almost entirely consumed by numerous cudaMemcpyAsync operations (approximately 2400+ small block copies). Typically, small blocks are moved from the CPU to the GPU memory path, amplifying the fixed overhead of each copy.

Modifications

Solution (How to Locate and Fix)

Location: H2D bandwidth testing confirmed the link was normal (pinned H2D was very high).

The "small copy" experiment verified that the difference between pageable and pinned was ~50×.

The loading logic was found in nunchaku/models/transformers/transformer_flux.py::from_pretrained: load_file(...) first reads the weights into the CPU tensor, then triggers a large number of H2D copies.

Fix: Added a _pin_state_dict(...) helper to perform pin_memory() on the CPU tensors in quantized_part_sd and unquantized_part_sd.

Added a pin_memory switch (default True) to from_pretrained(), enabling it only when device.type == "cuda".

Results (actual test): Loading profiling time decreased from ~37.5s Self CPU time to ~3.9s.

Self CUDA time decreased to the ~154ms range.

Conclusion: This patch is a performance fix, significantly reducing the overhead of "CPU→GPU small block transport during the weight loading phase".

@tonera
Copy link
Author

tonera commented Dec 17, 2025

Comparison results:

NunchakuFluxTransformer2dModel.from_pretrained:

transformer = NunchakuFluxTransformer2dModel.from_pretrained("weights/nunchaku-flux.1-dev/svdq-fp4_r32-flux.1-dev.safetensors",pin_memory=True)
[2025-12-17 10:49:22.805] [info] Initializing QuantizedFluxModel on device 0
[2025-12-17 10:49:23.084] [info] Loading partial weights from pytorch
[2025-12-17 10:49:23.222] [info] Done.
Injecting quantized module --------------------------------> spent 1s
transformer = NunchakuFluxTransformer2dModel.from_pretrained("weights/nunchaku-flux.1-dev/svdq-fp4_r32-flux.1-dev.safetensors",pin_memory=False)
[2025-12-17 10:49:29.788] [info] Initializing QuantizedFluxModel on device 0
[2025-12-17 10:49:30.068] [info] Loading partial weights from pytorch
[2025-12-17 10:50:06.274] [info] Done.
Injecting quantized module --------------------------------> spent 37s
transformer = NunchakuFluxTransformer2dModel.from_pretrained("weights/nunchaku-flux.1-dev/svdq-fp4_r32-flux.1-dev.safetensors",pin_memory=True)
[2025-12-17 10:50:17.494] [info] Initializing QuantizedFluxModel on device 0
[2025-12-17 10:50:17.749] [info] Loading partial weights from pytorch
[2025-12-17 10:50:17.878] [info] Done.
Injecting quantized module --------------------------------> spent <1s
transformer = NunchakuFluxTransformer2dModel.from_pretrained("weights/nunchaku-flux.1-dev/svdq-fp4_r32-flux.1-dev.safetensors",pin_memory=False)
[2025-12-17 10:50:34.962] [info] Initializing QuantizedFluxModel on device 0
[2025-12-17 10:50:35.217] [info] Loading partial weights from pytorch
[2025-12-17 10:51:10.768] [info] Done.
Injecting quantized module --------------------------------> spent 36s

0251217-111009@2x

@vgabbo
Copy link

vgabbo commented Dec 17, 2025

Congrats! This is a strong result, I hope it will get merged!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants