Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit 293458a

Browse files
author
Laurent
committed
update tests to use new fixtures
1 parent a5096b2 commit 293458a

26 files changed

+834
-1068
lines changed

tests/adapters/test_lora_manager.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from pathlib import Path
2-
from warnings import warn
32

43
import pytest
54
import torch
@@ -16,14 +15,8 @@ def manager() -> SDLoraManager:
1615

1716

1817
@pytest.fixture
19-
def weights(test_weights_path: Path) -> dict[str, torch.Tensor]:
20-
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
21-
22-
if not weights_path.is_file():
23-
warn(f"could not find weights at {weights_path}, skipping")
24-
pytest.skip(allow_module_level=True)
25-
26-
return load_tensors(weights_path)
18+
def weights(lora_pokemon_weights_path: Path) -> dict[str, torch.Tensor]:
19+
return load_tensors(lora_pokemon_weights_path)
2720

2821

2922
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:

tests/e2e/test_diffusion.py

Lines changed: 219 additions & 325 deletions
Large diffs are not rendered by default.

tests/e2e/test_doc_examples.py

Lines changed: 20 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -29,74 +29,11 @@ def ref_path(test_e2e_path: Path) -> Path:
2929
return test_e2e_path / "test_doc_examples_ref"
3030

3131

32-
@pytest.fixture(scope="module")
33-
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
34-
path = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
35-
if not path.is_file():
36-
warn(message=f"could not find weights at {path}, skipping")
37-
pytest.skip(allow_module_level=True)
38-
return path
39-
40-
41-
@pytest.fixture(scope="module")
42-
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
43-
path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
44-
if not path.is_file():
45-
warn(message=f"could not find weights at {path}, skipping")
46-
pytest.skip(allow_module_level=True)
47-
return path
48-
49-
50-
@pytest.fixture(scope="module")
51-
def sdxl_unet_weights(test_weights_path: Path) -> Path:
52-
path = test_weights_path / "sdxl-unet.safetensors"
53-
if not path.is_file():
54-
warn(message=f"could not find weights at {path}, skipping")
55-
pytest.skip(allow_module_level=True)
56-
return path
57-
58-
59-
@pytest.fixture(scope="module")
60-
def sdxl_ip_adapter_plus_weights(test_weights_path: Path) -> Path:
61-
path = test_weights_path / "ip-adapter-plus_sdxl_vit-h.safetensors"
62-
if not path.is_file():
63-
warn(f"could not find weights at {path}, skipping")
64-
pytest.skip(allow_module_level=True)
65-
return path
66-
67-
68-
@pytest.fixture(scope="module")
69-
def image_encoder_weights(test_weights_path: Path) -> Path:
70-
path = test_weights_path / "CLIPImageEncoderH.safetensors"
71-
if not path.is_file():
72-
warn(f"could not find weights at {path}, skipping")
73-
pytest.skip(allow_module_level=True)
74-
return path
75-
76-
77-
@pytest.fixture
78-
def scifi_lora_weights(test_weights_path: Path) -> Path:
79-
path = test_weights_path / "loras" / "Sci-fi_Environments_sdxl.safetensors"
80-
if not path.is_file():
81-
warn(message=f"could not find weights at {path}, skipping")
82-
pytest.skip(allow_module_level=True)
83-
return path
84-
85-
86-
@pytest.fixture
87-
def pixelart_lora_weights(test_weights_path: Path) -> Path:
88-
path = test_weights_path / "loras" / "pixel-art-xl-v1.1.safetensors"
89-
if not path.is_file():
90-
warn(message=f"could not find weights at {path}, skipping")
91-
pytest.skip(allow_module_level=True)
92-
return path
93-
94-
9532
@pytest.fixture
9633
def sdxl(
97-
sdxl_text_encoder_weights: Path,
98-
sdxl_lda_fp16_fix_weights: Path,
99-
sdxl_unet_weights: Path,
34+
sdxl_text_encoder_weights_path: Path,
35+
sdxl_autoencoder_fp16fix_weights_path: Path,
36+
sdxl_unet_weights_path: Path,
10037
test_device: torch.device,
10138
) -> StableDiffusion_XL:
10239
if test_device.type == "cpu":
@@ -105,9 +42,9 @@ def sdxl(
10542

10643
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16)
10744

108-
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
109-
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
110-
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
45+
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights_path)
46+
sdxl.lda.load_from_safetensors(tensors_path=sdxl_autoencoder_fp16fix_weights_path)
47+
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights_path)
11148

11249
return sdxl
11350

