Skip to content

Commit bf3f2e1

Browse files
committed
support loading multiple sd loras (up to 4 at once)
1 parent a089284 commit bf3f2e1

3 files changed

Lines changed: 105 additions & 48 deletions

File tree

expose.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ const int images_max = 8;
66
const int audio_max = 4;
77
const int logprobs_max = 10;
88
const int overridekv_max = 4;
9+
const int lora_filenames_max = 4;
910

1011
// match kobold's sampler list and order
1112
enum samplers
@@ -188,7 +189,7 @@ struct sd_load_model_inputs
188189
const char * clip1_filename = nullptr;
189190
const char * clip2_filename = nullptr;
190191
const char * vae_filename = nullptr;
191-
const char * lora_filename = nullptr;
192+
const char * lora_filenames[lora_filenames_max] = {};
192193
const float lora_multiplier = 1.0f;
193194
const int lora_apply_mode = 0;
194195
const char * photomaker_filename = nullptr;

koboldcpp.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
default_native_ctx = 16384
6060
overridekv_max = 4
6161
default_autofit_padding = 1024
62+
lora_filenames_max = 4
6263

6364
# abuse prevention
6465
stop_token_max = 256
@@ -311,7 +312,7 @@ class sd_load_model_inputs(ctypes.Structure):
311312
("clip1_filename", ctypes.c_char_p),
312313
("clip2_filename", ctypes.c_char_p),
313314
("vae_filename", ctypes.c_char_p),
314-
("lora_filename", ctypes.c_char_p),
315+
("lora_filenames", ctypes.c_char_p * lora_filenames_max),
315316
("lora_multiplier", ctypes.c_float),
316317
("lora_apply_mode", ctypes.c_int),
317318
("photomaker_filename", ctypes.c_char_p),
@@ -1931,7 +1932,7 @@ def sd_quant_option(value):
19311932
except Exception:
19321933
return 0
19331934

1934-
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clip1_filename,clip2_filename,photomaker_filename,upscaler_filename):
1935+
def sd_load_model(model_filename,vae_filename,lora_filenames,t5xxl_filename,clip1_filename,clip2_filename,photomaker_filename,upscaler_filename):
19351936
global args
19361937
inputs = sd_load_model_inputs()
19371938
inputs.model_filename = model_filename.encode("UTF-8")
@@ -1954,7 +1955,12 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clip1
19541955
inputs.taesd = True if args.sdvaeauto else False
19551956
inputs.tiled_vae_threshold = args.sdtiledvae
19561957
inputs.vae_filename = vae_filename.encode("UTF-8")
1957-
inputs.lora_filename = lora_filename.encode("UTF-8")
1958+
for n in range(lora_filenames_max):
1959+
if n >= len(lora_filenames):
1960+
inputs.lora_filenames[n] = "".encode("UTF-8")
1961+
else:
1962+
inputs.lora_filenames[n] = lora_filenames[n].encode("UTF-8")
1963+
19581964
inputs.lora_multiplier = args.sdloramult
19591965
inputs.t5xxl_filename = t5xxl_filename.encode("UTF-8")
19601966
inputs.clip1_filename = clip1_filename.encode("UTF-8")
@@ -5173,7 +5179,7 @@ def stop(self):
51735179
sys.exit(0)
51745180

51755181
# Based on https://github.com/mathgeniuszach/xdialog/blob/main/xdialog/zenity_dialogs.py - MIT license | - Expanded version by Henk717
5176-
def zenity(filetypes=None, initialdir="", initialfile="", **kwargs) -> Tuple[int, str]:
5182+
def zenity(filetypes=None, initialdir="", initialfile="", multiple=False, **kwargs) -> Tuple[int, object]:
51775183
global zenity_recent_dir, zenity_permitted
51785184

51795185
if not zenity_permitted:
@@ -5238,6 +5244,10 @@ def zenity_sanity_check(zenity_bin): #make sure zenity is sane
52385244
initialpath = os.path.join(initialdir, initialfile)
52395245
args.append(f'--filename={initialpath}')
52405246

