77from tests .utils import ensure_similar_images
88
99from refiners .fluxion .utils import no_grad
10- from refiners .foundationals .latent_diffusion .auto_encoder import LatentDiffusionAutoencoder
10+ from refiners .foundationals .latent_diffusion import (
11+ LatentDiffusionAutoencoder ,
12+ SD1Autoencoder ,
13+ SDXLAutoencoder ,
14+ )
1115
1216
1317@pytest .fixture (scope = "module" )
@@ -16,25 +20,37 @@ def sample_image() -> Image.Image:
1620 if not test_image .is_file ():
1721 warn (f"could not reference image at { test_image } , skipping" )
1822 pytest .skip (allow_module_level = True )
19- img = Image .open (test_image ) # type: ignore
23+ img = Image .open (test_image )
2024 assert img .size == (512 , 512 )
2125 return img
2226
2327
24- @pytest .fixture (scope = "module" )
28+ @pytest .fixture (scope = "module" , params = [ "SD1.5" , "SDXL" ] )
2529def autoencoder (
26- refiners_autoencoder : LatentDiffusionAutoencoder ,
30+ request : pytest .FixtureRequest ,
31+ refiners_sd15_autoencoder : SD1Autoencoder ,
32+ refiners_sdxl_autoencoder : SDXLAutoencoder ,
2733 test_device : torch .device ,
34+ test_dtype_fp32_bf16_fp16 : torch .dtype ,
2835) -> LatentDiffusionAutoencoder :
29- return refiners_autoencoder .to (test_device )
36+ model_version = request .param
37+ match (model_version , test_dtype_fp32_bf16_fp16 ):
38+ case ("SD1.5" , _):
39+ return refiners_sd15_autoencoder .to (device = test_device , dtype = test_dtype_fp32_bf16_fp16 )
40+ case ("SDXL" , torch .float16 ):
41+ return refiners_sdxl_autoencoder .to (device = test_device , dtype = test_dtype_fp32_bf16_fp16 )
42+ case ("SDXL" , _):
43+ return refiners_sdxl_autoencoder .to (device = test_device , dtype = test_dtype_fp32_bf16_fp16 )
44+ case _:
45+ raise ValueError (f"Unknown model version: { model_version } " )
3046
3147
3248@no_grad ()
3349def test_encode_decode_image (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
3450 encoded = autoencoder .image_to_latents (sample_image )
3551 decoded = autoencoder .latents_to_image (encoded )
3652
37- assert decoded .mode == "RGB" # type: ignore
53+ assert decoded .mode == "RGB"
3854
3955 # Ensure no saturation. The green channel (band = 1) must not max out.
4056 assert max (iter (decoded .getdata (band = 1 ))) < 255 # type: ignore
@@ -53,7 +69,7 @@ def test_encode_decode_images(autoencoder: LatentDiffusionAutoencoder, sample_im
5369
5470@no_grad ()
5571def test_tiled_autoencoder (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
56- sample_image = sample_image .resize ((2048 , 2048 )) # type: ignore
72+ sample_image = sample_image .resize ((2048 , 2048 ))
5773
5874 with autoencoder .tiled_inference (sample_image , tile_size = (512 , 512 )):
5975 encoded = autoencoder .tiled_image_to_latents (sample_image )
@@ -64,7 +80,7 @@ def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image
6480
6581@no_grad ()
6682def test_tiled_autoencoder_rectangular_tiles (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
67- sample_image = sample_image .resize ((2048 , 2048 )) # type: ignore
83+ sample_image = sample_image .resize ((2048 , 2048 ))
6884
6985 with autoencoder .tiled_inference (sample_image , tile_size = (512 , 1024 )):
7086 encoded = autoencoder .tiled_image_to_latents (sample_image )
@@ -75,7 +91,7 @@ def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoenc
7591
7692@no_grad ()
7793def test_tiled_autoencoder_large_tile (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
78- sample_image = sample_image .resize ((1024 , 1024 )) # type: ignore
94+ sample_image = sample_image .resize ((1024 , 1024 ))
7995
8096 with autoencoder .tiled_inference (sample_image , tile_size = (2048 , 2048 )):
8197 encoded = autoencoder .tiled_image_to_latents (sample_image )
@@ -87,7 +103,7 @@ def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, s
87103@no_grad ()
88104def test_tiled_autoencoder_rectangular_image (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
89105 sample_image = sample_image .crop ((0 , 0 , 300 , 500 ))
90- sample_image = sample_image .resize ((sample_image .width * 4 , sample_image .height * 4 )) # type: ignore
106+ sample_image = sample_image .resize ((sample_image .width * 4 , sample_image .height * 4 ))
91107
92108 with autoencoder .tiled_inference (sample_image , tile_size = (512 , 512 )):
93109 encoded = autoencoder .tiled_image_to_latents (sample_image )
@@ -96,6 +112,30 @@ def test_tiled_autoencoder_rectangular_image(autoencoder: LatentDiffusionAutoenc
96112 ensure_similar_images (sample_image , result , min_psnr = 37 , min_ssim = 0.985 )
97113
98114
115+ @no_grad ()
116+ @pytest .fixture (scope = "module" , params = [240 , 242 , 244 , 254 , 256 , 258 ])
117+ def test_tiled_autoencoder_pathologic_sizes (
118+ request : pytest .FixtureRequest ,
119+ refiners_sd15_autoencoder : SD1Autoencoder ,
120+ sample_image : Image .Image ,
121+ test_device : torch .device ,
122+ ):
123+ # 242 is a tile just larger than (tile size - overlap).
124+ # 242 * 4 = 968 = (128 - 8 + 1) * 8
125+ tile_w = request .param
126+
127+ autoencoder = refiners_sd15_autoencoder .to (device = test_device , dtype = torch .float32 )
128+
129+ sample_image = sample_image .crop ((0 , 0 , tile_w , 400 ))
130+ sample_image = sample_image .resize ((sample_image .width * 4 , sample_image .height * 4 ))
131+
132+ with autoencoder .tiled_inference (sample_image , tile_size = (1024 , 1024 )):
133+ encoded = autoencoder .tiled_image_to_latents (sample_image )
134+ result = autoencoder .tiled_latents_to_image (encoded )
135+
136+ ensure_similar_images (sample_image , result , min_psnr = 37 , min_ssim = 0.985 )
137+
138+
99139def test_value_error_tile_encode_no_context (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ) -> None :
100140 with pytest .raises (ValueError ):
101141 autoencoder .tiled_image_to_latents (sample_image )
0 commit comments