@@ -180,7 +117,7 @@ def test_guide_adapting_sdxl_vanilla(
180117
def test_guide_adapting_sdxl_single_lora(
181118
test_device: torch.device,
182119
sdxl: StableDiffusion_XL,
183-
scifi_lora_weights: Path,
120+
lora_scifi_weights_path: Path,
184121
expected_image_guide_adapting_sdxl_single_lora: Image.Image,
185122
) -> None:
186123
if test_device.type == "cpu":
@@ -195,7 +132,7 @@ def test_guide_adapting_sdxl_single_lora(
195132
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
196133

197134
manager = SDLoraManager(sdxl)
198-
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
135+
manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path))
199136

200137
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
201138
text=prompt + ", best quality, high quality",
@@ -222,8 +159,8 @@ def test_guide_adapting_sdxl_single_lora(
222159
def test_guide_adapting_sdxl_multiple_loras(
223160
test_device: torch.device,
224161
sdxl: StableDiffusion_XL,
225-
scifi_lora_weights: Path,
226-
pixelart_lora_weights: Path,
162+
lora_scifi_weights_path: Path,
163+
lora_pixelart_weights_path: Path,
227164
expected_image_guide_adapting_sdxl_multiple_loras: Image.Image,
228165
) -> None:
229166
if test_device.type == "cpu":
@@ -238,8 +175,8 @@ def test_guide_adapting_sdxl_multiple_loras(
238175
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
239176

240177
manager = SDLoraManager(sdxl)
241-
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights))
242-
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.4)
178+
manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path))
179+
manager.add_loras("pixel-art-lora", load_from_safetensors(lora_pixelart_weights_path), scale=1.4)
243180

