@@ -204,6 +204,49 @@ def test_deprecate_testing_utils_module(self):
204204 ), f"Expected deprecation message substring not found, got: { messages } "
205205
206206
207+ class FourierFilterTester (unittest .TestCase ):
208+ """Tests for :func:`diffusers.utils.torch_utils.fourier_filter` (FreeU helper)."""
209+
210+ def _run_without_complexhalf_warning (self , dtype ):
211+ import torch
212+
213+ from diffusers .utils .torch_utils import fourier_filter
214+
215+ x = torch .randn (1 , 4 , 32 , 32 , dtype = dtype )
216+ with warnings .catch_warnings (record = True ) as caught :
217+ warnings .simplefilter ("always" )
218+ out = fourier_filter (x , threshold = 1 , scale = 0.5 )
219+
220+ messages = [str (w .message ) for w in caught ]
221+ assert not any ("ComplexHalf" in m for m in messages ), (
222+ f"Unexpected ComplexHalf warning emitted by fourier_filter: { messages } "
223+ )
224+ return out
225+
226+ def test_fourier_filter_float16_no_complexhalf_warning (self ):
227+ import torch
228+
229+ out = self ._run_without_complexhalf_warning (torch .float16 )
230+ assert out .dtype == torch .float16
231+
232+ def test_fourier_filter_bfloat16_no_complexhalf_warning (self ):
233+ import torch
234+
235+ out = self ._run_without_complexhalf_warning (torch .bfloat16 )
236+ assert out .dtype == torch .bfloat16
237+
238+ def test_fourier_filter_preserves_dtype_and_shape (self ):
239+ import torch
240+
241+ from diffusers .utils .torch_utils import fourier_filter
242+
243+ for dtype in (torch .float32 , torch .float16 , torch .bfloat16 ):
244+ x = torch .randn (2 , 3 , 16 , 16 , dtype = dtype )
245+ out = fourier_filter (x , threshold = 1 , scale = 0.5 )
246+ assert out .dtype == dtype
247+ assert out .shape == x .shape
248+
249+
207250# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
208251class ExpectationsTester (unittest .TestCase ):
209252 def test_expectations (self ):
0 commit comments