Skip to content

Commit 0e9762d

Browse files
committed
Merge remote-tracking branch 'origin/main' into packing
2 parents 8fe72de + 9a8fe14 commit 0e9762d

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

src/cogkit/api/services/image_generation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,17 @@ def generate(
5858
if model not in self._models:
5959
raise ValueError(f"Model {model} not loaded")
6060
width, height = list(map(int, size.split("x")))
61-
6261
# TODO: Refactor this to switch by LoRA endpoint API
6362
if lora_path != self._current_lora[model]:
6463
if lora_path is not None:
6564
adapter_name = os.path.basename(lora_path)
66-
_logger.info(f"Loading LORA weights from {adapter_name}")
65+
_logger.info(
66+
f"Loading LORA weights from {adapter_name} and unload previous weights {self._current_lora[model]}"
67+
)
68+
unload_lora_checkpoint(self._models[model])
6769
load_lora_checkpoint(self._models[model], lora_path, lora_scale)
6870
else:
69-
_logger.info("Unloading LORA weights")
71+
_logger.info(f"Unloading LORA weights {self._current_lora[model]}")
7072
unload_lora_checkpoint(self._models[model])
7173

7274
self._current_lora[model] = lora_path

src/cogkit/utils/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ def load_lora_checkpoint(
99
lora_model_id_or_path: str,
1010
lora_scale: float = 1.0,
1111
) -> None:
12-
pipeline.load_lora_weights(lora_model_id_or_path)
13-
pipeline.fuse_lora(components=["transformer"], lora_scale=lora_scale)
12+
pipeline.load_lora_weights(lora_model_id_or_path, lora_scale=lora_scale)
13+
# pipeline.fuse_lora(components=["transformer"], lora_scale=lora_scale)
1414

1515

1616
def unload_lora_checkpoint(

web/infer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,7 @@ def get_lora_paths():
5959
if not os.path.exists(lora_dir):
6060
os.makedirs(lora_dir, exist_ok=True)
6161
return ["None"]
62-
checkpoint_dirs = [
63-
d
64-
for d in os.listdir(lora_dir)
65-
if os.path.isdir(os.path.join(lora_dir, d)) and d.startswith("checkpoint")
66-
]
62+
checkpoint_dirs = [d for d in os.listdir(lora_dir) if os.path.isdir(os.path.join(lora_dir, d))]
6763

6864
if not checkpoint_dirs:
6965
return ["None"]
@@ -206,7 +202,7 @@ def refresh_lora_paths():
206202
minimum=1,
207203
maximum=50,
208204
step=1,
209-
value=1,
205+
value=50,
210206
)
211207
guidance_scale = gr.Slider(
212208
label="Guidance Scale",

0 commit comments

Comments
 (0)