5247+
if multiple:
5248+
args.append("--multiple")
5249+
args.append("--separator=|")
5250+
52415251
clean_env = os.environ.copy()
52425252
clean_env.pop("LD_LIBRARY_PATH", None)
52435253
clean_env["PATH"] = "/usr/bin:/bin"
@@ -5252,15 +5262,18 @@ def zenity_sanity_check(zenity_bin): #make sure zenity is sane
52525262
result = procres.stdout.decode('utf-8').strip()
52535263
if procres.returncode==0 and result:
52545264
directory = result
5255-
if not os.path.isdir(result):
5256-
directory = os.path.dirname(result)
5265+
if multiple:
5266+
result = tuple(result.split("|"))
5267+
directory = result[0]
5268+
if not os.path.isdir(directory):
5269+
directory = os.path.dirname(directory)
52575270
zenity_recent_dir = directory
52585271
return (procres.returncode, result)
52595272

52605273
# note: In this section we wrap around file dialogues to allow for zenity
52615274
def zentk_askopenfilename(**options):
52625275
try:
5263-
result = zenity(filetypes=options.get("filetypes"), initialdir=options.get("initialdir"), title=options.get("title"))[1]
5276+
result = zenity(filetypes=options.get("filetypes"), initialdir=options.get("initialdir"), multiple=False, title=options.get("title"))[1]
52645277
if result and not os.path.isfile(result):
52655278
print("A folder was selected while we need a file, ignoring selection.")
52665279
return ''
@@ -5269,17 +5282,29 @@ def zentk_askopenfilename(**options):
52695282
result = askopenfilename(**options)
52705283
return result
52715284

5285+
def zentk_askopenfilenames(**options):
5286+
try:
5287+
result = zenity(filetypes=options.get("filetypes"), initialdir=options.get("initialdir"), multiple=True, title=options.get("title"))[1]
5288+
for itm in result:
5289+
if itm and not os.path.isfile(itm):
5290+
print("A folder was selected while we need a file, ignoring selection.")
5291+
return ''
5292+
except Exception:
5293+
from tkinter.filedialog import askopenfilenames
5294+
result = askopenfilenames(**options)
5295+
return result
5296+
52725297
def zentk_askdirectory(**options):
52735298
try:
5274-
result = zenity(initialdir=options.get("initialdir"), title=options.get("title"), directory=True)[1]
5299+
result = zenity(initialdir=options.get("initialdir"), multiple=False, title=options.get("title"), directory=True)[1]
52755300
except Exception:
52765301
from tkinter.filedialog import askdirectory
52775302
result = askdirectory(**options)
52785303
return result
52795304

