@@ -34,15 +34,10 @@ def autoencoder(
3434 test_dtype_fp32_bf16_fp16 : torch .dtype ,
3535) -> LatentDiffusionAutoencoder :
3636 model_version = request .param
37- match (model_version , test_dtype_fp32_bf16_fp16 ):
38- case ("SD1.5" , _):
39- return refiners_sd15_autoencoder .to (device = test_device , dtype = test_dtype_fp32_bf16_fp16 )
40- case ("SDXL" , torch .float16 ):
41- return refiners_sdxl_autoencoder .to (device = test_device , dtype = test_dtype_fp32_bf16_fp16 )
42- case ("SDXL" , _):
43- return refiners_sdxl_autoencoder .to (device = test_device , dtype = test_dtype_fp32_bf16_fp16 )
44- case _:
45- raise ValueError (f"Unknown model version: { model_version } " )
37+ if model_version == "SDXL" and test_dtype_fp32_bf16_fp16 == torch .float16 :
38+ pytest .skip ("SDXL autoencoder does not support float16" )
39+ ae = refiners_sd15_autoencoder if model_version == "SD1.5" else refiners_sdxl_autoencoder
40+ return ae .to (device = test_device , dtype = test_dtype_fp32_bf16_fp16 )
4641
4742
4843@no_grad ()
@@ -112,30 +107,6 @@ def test_tiled_autoencoder_rectangular_image(autoencoder: LatentDiffusionAutoenc
112107 ensure_similar_images (sample_image , result , min_psnr = 37 , min_ssim = 0.985 )
113108
114109
115- @no_grad ()
116- @pytest .fixture (scope = "module" , params = [240 , 242 , 244 , 254 , 256 , 258 ])
117- def test_tiled_autoencoder_pathologic_sizes (
118- request : pytest .FixtureRequest ,
119- refiners_sd15_autoencoder : SD1Autoencoder ,
120- sample_image : Image .Image ,
121- test_device : torch .device ,
122- ):
123- # 242 is a tile just larger than (tile size - overlap).
124- # 242 * 4 = 968 = (128 - 8 + 1) * 8
125- tile_w = request .param
126-
127- autoencoder = refiners_sd15_autoencoder .to (device = test_device , dtype = torch .float32 )
128-
129- sample_image = sample_image .crop ((0 , 0 , tile_w , 400 ))
130- sample_image = sample_image .resize ((sample_image .width * 4 , sample_image .height * 4 ))
131-
132- with autoencoder .tiled_inference (sample_image , tile_size = (1024 , 1024 )):
133- encoded = autoencoder .tiled_image_to_latents (sample_image )
134- result = autoencoder .tiled_latents_to_image (encoded )
135-
136- ensure_similar_images (sample_image , result , min_psnr = 37 , min_ssim = 0.985 )
137-
138-
139110def test_value_error_tile_encode_no_context (autoencoder : LatentDiffusionAutoencoder , sample_image : Image .Image ) -> None :
140111 with pytest .raises (ValueError ):
141112 autoencoder .tiled_image_to_latents (sample_image )
0 commit comments