77from tests .utils import ensure_similar_images
88
99from refiners .fluxion .utils import no_grad
10- from refiners .foundationals .latent_diffusion .auto_encoder import LatentDiffusionAutoencoder
10+ from refiners .foundationals .latent_diffusion import (
11+ LatentDiffusionAutoencoder ,
12+ SD1Autoencoder ,
13+ SDXLAutoencoder ,
14+ )
1115
1216
1317@pytest .fixture (scope = "module" )
@@ -16,25 +20,32 @@ def sample_image() -> Image.Image:
1620 if not test_image .is_file ():
1721 warn (f"could not reference image at { test_image } , skipping" )
1822 pytest .skip (allow_module_level = True )
19- img = Image .open (test_image ) # type: ignore
23+ img = Image .open (test_image )
2024 assert img .size == (512 , 512 )
2125 return img
2226
2327
24- @pytest .fixture (scope = "module" )
28+ @pytest .fixture (scope = "module" , params = [ "SD1.5" , "SDXL" ] )
2529def autoencoder (
26- refiners_autoencoder : LatentDiffusionAutoencoder ,
30+ request : pytest .FixtureRequest ,
31+ refiners_sd15_autoencoder : SD1Autoencoder ,
32+ refiners_sdxl_autoencoder : SDXLAutoencoder ,
2733 test_device : torch .device ,
34+ test_dtype_fp32_bf16_fp16 : torch .dtype ,
2835) -> LatentDiffusionAutoencoder :
29- return refiners_autoencoder .to (test_device )
36+ model_version = request .param
37+ if model_version == "SDXL" and test_dtype_fp32_bf16_fp16 == torch .float16 :
38+ pytest .skip ("SDXL autoencoder does not support float16" )
39+ ae = refiners_sd15_autoencoder if model_version == "SD1.5" else refiners_sdxl_autoencoder
40+ return ae .to (device = test_device , dtype = test_dtype_fp32_bf16_fp16 )
3041
3142
3243@no_grad ()
3344def test_encode_decode_image (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
3445 encoded = autoencoder .image_to_latents (sample_image )
3546 decoded = autoencoder .latents_to_image (encoded )
3647
37- assert decoded .mode == "RGB" # type: ignore
48+ assert decoded .mode == "RGB"
3849
3950 # Ensure no saturation. The green channel (band = 1) must not max out.
4051 assert max (iter (decoded .getdata (band = 1 ))) < 255 # type: ignore
@@ -53,7 +64,7 @@ def test_encode_decode_images(autoencoder: LatentDiffusionAutoencoder, sample_im
5364
5465@no_grad ()
5566def test_tiled_autoencoder (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
56- sample_image = sample_image .resize ((2048 , 2048 )) # type: ignore
67+ sample_image = sample_image .resize ((2048 , 2048 ))
5768
5869 with autoencoder .tiled_inference (sample_image , tile_size = (512 , 512 )):
5970 encoded = autoencoder .tiled_image_to_latents (sample_image )
@@ -64,7 +75,7 @@ def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image
6475
6576@no_grad ()
6677def test_tiled_autoencoder_rectangular_tiles (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
67- sample_image = sample_image .resize ((2048 , 2048 )) # type: ignore
78+ sample_image = sample_image .resize ((2048 , 2048 ))
6879
6980 with autoencoder .tiled_inference (sample_image , tile_size = (512 , 1024 )):
7081 encoded = autoencoder .tiled_image_to_latents (sample_image )
@@ -75,7 +86,7 @@ def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoenc
7586
7687@no_grad ()
7788def test_tiled_autoencoder_large_tile (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
78- sample_image = sample_image .resize ((1024 , 1024 )) # type: ignore
89+ sample_image = sample_image .resize ((1024 , 1024 ))
7990
8091 with autoencoder .tiled_inference (sample_image , tile_size = (2048 , 2048 )):
8192 encoded = autoencoder .tiled_image_to_latents (sample_image )
@@ -87,7 +98,7 @@ def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, s
8798@no_grad ()
8899def test_tiled_autoencoder_rectangular_image (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ):
89100 sample_image = sample_image .crop ((0 , 0 , 300 , 500 ))
90- sample_image = sample_image .resize ((sample_image .width * 4 , sample_image .height * 4 )) # type: ignore
101+ sample_image = sample_image .resize ((sample_image .width * 4 , sample_image .height * 4 ))
91102
92103 with autoencoder .tiled_inference (sample_image , tile_size = (512 , 512 )):
93104 encoded = autoencoder .tiled_image_to_latents (sample_image )
0 commit comments