Fix crash in patch_scale_key when tensor dtypes do not match and prevent prompt corruption #754
+15
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem
patch_scale_key()assumes model tensors and checkpoint tensors always share the same dtype. In practice this is not guaranteed: the model may be initialized in fp16 while the checkpoint is bf16. This can trigger anAssertionErrorduring load or, worse, lead to corrupted weights that make outputs ignore the prompt.On bf16-capable GPUs (e.g., RTX 40/50 series), the current logic may still choose fp16 in some paths. Casting bf16 weights to fp16 can introduce NaN/inf or precision loss for quantized tensors and break conditioning.
Solution
Robust dtype alignment before
patch_scale_key()Prefer bf16 when supported
unet_dtyperesolves to fp16, override to bf16 and disable manual casting. This prevents fp16 weight corruption on modern GPUs.Safe fp16 cast (legacy GPUs)
torch.nan_to_num(..., nan=0.0, posinf=65504, neginf=-65504).Why this is safe
Result
Models load successfully across mixed-precision environments without affecting inference correctness.
Checklist