File tree Expand file tree Collapse file tree 3 files changed +9
-11
lines changed
Expand file tree Collapse file tree 3 files changed +9
-11
lines changed Original file line number Diff line number Diff line change @@ -58,15 +58,17 @@ def generate(
5858 if model not in self ._models :
5959 raise ValueError (f"Model { model } not loaded" )
6060 width , height = list (map (int , size .split ("x" )))
61-
6261 # TODO: Refactor this to switch by LoRA endpoint API
6362 if lora_path != self ._current_lora [model ]:
6463 if lora_path is not None :
6564 adapter_name = os .path .basename (lora_path )
66- _logger .info (f"Loading LORA weights from { adapter_name } " )
65+ _logger .info (
66+ f"Loading LORA weights from { adapter_name } and unload previous weights { self ._current_lora [model ]} "
67+ )
68+ unload_lora_checkpoint (self ._models [model ])
6769 load_lora_checkpoint (self ._models [model ], lora_path , lora_scale )
6870 else :
69- _logger .info ("Unloading LORA weights" )
71+ _logger .info (f "Unloading LORA weights { self . _current_lora [ model ] } " )
7072 unload_lora_checkpoint (self ._models [model ])
7173
7274 self ._current_lora [model ] = lora_path
Original file line number Diff line number Diff line change @@ -9,8 +9,8 @@ def load_lora_checkpoint(
99 lora_model_id_or_path : str ,
1010 lora_scale : float = 1.0 ,
1111) -> None :
12- pipeline .load_lora_weights (lora_model_id_or_path )
13- pipeline .fuse_lora (components = ["transformer" ], lora_scale = lora_scale )
12+ pipeline .load_lora_weights (lora_model_id_or_path , lora_scale = lora_scale )
13+ # pipeline.fuse_lora(components=["transformer"], lora_scale=lora_scale)
1414
1515
1616def unload_lora_checkpoint (
Original file line number Diff line number Diff line change @@ -59,11 +59,7 @@ def get_lora_paths():
5959 if not os .path .exists (lora_dir ):
6060 os .makedirs (lora_dir , exist_ok = True )
6161 return ["None" ]
62- checkpoint_dirs = [
63- d
64- for d in os .listdir (lora_dir )
65- if os .path .isdir (os .path .join (lora_dir , d )) and d .startswith ("checkpoint" )
66- ]
62+ checkpoint_dirs = [d for d in os .listdir (lora_dir ) if os .path .isdir (os .path .join (lora_dir , d ))]
6763
6864 if not checkpoint_dirs :
6965 return ["None" ]
@@ -206,7 +202,7 @@ def refresh_lora_paths():
206202 minimum = 1 ,
207203 maximum = 50 ,
208204 step = 1 ,
209- value = 1 ,
205+ value = 50 ,
210206 )
211207 guidance_scale = gr .Slider (
212208 label = "Guidance Scale" ,
You can’t perform that action at this time.
0 commit comments