52805305
def zentk_asksaveasfilename(**options):
52815306
try:
5282-
result = zenity(filetypes=options.get("filetypes"), initialdir=options.get("initialdir"), initialfile=options.get("initialfile"), title=options.get("title"), save=True)[1]
5307+
result = zenity(filetypes=options.get("filetypes"), initialdir=options.get("initialdir"), initialfile=options.get("initialfile"), multiple=False, title=options.get("title"), save=True)[1]
52835308
except Exception:
52845309
from tkinter.filedialog import asksaveasfilename
52855310
result = asksaveasfilename(**options)
@@ -5724,7 +5749,7 @@ def makelabelentry(parent, text, var, row=0, width=50, padx=8, singleline=False,
57245749
return entry, label
57255750

57265751
#file dialog types: 0=openfile,1=savefile,2=opendir
5727-
def makefileentry(parent, text, searchtext, var, row=0, width=200, filetypes=[], onchoosefile=None, singlerow=False, singlecol=True, dialog_type=0, tooltiptxt=""):
5752+
def makefileentry(parent, text, searchtext, var, row=0, width=200, filetypes=[], onchoosefile=None, singlerow=False, singlecol=True, dialog_type=0, tooltiptxt="", multiple=False):
57285753
label = makelabel(parent, text, row,0,tooltiptxt,columnspan=3)
57295754
def getfilename(var, text):
57305755
initialDir = os.path.dirname(var.get())
@@ -5740,7 +5765,11 @@ def getfilename(var, text):
57405765
fnam = str(fnam).strip()
57415766
fnam = f"{fnam}.jsondb" if ".jsondb" not in fnam.lower() else fnam
57425767
else:
5743-
fnam = zentk_askopenfilename(title=text,filetypes=filetypes, initialdir=initialDir)
5768+
if multiple:
5769+
fnam = zentk_askopenfilenames(title=text,filetypes=filetypes, initialdir=initialDir)
5770+
fnam = "|".join(fnam)
5771+
else:
5772+
fnam = zentk_askopenfilename(title=text,filetypes=filetypes, initialdir=initialDir)
57445773
if fnam:
57455774
var.set(fnam)
57465775
if onchoosefile:
@@ -6383,7 +6412,7 @@ def togglehorde(a,b,c):
63836412
makelabelcombobox(images_tab, "Compress Weights: ", sd_quant_var, 8, width=(60), padx=(126), labelpadx=8, tooltiptxt="Quantizes the SD model weights to save memory.\nHigher levels save more memory, and cause more quality degradation.", values=sd_quant_choices)
63846413
sd_quant_var.trace_add("write", changed_gpulayers_estimate)
63856414

6386-
makefileentry(images_tab, "Image LoRA:", "Select SD lora file",sd_lora_var, 20, width=160, singlerow=True, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD LoRA model file to be loaded. Should be unquantized!")
6415+
makefileentry(images_tab, "Image LoRA:", "Select SD lora file",sd_lora_var, 20, width=160, singlerow=True, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD LoRA model file to be loaded. Should be unquantized!", multiple=True)
63876416
makelabelentry(images_tab, "Multiplier:" , sd_loramult_var, 20, 50,padx=(390),singleline=True,tooltip="What mutiplier value to apply the SD LoRA with.",labelpadx=(330))
63886417

63896418
makefileentry(images_tab, "T5-XXL File:", "Select T5-XXL model file (SD3, Flux, WAN)",sd_t5xxl_var, 24, width=280, singlerow=True, filetypes=[("*.safetensors *.gguf","*.safetensors *.gguf")],tooltiptxt="Select a .safetensors t5xxl file to be loaded.")
@@ -6711,10 +6740,10 @@ def export_vars():
67116740
args.sdupscaler = sd_upscaler_var.get()
67126741
args.sdquant = sd_quant_option(sd_quant_var.get())
67136742
if sd_lora_var.get() != "":
6714-
args.sdlora = sd_lora_var.get()
6743+
args.sdlora = [item.strip() for item in sd_lora_var.get().split("|") if item]
67156744
args.sdloramult = float(sd_loramult_var.get())
67166745
else:
6717-
args.sdlora = ""
6746+
args.sdlora = None
67186747

67196748
if gen_defaults_var.get() != "":
67206749
args.gendefaults = gen_defaults_var.get()
@@ -6959,8 +6988,13 @@ def import_vars(dict):
69596988
sd_upscaler_var.set(dict["sdupscaler"] if ("sdupscaler" in dict and dict["sdupscaler"]) else "")
69606989
sd_vaeauto_var.set(1 if ("sdvaeauto" in dict and dict["sdvaeauto"]) else 0)
69616990
sd_tiled_vae_var.set(str(dict["sdtiledvae"]) if ("sdtiledvae" in dict and dict["sdtiledvae"]) else str(default_vae_tile_threshold))
6962-
6963-
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
6991+
if "sdlora" in dict and dict["sdlora"]:
6992+
if isinstance((dict["sdlora"]), list):
6993+
sd_lora_var.set("|".join(dict["sdlora"]))
6994+
else:
6995+
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
6996+
else:
6997+
sd_lora_var.set("")
69646998
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
69656999
gen_defaults_var.set(dict["gendefaults"] if ("gendefaults" in dict and dict["gendefaults"]) else "")
69667000
gen_defaults_overwrite_var.set(1 if "gendefaultsoverwrite" in dict and dict["gendefaultsoverwrite"] else 0)
@@ -7401,6 +7435,8 @@ def convert_invalid_args(args):
74017435
dict["gendefaults"] = dict["sdgendefaults"]
74027436
if "flashattention" in dict and "noflashattention" not in dict:
74037437
dict["noflashattention"] = not dict["flashattention"]
7438+
if "sdlora" in dict and isinstance(dict["sdlora"], str):
7439+
dict["sdlora"] = ([dict["sdlora"]] if dict["sdlora"] else None)
74047440
return args
74057441

74067442
def setuptunnel(global_memory, has_sd):
@@ -8220,10 +8256,11 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
82208256
dlfile = download_model_from_url(args.sdvae,[".gguf",".safetensors"],min_file_size=500000)
82218257
if dlfile:
82228258
args.sdvae = dlfile
8223-
if args.sdlora and args.sdlora!="":
8224-
dlfile = download_model_from_url(args.sdlora,[".gguf",".safetensors"],min_file_size=500000)
8225-
if dlfile:
8226-
args.sdlora = dlfile
8259+
if args.sdlora and len(args.sdlora)>0:
8260+
for i in range(0,len(args.sdlora)):
8261+
dlfile = download_model_from_url(args.sdlora[i],[".gguf",".safetensors"],min_file_size=500000)
8262+
if dlfile:
8263+
args.sdlora[i] = dlfile
82278264
if args.mmproj and args.mmproj!="":
82288265
dlfile = download_model_from_url(args.mmproj,[".gguf"],min_file_size=500000)
82298266
if dlfile:
@@ -8499,18 +8536,20 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
84998536
exitcounter = 999
85008537
exit_with_error(2,f"Cannot find image model file: {imgmodel}")
85018538
else:
8502-
imglora = ""
8539+
imgloras = []
85038540
imgvae = ""
85048541
imgt5xxl = ""
85058542
imgclip1 = ""
85068543
imgclip2 = ""
85078544
imgphotomaker = ""
85088545
imgupscaler = ""
8509-
if args.sdlora:
8510-
if os.path.exists(args.sdlora):
8511-
imglora = os.path.abspath(args.sdlora)
8512-
else:
8513-
print("Missing SD LORA model file...")
8546+
if args.sdlora and len(args.sdlora)>0:
8547+
for i in range (0,len(args.sdlora)):
8548+
curr = args.sdlora[i]
8549+
if os.path.exists(curr):
8550+
imgloras.append(os.path.abspath(curr))
8551+
else:
8552+
print(f"Missing SD LORA model file {curr}...")
85148553
if args.sdvae:
85158554
if os.path.exists(args.sdvae):
85168555
imgvae = os.path.abspath(args.sdvae)
@@ -8547,7 +8586,7 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
85478586
friendlysdmodelname = os.path.basename(imgmodel)
85488587
friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0]
85498588
friendlysdmodelname = sanitize_string(friendlysdmodelname)
8550-
loadok = sd_load_model(imgmodel,imgvae,imglora,imgt5xxl,imgclip1,imgclip2,imgphotomaker,imgupscaler)
8589+
loadok = sd_load_model(imgmodel,imgvae,imgloras,imgt5xxl,imgclip1,imgclip2,imgphotomaker,imgupscaler)
85518590
print("Load Image Model OK: " + str(loadok))
85528591
if not loadok:
85538592
exitcounter = 999
@@ -9008,8 +9047,8 @@ def range_checker(arg: str):
90089047
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in tiny VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')
90099048
sdparsergrouplora = sdparsergroup.add_mutually_exclusive_group()
90109049
sdparsergrouplora.add_argument("--sdquant", metavar=('[quantization level 0/1/2]'), help="If specified, loads the model quantized to save memory. 0=off, 1=q8, 2=q4", type=int, choices=[0,1,2], nargs="?", const=2, default=0)
9011-
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify an image generation LORA safetensors model to be applied.", default="")
9012-
sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the image LORA model to be applied.", type=float, default=1.0)
9050+
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify image generation LoRAs safetensors models to be applied. Multiple LoRAs are accepted.", nargs='+')
9051+
sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the image LoRA model to be applied.", type=float, default=1.0)
90139052
sdparsergroup.add_argument("--sdtiledvae", metavar=('[maxres]'), help="Adjust the automatic VAE tiling trigger for images above this size. 0 disables vae tiling.", type=int, default=default_vae_tile_threshold)
90149053
whisperparsergroup = parser.add_argument_group('Whisper Transcription Commands')
90159054
whisperparsergroup.add_argument("--whispermodel", metavar=('[filename]'), help="Specify a Whisper .bin model to enable Speech-To-Text transcription.", default="")

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ struct SDParams {
7878

7979
bool chroma_use_dit_mask = true;
8080

81-
std::string lora_path;
82-
sd_lora_t lora_spec;
81+
std::vector<std::string> lora_paths;
82+
std::vector<sd_lora_t> lora_specs;
8383
uint32_t lora_count;
8484
};
8585

@@ -207,7 +207,15 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
207207
set_sd_quiet(sd_is_quiet);
208208
executable_path = inputs.executable_path;
209209
std::string taesdpath = "";
210-
std::string lorafilename = inputs.lora_filename;
210+
std::vector<std::string> lorafilenames;
211+
for(int i=0;i<lora_filenames_max;++i)
212+
{
213+
std::string temp = inputs.lora_filenames[i];
214+
if(temp!="")
215+
{
216+
lorafilenames.push_back(temp);
217+
}
218+
}
211219
std::string vaefilename = inputs.vae_filename;
212220
std::string t5xxl_filename = inputs.t5xxl_filename;
213221
std::string clip1_filename = inputs.clip1_filename;
@@ -223,13 +231,16 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
223231

224232
int lora_apply_mode = std::max(0, std::min(2, inputs.lora_apply_mode));
225233

226-
if(lorafilename!="")
234+
if(lorafilenames.size()>0)
227235
{
228-
const char* lora_apply_mode_name = lora_apply_mode == 1 ? "immediately"
229-
: lora_apply_mode == 2 ? "at runtime"
230-
: "auto";
231-
printf("With LoRA: %s at %f power, apply mode: %s\n",
232-
lorafilename.c_str(),inputs.lora_multiplier,lora_apply_mode_name);
236+
for(int i=0;i<lorafilenames.size();++i)
237+
{
238+
const char* lora_apply_mode_name = lora_apply_mode == 1 ? "immediately"
239+
: lora_apply_mode == 2 ? "at runtime"
240+
: "auto";
241+
printf("With LoRA: %s at %f power, apply mode: %s\n",
242+
lorafilenames[i].c_str(),inputs.lora_multiplier,lora_apply_mode_name);
243+
}
233244
}
234245
if(inputs.taesd)
235246
{
@@ -315,7 +326,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
315326
sd_params->clip_l_path = clip1_filename;
316327
sd_params->clip_g_path = clip2_filename;
317328
sd_params->stacked_id_embeddings_path = photomaker_filename;
318-
sd_params->lora_path = lorafilename;
329+
sd_params->lora_paths = lorafilenames;
319330
//if t5 is set, and model is a gguf, load it as a diffusion model path
320331
bool endswithgguf = (sd_params->model_path.rfind(".gguf") == sd_params->model_path.size() - 5);
321332
if((sd_params->t5xxl_path!="" || sd_params->clip_l_path!="" || sd_params->clip_g_path!="") && endswithgguf)
@@ -405,15 +416,21 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
405416
std::filesystem::path mpath(inputs.model_filename);
406417
sdmodelfilename = mpath.filename().string();
407418

408-
sd_params->lora_spec = {};
409-
sd_params->lora_spec.path = sd_params->lora_path.c_str();
410-
sd_params->lora_spec.multiplier = inputs.lora_multiplier;
419+
sd_params->lora_specs.clear();
420+
sd_params->lora_specs.reserve(lora_filenames_max*2);
421+
for(int i=0;i<sd_params->lora_paths.size();++i)
422+
{
423+
sd_lora_t spec = {};
424+
spec.path = sd_params->lora_paths[i].c_str();
425+
spec.multiplier = inputs.lora_multiplier;
426+
sd_params->lora_specs.push_back(spec);
427+
}
411428

412-
if(sd_params->lora_path!="" && sd_params->lora_spec.multiplier>0)
429+
if(sd_params->lora_specs.size()>0 && inputs.lora_multiplier>0)
413430
{
414-
printf("\nApply LoRA...\n");
415-
sd_params->lora_count = 1;
416-
sd_ctx->sd->apply_loras(&sd_params->lora_spec, sd_params->lora_count);
431+
printf("\nApply %d LoRAs...\n",sd_params->lora_specs.size());
432+
sd_params->lora_count = sd_params->lora_specs.size();
433+
sd_ctx->sd->apply_loras(sd_params->lora_specs.data(), sd_params->lora_count);
417434
}
418435

419436
input_extraimage_buffers.reserve(max_extra_images);
@@ -1011,7 +1028,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
10111028

10121029
// needs to be "reapplied" because sdcpp tracks previously applied LoRAs
10131030
// and weights, and apply/unapply the differences at each gen
1014-
params.loras = &sd_params->lora_spec;
1031+
params.loras = sd_params->lora_specs.data();
10151032
params.lora_count = sd_params->lora_count;
10161033

10171034
params.ref_images = reference_imgs.data();

0 commit comments

Comments
 (0)