Fix crash in patch_scale_key when tensor dtypes do not match and prevent prompt corruption#754
Fix crash in patch_scale_key when tensor dtypes do not match and prevent prompt corruption#7541038lab wants to merge 2 commits intonunchaku-ai:mainfrom
Conversation
Fix crash in patch_scale_key when tensor dtypes do not match
|
This was a problem for 3090 and turning too. What GPU did it actually work for if 5xxx doesn't. |
This issue is not specific to RTX 50-series. I reproduced it on a 50-series test machine, but the root cause is a dtype mismatch during model loading (fp16 vs bf16), which can also occur on 3090 / Turing and other GPUs depending on PyTorch version and load path. |
|
This guy also fixed it upstream: nunchaku-ai/nunchaku#833 I've been using their PR merged. |
|
Both approaches address the same crash, but at different layers. PR #833 works by explicitly handling fp16 conversion on certain GPUs, which resolves the issue in those specific environments. This PR fixes the underlying assumption in patch_scale_key that model and checkpoint tensors always share the same dtype, making the loading logic more robust across mixed-precision setups in general. Both solutions are valid depending on the use case, feel free to try either and use whichever works best in your environment. |
|
I tried your patch and also the PR #833, both of them do no work on Nvidia RTX 3050 (4GB VRAM). The dtype is already torch.bfloat16. hence both the patches are not impacting/involved. I get a ComfyUI crash as reported in nunchaku-ai/nunchaku#841 (comment) I executed the below code for debug purposes. print("Val 1: ", torch.cuda.is_available()) Output is: Val 1: True Inside new code 2 Requested to load Lumina2 My config is: Checkpoint files will always be loaded safely. |
|
Thanks for the detailed diagnostics. From your logs, the model and checkpoint are already using torch.bfloat16, so the dtype-alignment logic in this PR is not being exercised. That suggests this crash is not caused by the fp16↔bf16 mismatch addressed here, but occurs later during execution. Given the RTX 3050 (4GB) + LOW_VRAM + async offloading setup, this is likely a separate low-VRAM / offload / runtime issue, rather than a regression from this PR. This PR fixes a real dtype-mismatch crash during model loading, but your case looks orthogonal and should probably be tracked as a follow-up issue focused on low-VRAM execution. |
Problem
patch_scale_key()assumes model tensors and checkpoint tensors always share the same dtype. In practice this is not guaranteed: the model may be initialized in fp16 while the checkpoint is bf16. This can trigger anAssertionErrorduring load or, worse, lead to corrupted weights that make outputs ignore the prompt.On bf16-capable GPUs (e.g., RTX 40/50 series), the current logic may still choose fp16 in some paths. Casting bf16 weights to fp16 can introduce NaN/inf or precision loss for quantized tensors and break conditioning.
Solution
Robust dtype alignment before
patch_scale_key()Prefer bf16 when supported
unet_dtyperesolves to fp16, override to bf16 and disable manual casting. This prevents fp16 weight corruption on modern GPUs.Safe fp16 cast (legacy GPUs)
torch.nan_to_num(..., nan=0.0, posinf=65504, neginf=-65504).Why this is safe
Result
Models load successfully across mixed-precision environments without affecting inference correctness.
Checklist