Skip to content

Conversation

@1038lab
Copy link

@1038lab 1038lab commented Dec 30, 2025

Problem

patch_scale_key() assumes model tensors and checkpoint tensors always share the same dtype. In practice this is not guaranteed: the model may be initialized in fp16 while the checkpoint is bf16. This can trigger an AssertionError during load or, worse, lead to corrupted weights that make outputs ignore the prompt.

On bf16-capable GPUs (e.g., RTX 40/50 series), the current logic may still choose fp16 in some paths. Casting bf16 weights to fp16 can introduce NaN/inf or precision loss for quantized tensors and break conditioning.

Solution

  1. Robust dtype alignment before patch_scale_key()

    • If a checkpoint tensor is floating-point and its dtype differs from the model tensor, cast the checkpoint tensor to the model tensor’s dtype.
    • Preserve the original assertion behavior for non-floating tensors.
  2. Prefer bf16 when supported

    • If the GPU supports bf16 and unet_dtype resolves to fp16, override to bf16 and disable manual casting. This prevents fp16 weight corruption on modern GPUs.
  3. Safe fp16 cast (legacy GPUs)

    • When casting to fp16, sanitize NaN/inf values using torch.nan_to_num(..., nan=0.0, posinf=65504, neginf=-65504).

Why this is safe

  • Only floating-point tensors are cast; quantized/integer tensors remain untouched.
  • Behavior is unchanged when dtypes already match.
  • bf16 override only applies on bf16-capable GPUs and avoids known fp16 corruption paths.

Result

Models load successfully across mixed-precision environments without affecting inference correctness.


Checklist

  • Fixes a real crash during model loading
  • Does not change inference behavior
  • Safe for mixed-precision environments
  • Additional test coverage (optional / follow-up)

Fix crash in patch_scale_key when tensor dtypes do not match
@Ph0rk0z
Copy link

Ph0rk0z commented Dec 30, 2025

This was a problem for 3090 and turning too. What GPU did it actually work for if 5xxx doesn't.

@1038lab
Copy link
Author

1038lab commented Dec 30, 2025

This was a problem for 3090 and turning too. What GPU did it actually work for if 5xxx doesn't.

This issue is not specific to RTX 50-series. I reproduced it on a 50-series test machine, but the root cause is a dtype mismatch during model loading (fp16 vs bf16), which can also occur on 3090 / Turing and other GPUs depending on PyTorch version and load path.

@Ph0rk0z
Copy link

Ph0rk0z commented Dec 30, 2025

This guy also fixed it upstream: nunchaku-tech/nunchaku#833 I've been using their PR merged.

@1038lab
Copy link
Author

1038lab commented Dec 30, 2025

Both approaches address the same crash, but at different layers.

PR #833 works by explicitly handling fp16 conversion on certain GPUs, which resolves the issue in those specific environments.

This PR fixes the underlying assumption in patch_scale_key that model and checkpoint tensors always share the same dtype, making the loading logic more robust across mixed-precision setups in general.

Both solutions are valid depending on the use case, feel free to try either and use whichever works best in your environment.

@1038lab 1038lab changed the title Fix crash in patch_scale_key when tensor dtypes do not match Fix crash in patch_scale_key when tensor dtypes do not match and prevent prompt corruption Dec 31, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants