@@ -29,74 +29,11 @@ def ref_path(test_e2e_path: Path) -> Path:
2929 return test_e2e_path / "test_doc_examples_ref"
3030
3131
32- @pytest .fixture (scope = "module" )
33- def sdxl_text_encoder_weights (test_weights_path : Path ) -> Path :
34- path = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
35- if not path .is_file ():
36- warn (message = f"could not find weights at { path } , skipping" )
37- pytest .skip (allow_module_level = True )
38- return path
39-
40-
41- @pytest .fixture (scope = "module" )
42- def sdxl_lda_fp16_fix_weights (test_weights_path : Path ) -> Path :
43- path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
44- if not path .is_file ():
45- warn (message = f"could not find weights at { path } , skipping" )
46- pytest .skip (allow_module_level = True )
47- return path
48-
49-
50- @pytest .fixture (scope = "module" )
51- def sdxl_unet_weights (test_weights_path : Path ) -> Path :
52- path = test_weights_path / "sdxl-unet.safetensors"
53- if not path .is_file ():
54- warn (message = f"could not find weights at { path } , skipping" )
55- pytest .skip (allow_module_level = True )
56- return path
57-
58-
59- @pytest .fixture (scope = "module" )
60- def sdxl_ip_adapter_plus_weights (test_weights_path : Path ) -> Path :
61- path = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
62- if not path .is_file ():
63- warn (f"could not find weights at { path } , skipping" )
64- pytest .skip (allow_module_level = True )
65- return path
66-
67-
68- @pytest .fixture (scope = "module" )
69- def image_encoder_weights (test_weights_path : Path ) -> Path :
70- path = test_weights_path / "CLIPImageEncoderH.safetensors"
71- if not path .is_file ():
72- warn (f"could not find weights at { path } , skipping" )
73- pytest .skip (allow_module_level = True )
74- return path
75-
76-
77- @pytest .fixture
78- def scifi_lora_weights (test_weights_path : Path ) -> Path :
79- path = test_weights_path / "loras" / "Sci-fi_Environments_sdxl.safetensors"
80- if not path .is_file ():
81- warn (message = f"could not find weights at { path } , skipping" )
82- pytest .skip (allow_module_level = True )
83- return path
84-
85-
86- @pytest .fixture
87- def pixelart_lora_weights (test_weights_path : Path ) -> Path :
88- path = test_weights_path / "loras" / "pixel-art-xl-v1.1.safetensors"
89- if not path .is_file ():
90- warn (message = f"could not find weights at { path } , skipping" )
91- pytest .skip (allow_module_level = True )
92- return path
93-
94-
9532@pytest .fixture
9633def sdxl (
97- sdxl_text_encoder_weights : Path ,
98- sdxl_lda_fp16_fix_weights : Path ,
99- sdxl_unet_weights : Path ,
34+ sdxl_text_encoder_weights_path : Path ,
35+ sdxl_autoencoder_fp16fix_weights_path : Path ,
36+ sdxl_unet_weights_path : Path ,
10037 test_device : torch .device ,
10138) -> StableDiffusion_XL :
10239 if test_device .type == "cpu" :
@@ -105,9 +42,9 @@ def sdxl(
10542
10643 sdxl = StableDiffusion_XL (device = test_device , dtype = torch .float16 )
10744
108- sdxl .clip_text_encoder .load_from_safetensors (tensors_path = sdxl_text_encoder_weights )
109- sdxl .lda .load_from_safetensors (tensors_path = sdxl_lda_fp16_fix_weights )
110- sdxl .unet .load_from_safetensors (tensors_path = sdxl_unet_weights )
45+ sdxl .clip_text_encoder .load_from_safetensors (tensors_path = sdxl_text_encoder_weights_path )
46+ sdxl .lda .load_from_safetensors (tensors_path = sdxl_autoencoder_fp16fix_weights_path )
47+ sdxl .unet .load_from_safetensors (tensors_path = sdxl_unet_weights_path )
11148
11249 return sdxl
11350
@@ -180,7 +117,7 @@ def test_guide_adapting_sdxl_vanilla(
180117def test_guide_adapting_sdxl_single_lora (
181118 test_device : torch .device ,
182119 sdxl : StableDiffusion_XL ,
183- scifi_lora_weights : Path ,
120+ lora_scifi_weights_path : Path ,
184121 expected_image_guide_adapting_sdxl_single_lora : Image .Image ,
185122) -> None :
186123 if test_device .type == "cpu" :
@@ -195,7 +132,7 @@ def test_guide_adapting_sdxl_single_lora(
195132 sdxl .set_self_attention_guidance (enable = True , scale = 0.75 )
196133
197134 manager = SDLoraManager (sdxl )
198- manager .add_loras ("scifi-lora" , load_from_safetensors (scifi_lora_weights ))
135+ manager .add_loras ("scifi-lora" , load_from_safetensors (lora_scifi_weights_path ))
199136
200137 clip_text_embedding , pooled_text_embedding = sdxl .compute_clip_text_embedding (
201138 text = prompt + ", best quality, high quality" ,
@@ -222,8 +159,8 @@ def test_guide_adapting_sdxl_single_lora(
222159def test_guide_adapting_sdxl_multiple_loras (
223160 test_device : torch .device ,
224161 sdxl : StableDiffusion_XL ,
225- scifi_lora_weights : Path ,
226- pixelart_lora_weights : Path ,
162+ lora_scifi_weights_path : Path ,
163+ lora_pixelart_weights_path : Path ,
227164 expected_image_guide_adapting_sdxl_multiple_loras : Image .Image ,
228165) -> None :
229166 if test_device .type == "cpu" :
@@ -238,8 +175,8 @@ def test_guide_adapting_sdxl_multiple_loras(
238175 sdxl .set_self_attention_guidance (enable = True , scale = 0.75 )
239176
240177 manager = SDLoraManager (sdxl )
241- manager .add_loras ("scifi-lora" , load_from_safetensors (scifi_lora_weights ))
242- manager .add_loras ("pixel-art-lora" , load_from_safetensors (pixelart_lora_weights ), scale = 1.4 )
178+ manager .add_loras ("scifi-lora" , load_from_safetensors (lora_scifi_weights_path ))
179+ manager .add_loras ("pixel-art-lora" , load_from_safetensors (lora_pixelart_weights_path ), scale = 1.4 )
243180
244181 clip_text_embedding , pooled_text_embedding = sdxl .compute_clip_text_embedding (
245182 text = prompt + ", best quality, high quality" ,
@@ -266,10 +203,10 @@ def test_guide_adapting_sdxl_multiple_loras(
266203def test_guide_adapting_sdxl_loras_ip_adapter (
267204 test_device : torch .device ,
268205 sdxl : StableDiffusion_XL ,
269- sdxl_ip_adapter_plus_weights : Path ,
270- image_encoder_weights : Path ,
271- scifi_lora_weights : Path ,
272- pixelart_lora_weights : Path ,
206+ ip_adapter_sdxl_plus_weights_path : Path ,
207+ clip_image_encoder_huge_weights_path : Path ,
208+ lora_scifi_weights_path : Path ,
209+ lora_pixelart_weights_path : Path ,
273210 image_prompt_german_castle : Image .Image ,
274211 expected_image_guide_adapting_sdxl_loras_ip_adapter : Image .Image ,
275212) -> None :
@@ -285,16 +222,16 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
285222 sdxl .set_self_attention_guidance (enable = True , scale = 0.75 )
286223
287224 manager = SDLoraManager (sdxl )
288- manager .add_loras ("scifi-lora" , load_from_safetensors (scifi_lora_weights ), scale = 1.5 )
289- manager .add_loras ("pixel-art-lora" , load_from_safetensors (pixelart_lora_weights ), scale = 1.55 )
225+ manager .add_loras ("scifi-lora" , load_from_safetensors (lora_scifi_weights_path ), scale = 1.5 )
226+ manager .add_loras ("pixel-art-lora" , load_from_safetensors (lora_pixelart_weights_path ), scale = 1.55 )
290227
291228 ip_adapter = SDXLIPAdapter (
292229 target = sdxl .unet ,
293- weights = load_from_safetensors (sdxl_ip_adapter_plus_weights ),
230+ weights = load_from_safetensors (ip_adapter_sdxl_plus_weights_path ),
294231 scale = 1.0 ,
295232 fine_grained = True ,
296233 )
297- ip_adapter .clip_image_encoder .load_from_safetensors (image_encoder_weights )
234+ ip_adapter .clip_image_encoder .load_from_safetensors (clip_image_encoder_huge_weights_path )
298235 ip_adapter .inject ()
299236
300237 clip_text_embedding , pooled_text_embedding = sdxl .compute_clip_text_embedding (
0 commit comments