Skip to content

Commit 9a8fe14

Browse files
Merge pull request #20 from THUDM/dev
Fix lora weight Loading and Unloading
2 parents 5d85c63 + 6fde85f commit 9a8fe14

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/cogkit/api/services/image_generation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,17 @@ def generate(
5757
if model not in self._models:
5858
raise ValueError(f"Model {model} not loaded")
5959
width, height = list(map(int, size.split("x")))
60-
6160
# TODO: Refactor this to switch by LoRA endpoint API
6261
if lora_path != self._current_lora[model]:
6362
if lora_path is not None:
6463
adapter_name = os.path.basename(lora_path)
65-
_logger.info(f"Loading LORA weights from {adapter_name}")
64+
_logger.info(
65+
f"Loading LORA weights from {adapter_name} and unload previous weights {self._current_lora[model]}"
66+
)
67+
unload_lora_checkpoint(self._models[model])
6668
load_lora_checkpoint(self._models[model], lora_path, lora_scale)
6769
else:
68-
_logger.info("Unloading LORA weights")
70+
_logger.info(f"Unloading LORA weights {self._current_lora[model]}")
6971
unload_lora_checkpoint(self._models[model])
7072

7173
self._current_lora[model] = lora_path

src/cogkit/utils/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ def load_lora_checkpoint(
99
lora_model_id_or_path: str,
1010
lora_scale: float = 1.0,
1111
) -> None:
12-
pipeline.load_lora_weights(lora_model_id_or_path)
13-
pipeline.fuse_lora(components=["transformer"], lora_scale=lora_scale)
12+
pipeline.load_lora_weights(lora_model_id_or_path, lora_scale=lora_scale)
13+
# pipeline.fuse_lora(components=["transformer"], lora_scale=lora_scale)
1414

1515

1616
def unload_lora_checkpoint(

0 commit comments

Comments
 (0)