Skip to content

Fix crash in patch_scale_key when tensor dtypes do not match and prevent prompt corruption#754

Open
1038lab wants to merge 2 commits intonunchaku-ai:mainfrom
1038lab:main
Open

Fix crash in patch_scale_key when tensor dtypes do not match and prevent prompt corruption#754
1038lab wants to merge 2 commits intonunchaku-ai:mainfrom
1038lab:main

Conversation

@1038lab
Copy link

@1038lab 1038lab commented Dec 30, 2025

Problem

patch_scale_key() assumes model tensors and checkpoint tensors always share the same dtype. In practice this is not guaranteed: the model may be initialized in fp16 while the checkpoint is bf16. This can trigger an AssertionError during load or, worse, lead to corrupted weights that make outputs ignore the prompt.

On bf16-capable GPUs (e.g., RTX 40/50 series), the current logic may still choose fp16 in some paths. Casting bf16 weights to fp16 can introduce NaN/inf or precision loss for quantized tensors and break conditioning.

Solution

  1. Robust dtype alignment before patch_scale_key()

    • If a checkpoint tensor is floating-point and its dtype differs from the model tensor, cast the checkpoint tensor to the model tensor’s dtype.
    • Preserve the original assertion behavior for non-floating tensors.
  2. Prefer bf16 when supported

    • If the GPU supports bf16 and unet_dtype resolves to fp16, override to bf16 and disable manual casting. This prevents fp16 weight corruption on modern GPUs.
  3. Safe fp16 cast (legacy GPUs)

    • When casting to fp16, sanitize NaN/inf values using torch.nan_to_num(..., nan=0.0, posinf=65504, neginf=-65504).

Why this is safe

  • Only floating-point tensors are cast; quantized/integer tensors remain untouched.
  • Behavior is unchanged when dtypes already match.
  • bf16 override only applies on bf16-capable GPUs and avoids known fp16 corruption paths.

Result

Models load successfully across mixed-precision environments without affecting inference correctness.


Checklist

  • Fixes a real crash during model loading
  • Does not change inference behavior
  • Safe for mixed-precision environments
  • Additional test coverage (optional / follow-up)

Fix crash in patch_scale_key when tensor dtypes do not match
@Ph0rk0z
Copy link

Ph0rk0z commented Dec 30, 2025

This was a problem for 3090 and turning too. What GPU did it actually work for if 5xxx doesn't.

@1038lab
Copy link
Author

1038lab commented Dec 30, 2025

This was a problem for 3090 and turning too. What GPU did it actually work for if 5xxx doesn't.

This issue is not specific to RTX 50-series. I reproduced it on a 50-series test machine, but the root cause is a dtype mismatch during model loading (fp16 vs bf16), which can also occur on 3090 / Turing and other GPUs depending on PyTorch version and load path.

@Ph0rk0z
Copy link

Ph0rk0z commented Dec 30, 2025

This guy also fixed it upstream: nunchaku-ai/nunchaku#833 I've been using their PR merged.

@1038lab
Copy link
Author

1038lab commented Dec 30, 2025

Both approaches address the same crash, but at different layers.

PR #833 works by explicitly handling fp16 conversion on certain GPUs, which resolves the issue in those specific environments.

This PR fixes the underlying assumption in patch_scale_key that model and checkpoint tensors always share the same dtype, making the loading logic more robust across mixed-precision setups in general.

Both solutions are valid depending on the use case, feel free to try either and use whichever works best in your environment.

@1038lab 1038lab changed the title Fix crash in patch_scale_key when tensor dtypes do not match Fix crash in patch_scale_key when tensor dtypes do not match and prevent prompt corruption Dec 31, 2025
@praveenmaniyan
Copy link

I tried your patch and also the PR #833, both of them do no work on Nvidia RTX 3050 (4GB VRAM). The dtype is already torch.bfloat16. hence both the patches are not impacting/involved. I get a ComfyUI crash as reported in nunchaku-ai/nunchaku#841 (comment)

I executed the below code for debug purposes.

print("Val 1: ", torch.cuda.is_available())
print("Val 2: ", torch.cuda.is_bf16_supported())
print("Val 3: ", unet_dtype)
if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and unet_dtype == torch.float16:
unet_dtype = torch.bfloat16
manual_cast_dtype = None
print("Inside new code 1")

patched_sd = _patch_state_dict(new_sd)

model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model = model_config.get_model(patched_sd, "")

model_sd = model.diffusion_model.state_dict()
for key, model_value in model_sd.items():
    print("Inside new code 2")
    if key in patched_sd:
        print("Inside new code 3")
        ckpt_value = patched_sd[key]
        if torch.is_tensor(ckpt_value) and torch.is_tensor(model_value):
            if ckpt_value.is_floating_point() and ckpt_value.dtype != model_value.dtype:
                cast_value = ckpt_value.to(dtype=model_value.dtype)
                print("Inside new code 4")
                if model_value.dtype == torch.float16:
                    print("Inside new code 5")
                    cast_value = torch.nan_to_num(cast_value, nan=0.0, posinf=65504, neginf=-65504)
                patched_sd[key] = cast_value
                print("Inside new code 6")

Output is:

Val 1: True
Val 2: True
Val 3: torch.bfloat16
model weight dtype torch.bfloat16, manual cast: None
model_type FLOW
Torch dtype: torch.bfloat16

Inside new code 2
Inside new code 3
Inside new code 2
Inside new code 3

Requested to load Lumina2
loaded partially; 2419.36 MB usable, 2356.90 MB loaded, 261.38 MB offloaded, 62.46 MB buffer reserved, lowvram patches: 0
0%| | 0/9 [00:00<?, ?it/s]Aborted (core dumped)

My config is:

Checkpoint files will always be loaded safely.
Total VRAM 3770 MB, total RAM 63989 MB
pytorch version: 2.9.1+cu130
xformers version: 0.0.33.post2
Set vram state to: LOW_VRAM
Device: cuda:0 NVIDIA GeForce RTX 3050 Laptop GPU : cudaMallocAsync
Using async weight offloading with 2 streams
Enabled pinned memory 60789.0
working around nvidia conv3d memory bug.
Using sage attention
Python version: 3.12.11 | packaged by Anaconda, Inc. | (main, Jun 5 2025, 13:09:17) [GCC 11.2.0]
ComfyUI version: 0.6.0
ComfyUI frontend version: 1.35.9
[Prompt Server] web root: /home/pm/anaconda3/envs/comfygpu/lib/python3.12/site-packages/comfyui_frontend_package/static
Total VRAM 3770 MB, total RAM 63989 MB
pytorch version: 2.9.1+cu130
xformers version: 0.0.33.post2
Set vram state to: LOW_VRAM
Device: cuda:0 NVIDIA GeForce RTX 3050 Laptop GPU : cudaMallocAsync
Using async weight offloading with 2 streams
Enabled pinned memory 60789.0

@1038lab
Copy link
Author

1038lab commented Jan 1, 2026

Thanks for the detailed diagnostics.

From your logs, the model and checkpoint are already using torch.bfloat16, so the dtype-alignment logic in this PR is not being exercised. That suggests this crash is not caused by the fp16↔bf16 mismatch addressed here, but occurs later during execution.

Given the RTX 3050 (4GB) + LOW_VRAM + async offloading setup, this is likely a separate low-VRAM / offload / runtime issue, rather than a regression from this PR.

This PR fixes a real dtype-mismatch crash during model loading, but your case looks orthogonal and should probably be tracked as a follow-up issue focused on low-VRAM execution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants