File tree Expand file tree Collapse file tree 2 files changed +7
-5
lines changed
Expand file tree Collapse file tree 2 files changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -57,15 +57,17 @@ def generate(
5757 if model not in self ._models :
5858 raise ValueError (f"Model { model } not loaded" )
5959 width , height = list (map (int , size .split ("x" )))
60-
6160 # TODO: Refactor this to switch by LoRA endpoint API
6261 if lora_path != self ._current_lora [model ]:
6362 if lora_path is not None :
6463 adapter_name = os .path .basename (lora_path )
65- _logger .info (f"Loading LORA weights from { adapter_name } " )
64+ _logger .info (
65+ f"Loading LORA weights from { adapter_name } and unload previous weights { self ._current_lora [model ]} "
66+ )
67+ unload_lora_checkpoint (self ._models [model ])
6668 load_lora_checkpoint (self ._models [model ], lora_path , lora_scale )
6769 else :
68- _logger .info ("Unloading LORA weights" )
70+ _logger .info (f "Unloading LORA weights { self . _current_lora [ model ] } " )
6971 unload_lora_checkpoint (self ._models [model ])
7072
7173 self ._current_lora [model ] = lora_path
Original file line number Diff line number Diff line change @@ -9,8 +9,8 @@ def load_lora_checkpoint(
99 lora_model_id_or_path : str ,
1010 lora_scale : float = 1.0 ,
1111) -> None :
12- pipeline .load_lora_weights (lora_model_id_or_path )
13- pipeline .fuse_lora (components = ["transformer" ], lora_scale = lora_scale )
12+ pipeline .load_lora_weights (lora_model_id_or_path , lora_scale = lora_scale )
13+ # pipeline.fuse_lora(components=["transformer"], lora_scale=lora_scale)
1414
1515
1616def unload_lora_checkpoint (
You can’t perform that action at this time.
0 commit comments