Skip to content

Commit 0bf81b5

Browse files
authored
Merge branch 'main' into torchao-fix-input-dtype
2 parents c51a579 + b8aebf4 commit 0bf81b5

5 files changed

Lines changed: 51 additions & 5 deletions

File tree

src/diffusers/models/attention_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,7 +2140,7 @@ def __call__(
21402140
if attn.norm_added_q is not None:
21412141
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
21422142
if attn.norm_added_k is not None:
2143-
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
2143+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
21442144

21452145
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
21462146
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -2237,7 +2237,7 @@ def __call__(
22372237
if attn.norm_added_q is not None:
22382238
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
22392239
if attn.norm_added_k is not None:
2240-
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
2240+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
22412241

22422242
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
22432243
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)

src/diffusers/models/transformers/transformer_ernie_image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
290290

291291
class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin):
292292
_supports_gradient_checkpointing = True
293+
_repeated_blocks = ["ErnieImageSharedAdaLNBlock"]
293294

294295
@register_to_config
295296
def __init__(

src/diffusers/utils/torch_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,10 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
223223
# Non-power of 2 images must be float32
224224
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
225225
x = x.to(dtype=torch.float32)
226-
# fftn does not support bfloat16
227-
elif x.dtype == torch.bfloat16:
226+
# fftn does not support bfloat16, and produces the experimental ComplexHalf
227+
# dtype (torch.complex32) when given float16, which is numerically unstable
228+
# and triggers a UserWarning. Upcast any non-float32 dtype to float32.
229+
elif x.dtype != torch.float32:
228230
x = x.to(dtype=torch.float32)
229231

230232
# FFT

tests/others/test_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,49 @@ def test_deprecate_testing_utils_module(self):
204204
), f"Expected deprecation message substring not found, got: {messages}"
205205

206206

207+
class FourierFilterTester(unittest.TestCase):
208+
"""Tests for :func:`diffusers.utils.torch_utils.fourier_filter` (FreeU helper)."""
209+
210+
def _run_without_complexhalf_warning(self, dtype):
211+
import torch
212+
213+
from diffusers.utils.torch_utils import fourier_filter
214+
215+
x = torch.randn(1, 4, 32, 32, dtype=dtype)
216+
with warnings.catch_warnings(record=True) as caught:
217+
warnings.simplefilter("always")
218+
out = fourier_filter(x, threshold=1, scale=0.5)
219+
220+
messages = [str(w.message) for w in caught]
221+
assert not any("ComplexHalf" in m for m in messages), (
222+
f"Unexpected ComplexHalf warning emitted by fourier_filter: {messages}"
223+
)
224+
return out
225+
226+
def test_fourier_filter_float16_no_complexhalf_warning(self):
227+
import torch
228+
229+
out = self._run_without_complexhalf_warning(torch.float16)
230+
assert out.dtype == torch.float16
231+
232+
def test_fourier_filter_bfloat16_no_complexhalf_warning(self):
233+
import torch
234+
235+
out = self._run_without_complexhalf_warning(torch.bfloat16)
236+
assert out.dtype == torch.bfloat16
237+
238+
def test_fourier_filter_preserves_dtype_and_shape(self):
239+
import torch
240+
241+
from diffusers.utils.torch_utils import fourier_filter
242+
243+
for dtype in (torch.float32, torch.float16, torch.bfloat16):
244+
x = torch.randn(2, 3, 16, 16, dtype=dtype)
245+
out = fourier_filter(x, threshold=1, scale=0.5)
246+
assert out.dtype == dtype
247+
assert out.shape == x.shape
248+
249+
207250
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
208251
class ExpectationsTester(unittest.TestCase):
209252
def test_expectations(self):

tests/pipelines/hidream_image/test_pipeline_hidream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def get_dummy_components(self):
9696

9797
torch.manual_seed(0)
9898
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
99-
text_encoder_3 = T5EncoderModel(config)
99+
text_encoder_3 = T5EncoderModel(config).eval()
100100

101101
torch.manual_seed(0)
102102
text_encoder_4 = LlamaForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")

0 commit comments

Comments
 (0)