Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from backend.server.request_handler import RequestHandler
from backend.server.scheduler_manage import SchedulerManage
from backend.server.server_args import parse_args
from backend.server.static_config import get_model_list, get_node_join_command
from backend.server.static_config import (
get_model_list,
get_node_join_command,
init_model_info_dict_cache,
)
from parallax_utils.ascii_anime import display_parallax_run
from parallax_utils.file_util import get_project_root
from parallax_utils.logging_config import get_logger, set_log_level
Expand Down Expand Up @@ -126,6 +130,7 @@ async def serve_index():
args = parse_args()
set_log_level(args.log_level)
logger.info(f"args: {args}")
init_model_info_dict_cache(args.use_hfcache)
Comment thread
gufengc marked this conversation as resolved.
if args.log_level != "DEBUG":
display_parallax_run()
check_latest_release()
Expand Down
44 changes: 36 additions & 8 deletions src/backend/server/static_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,26 +142,54 @@ def _load_config_only(name: str) -> dict:
return model_info


def get_model_info_list():
def get_model_info_with_try_catch(model_name, use_hfcache: bool = False):
try:
return get_model_info(model_name, use_hfcache)
except Exception as e:
logger.debug(f"Error loading config.json for {model_name}: {e}")
return None


def get_model_info_dict(use_hfcache: bool = False):
model_name_list = list(MODELS.keys())
with concurrent.futures.ThreadPoolExecutor() as executor:
model_info_list = list(executor.map(get_model_info, model_name_list))
return model_info_list
model_info_dict = dict(
executor.map(
lambda name: (name, get_model_info_with_try_catch(name, use_hfcache)),
model_name_list,
)
)
return model_info_dict


model_info_dict_cache = None

model_info_list_cache = get_model_info_list()

def init_model_info_dict_cache(use_hfcache: bool = False):
global model_info_dict_cache
if model_info_dict_cache is not None:
return
model_info_dict_cache = get_model_info_dict(use_hfcache)


def get_model_info_dict_cache():
return model_info_dict_cache


def get_model_list():
model_info_list = model_info_list_cache
model_name_list = list(MODELS.keys())
model_info_dict = get_model_info_dict_cache()

def build_single_model(model_info):
def build_single_model(model_name, model_info):
return {
"name": model_info.model_name,
"name": model_name,
"vram_gb": math.ceil(estimate_vram_gb_required(model_info)),
}

results = [build_single_model(model_info) for model_info in model_info_list]
results = [
build_single_model(model_name, model_info_dict.get(model_name, None))
for model_name in model_name_list
]
return results


Expand Down