244181
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
245182
text=prompt + ", best quality, high quality",
@@ -266,10 +203,10 @@ def test_guide_adapting_sdxl_multiple_loras(
266203
def test_guide_adapting_sdxl_loras_ip_adapter(
267204
test_device: torch.device,
268205
sdxl: StableDiffusion_XL,
269-
sdxl_ip_adapter_plus_weights: Path,
270-
image_encoder_weights: Path,
271-
scifi_lora_weights: Path,
272-
pixelart_lora_weights: Path,
206+
ip_adapter_sdxl_plus_weights_path: Path,
207+
clip_image_encoder_huge_weights_path: Path,
208+
lora_scifi_weights_path: Path,
209+
lora_pixelart_weights_path: Path,
273210
image_prompt_german_castle: Image.Image,
274211
expected_image_guide_adapting_sdxl_loras_ip_adapter: Image.Image,
275212
) -> None:
@@ -285,16 +222,16 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
285222
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
286223

287224
manager = SDLoraManager(sdxl)
288-
manager.add_loras("scifi-lora", load_from_safetensors(scifi_lora_weights), scale=1.5)
289-
manager.add_loras("pixel-art-lora", load_from_safetensors(pixelart_lora_weights), scale=1.55)
225+
manager.add_loras("scifi-lora", load_from_safetensors(lora_scifi_weights_path), scale=1.5)
226+
manager.add_loras("pixel-art-lora", load_from_safetensors(lora_pixelart_weights_path), scale=1.55)
290227

291228
ip_adapter = SDXLIPAdapter(
292229
target=sdxl.unet,
293-
weights=load_from_safetensors(sdxl_ip_adapter_plus_weights),
230+
weights=load_from_safetensors(ip_adapter_sdxl_plus_weights_path),
294231
scale=1.0,
295232
fine_grained=True,
296233
)
297-
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
234+
ip_adapter.clip_image_encoder.load_from_safetensors(clip_image_encoder_huge_weights_path)
298235
ip_adapter.inject()
299236

300237
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(

tests/e2e/test_lcm.py

Lines changed: 22 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -26,51 +26,6 @@ def ensure_gc():
2626
gc.collect()
2727

2828

29-
@pytest.fixture
30-
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
31-
r = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
32-
if not r.is_file():
33-
warn(f"could not find weights at {r}, skipping")
34-
pytest.skip(allow_module_level=True)
35-
return r
36-
37-
38-
@pytest.fixture
39-
def sdxl_unet_weights(test_weights_path: Path) -> Path:
40-
r = test_weights_path / "sdxl-unet.safetensors"
41-
if not r.is_file():
42-
warn(f"could not find weights at {r}, skipping")
43-
pytest.skip(allow_module_level=True)
44-
return r
45-
46-
47-
@pytest.fixture
48-
def sdxl_lcm_unet_weights(test_weights_path: Path) -> Path:
49-
r = test_weights_path / "sdxl-lcm-unet.safetensors"
50-
if not r.is_file():
51-
warn(f"could not find weights at {r}, skipping")
52-
pytest.skip(allow_module_level=True)
53-
return r
54-
55-
56-
@pytest.fixture
57-
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
58-
r = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
59-
if not r.is_file():
60-
warn(f"could not find weights at {r}, skipping")
61-
pytest.skip(allow_module_level=True)
62-
return r
63-
64-
65-
@pytest.fixture
66-
def sdxl_lcm_lora_weights(test_weights_path: Path) -> Path:
67-
r = test_weights_path / "sdxl-lcm-lora.safetensors"
68-
if not r.is_file():
69-
warn(f"could not find weights at {r}, skipping")
70-
pytest.skip(allow_module_level=True)
71-
return r
72-
73-
7429
@pytest.fixture(scope="module")
7530
def ref_path(test_e2e_path: Path) -> Path:
7631
return test_e2e_path / "test_lcm_ref"
@@ -94,9 +49,9 @@ def expected_lcm_lora_1_2(ref_path: Path) -> Image.Image:
9449
@no_grad()
9550
def test_lcm_base(
9651
test_device: torch.device,
97-
sdxl_lda_fp16_fix_weights: Path,
98-
sdxl_lcm_unet_weights: Path,
99-
sdxl_text_encoder_weights: Path,
52+
sdxl_autoencoder_fp16fix_weights_path: Path,
53+
sdxl_unet_lcm_weights_path: Path,
54+
sdxl_text_encoder_weights_path: Path,
10055
expected_lcm_base: Image.Image,
10156
) -> None:
10257
if test_device.type == "cpu":
@@ -111,9 +66,9 @@ def test_lcm_base(
11166
# not in the diffusion loop.
11267
SDXLLcmAdapter(sdxl.unet, condition_scale=8.0).inject()
11368

114-
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
115-
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
116-
sdxl.unet.load_from_safetensors(sdxl_lcm_unet_weights)
69+
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
70+
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
71+
sdxl.unet.load_from_safetensors(sdxl_unet_lcm_weights_path)
11772

11873
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
11974
expected_image = expected_lcm_base
@@ -141,10 +96,10 @@ def test_lcm_base(
14196
@pytest.mark.parametrize("condition_scale", [1.0, 1.2])
14297
def test_lcm_lora_with_guidance(
14398
test_device: torch.device,
144-
sdxl_lda_fp16_fix_weights: Path,
145-
sdxl_unet_weights: Path,
146-
sdxl_text_encoder_weights: Path,
147-
sdxl_lcm_lora_weights: Path,
99+
sdxl_autoencoder_fp16fix_weights_path: Path,
100+
sdxl_unet_weights_path: Path,
101+
sdxl_text_encoder_weights_path: Path,
102+
lora_sdxl_lcm_weights_path: Path,
148103
expected_lcm_lora_1_0: Image.Image,
149104
expected_lcm_lora_1_2: Image.Image,
150105
condition_scale: float,
@@ -156,12 +111,12 @@ def test_lcm_lora_with_guidance(
156111
solver = LCMSolver(num_inference_steps=4)
157112
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
158113

159-
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
160-
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
161-
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
114+
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
115+
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
116+
sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
162117

163118
manager = SDLoraManager(sdxl)
164-
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
119+
add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lcm_weights_path))
165120

166121
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
167122
expected_image = expected_lcm_lora_1_0 if condition_scale == 1.0 else expected_lcm_lora_1_2
@@ -191,10 +146,10 @@ def test_lcm_lora_with_guidance(
191146
@no_grad()
192147
def test_lcm_lora_without_guidance(
193148
test_device: torch.device,
194-
sdxl_lda_fp16_fix_weights: Path,
195-
sdxl_unet_weights: Path,
196-
sdxl_text_encoder_weights: Path,
197-
sdxl_lcm_lora_weights: Path,
149+
sdxl_autoencoder_fp16fix_weights_path: Path,
150+
sdxl_unet_weights_path: Path,
151+
sdxl_text_encoder_weights_path: Path,
152+
lora_sdxl_lcm_weights_path: Path,
198153
expected_lcm_lora_1_0: Image.Image,
199154
) -> None:
200155
if test_device.type == "cpu":
@@ -205,12 +160,12 @@ def test_lcm_lora_without_guidance(
205160
sdxl = StableDiffusion_XL(device=test_device, dtype=torch.float16, solver=solver)
206161
sdxl.classifier_free_guidance = False
207162

208-
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights)
209-
sdxl.lda.load_from_safetensors(sdxl_lda_fp16_fix_weights)
210-
sdxl.unet.load_from_safetensors(sdxl_unet_weights)
163+
sdxl.clip_text_encoder.load_from_safetensors(sdxl_text_encoder_weights_path)
164+
sdxl.lda.load_from_safetensors(sdxl_autoencoder_fp16fix_weights_path)
165+
sdxl.unet.load_from_safetensors(sdxl_unet_weights_path)
211166

212167
manager = SDLoraManager(sdxl)
213-
add_lcm_lora(manager, load_from_safetensors(sdxl_lcm_lora_weights))
168+
add_lcm_lora(manager, load_from_safetensors(lora_sdxl_lcm_weights_path))
214169

215170
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
216171
expected_image = expected_lcm_lora_1_0

0 commit comments

Comments
 (0)