Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 146 additions & 32 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,33 +135,123 @@ def get_dir_or_set_default(key, default_value):
path_outputs = get_dir_or_set_default('path_outputs', '../outputs/')


def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False, corrector=None):
global config_dict, visited_keys

debug_mode=False
if debug_mode:
print(f"Checking key: {key}")

if key not in visited_keys:
visited_keys.append(key)

if key not in config_dict:
config_dict[key] = default_value
return default_value

v = config_dict.get(key, None)
if debug_mode:
print(f"Value for key {key}: {v}")

if not disable_empty_as_none:
if v is None or v == '':
v = 'None'
v = default_value
if debug_mode:
print(f"Value for key {key} is None or empty, setting to default: {v}")

if validator(v):
return v
if debug_mode:
print(f"Value for key {key} passed validation.")
elif corrector:
corrected_v = corrector(v)
if validator(corrected_v):
if debug_mode:
print(f"Value for key {key} passed validation after correction.")
v = corrected_v
else:
print(f"Failed to load config key after correction. Using default: {default_value}")
v = default_value
else:
if v is not None:
print(f'Failed to load config key: {json.dumps({key:v})} is invalid; will use {json.dumps({key:default_value})} instead.')
config_dict[key] = default_value
return default_value
print(f"Failed to load config key: {json.dumps({key: v})} is invalid. Using default: {default_value}")
v = default_value

config_dict[key] = v
return v

def get_model_filenames(folder_path, name_filter=None):
return get_files_from_folder(folder_path, ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'], name_filter)

def update_all_model_names():
global model_filenames, lora_filenames
model_filenames = get_model_filenames(path_checkpoints)
lora_filenames = get_model_filenames(path_loras)
return

model_filenames = []
lora_filenames = []
update_all_model_names()

def model_validator(value):
if isinstance(value, str) and (value == "" or value in model_filenames):
return True
else :
print(f"model_filenames: {model_filenames}") # Debug print
print(f"failed model_validator: {value}") # Debug printà
return False

def correct_case_sensitivity(value, valid_values):
"""Corrects case sensitivity of a value or list of values based on a list of valid values.
If a valid value is contained in the input value, it replaces it with the full valid value."""
def find_full_match(partial_value):
if isinstance(partial_value, str):
lower_partial_value = partial_value.lower()
for valid_value in valid_values:
if lower_partial_value == valid_value.lower():
return valid_value
elif lower_partial_value in valid_value.lower():
return valid_value
return partial_value

print(f"Initial value: {value}") # Debug print

# If value is a string, apply find_full_match directly
if isinstance(value, str):
corrected_value = find_full_match(value)
print(f"Corrected string value: {corrected_value}") # Debug print
return corrected_value

# If value is a list
elif isinstance(value, list):
print(f"Processing list value: {value}") # Debug print before if cases

# Check if value is a list of lists (as in default_loras)
if all(isinstance(item, list) and len(item) == 2 for item in value):
corrected_list = []
for sub_value in value:
print(f"Processing sub_value: {sub_value}") # Debug print
if isinstance(sub_value, list):
corrected_element = find_full_match(sub_value[0])
print(f"Correcting {sub_value[0]} to {corrected_element}") # Debug print
corrected_list.append([corrected_element, sub_value[1]])
else:
corrected_list.append(sub_value)
print(f"Corrected list of lists: {corrected_list}") # Debug print
return corrected_list

# If value is a regular list
else:
corrected_list = [find_full_match(val) for val in value]
print(f"Corrected regular list: {corrected_list}") # Debug print
return corrected_list

return value



def model_corrector(value):
return correct_case_sensitivity(value, model_filenames)

default_base_model_name = get_config_item_or_set_default(
key='default_model',
default_value='model.safetensors',
validator=lambda x: isinstance(x, str)
validator=model_validator,
corrector=model_corrector
)
previous_default_models = get_config_item_or_set_default(
key='previous_default_models',
Expand All @@ -171,19 +261,40 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_
default_refiner_model_name = get_config_item_or_set_default(
key='default_refiner',
default_value='None',
validator=lambda x: isinstance(x, str)
validator=model_validator,
corrector=model_corrector
)
default_refiner_switch = get_config_item_or_set_default(
key='default_refiner_switch',
default_value=0.8,
validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1
)

def loras_validator(x):
if not isinstance(x, list):
print(f"Validation failed: 'x' is not a list. Value of x: {x}") # Debug print
return False

for y in x:
if not (len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], (numbers.Number, float))):
print(f"Validation failed: Element structure is incorrect. Element: {y}") # Debug print
return False
if y[0] != "None" and y[0] not in lora_filenames:
print(f"Validation failed: Lora filename not found in lora_filenames. Lora filename: {y[0]}") # Debug print
print(f"Available lora_filenames: {lora_filenames}") # Debug print
return False

return True

def loras_corrector(value):
return correct_case_sensitivity(value, lora_filenames)

default_loras = get_config_item_or_set_default(
key='default_loras',
default_value=[
[
"None",
1.0
"sd_xl_offset_example-lora_1.0.safetensors",
0.1
],
[
"None",
Expand Down Expand Up @@ -224,14 +335,33 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_
default_value='karras',
validator=lambda x: x in modules.flags.scheduler_list
)

def sdxl_styles_validator(x):
if not isinstance(x, list):
print("Validation failed: The variable 'x' is not a list.")
print(f"Type of x: {type(x)}")
return False

for y in x:
if y not in modules.sdxl_styles.legal_style_names:
print("Validation failed: An element in 'x' is not in legal_style_names.")
print(f"Failed element: {y}")
return False

return True

def sdxl_styles_corrector(value):
return correct_case_sensitivity(value, modules.sdxl_styles.legal_style_names)

default_styles = get_config_item_or_set_default(
key='default_styles',
default_value=[
"Fooocus V2",
"Fooocus Enhance",
"Fooocus Sharp"
],
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x)
validator=sdxl_styles_validator,
corrector=sdxl_styles_corrector
)
default_prompt_negative = get_config_item_or_set_default(
key='default_prompt_negative',
Expand Down Expand Up @@ -388,20 +518,6 @@ def add_ratio(x):

os.makedirs(path_outputs, exist_ok=True)

model_filenames = []
lora_filenames = []


def get_model_filenames(folder_path, name_filter=None):
return get_files_from_folder(folder_path, ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'], name_filter)


def update_all_model_names():
global model_filenames, lora_filenames
model_filenames = get_model_filenames(path_checkpoints)
lora_filenames = get_model_filenames(path_loras)
return


def downloading_inpaint_models(v):
assert v in modules.flags.inpaint_engine_versions
Expand Down Expand Up @@ -514,5 +630,3 @@ def downloading_upscale_model():
)
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')


update_all_model_names()