|
3 | 3 | #include "chat.h" |
4 | 4 | #include "common.h" |
5 | 5 | #include "download.h" |
| 6 | +#include "hf-cache.h" |
6 | 7 | #include "json-schema-to-grammar.h" |
7 | 8 | #include "log.h" |
8 | 9 | #include "sampling.h" |
@@ -326,60 +327,48 @@ struct handle_model_result { |
326 | 327 | common_params_model mmproj; |
327 | 328 | }; |
328 | 329 |
|
329 | | -static handle_model_result common_params_handle_model( |
330 | | - struct common_params_model & model, |
331 | | - const std::string & bearer_token, |
332 | | - bool offline) { |
| 330 | +static handle_model_result common_params_handle_model(struct common_params_model & model, |
| 331 | + const std::string & bearer_token, |
| 332 | + bool offline) { |
333 | 333 | handle_model_result result; |
334 | | - // handle pre-fill default model path and url based on hf_repo and hf_file |
335 | | - { |
336 | | - if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths |
337 | | - model.path = common_docker_resolve_model(model.docker_repo); |
338 | | - model.name = model.docker_repo; // set name for consistency |
339 | | - } else if (!model.hf_repo.empty()) { |
340 | | - // short-hand to avoid specifying --hf-file -> default it to --model |
341 | | - if (model.hf_file.empty()) { |
342 | | - if (model.path.empty()) { |
343 | | - auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline); |
344 | | - if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { |
345 | | - exit(1); // error message already printed |
346 | | - } |
347 | | - model.name = model.hf_repo; // repo name with tag |
348 | | - model.hf_repo = auto_detected.repo; // repo name without tag |
349 | | - model.hf_file = auto_detected.ggufFile; |
350 | | - if (!auto_detected.mmprojFile.empty()) { |
351 | | - result.found_mmproj = true; |
352 | | - result.mmproj.hf_repo = model.hf_repo; |
353 | | - result.mmproj.hf_file = auto_detected.mmprojFile; |
354 | | - } |
355 | | - } else { |
356 | | - model.hf_file = model.path; |
357 | | - } |
358 | | - } |
359 | 334 |
|
360 | | - std::string model_endpoint = get_model_endpoint(); |
361 | | - model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; |
362 | | - // make sure model path is present (for caching purposes) |
363 | | - if (model.path.empty()) { |
364 | | - // this is to avoid different repo having same file name, or same file name in different subdirs |
365 | | - std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file); |
366 | | - model.path = fs_get_cache_file(filename); |
367 | | - } |
| 335 | + if (!model.docker_repo.empty()) { |
| 336 | + model.path = common_docker_resolve_model(model.docker_repo); |
| 337 | + model.name = model.docker_repo; |
| 338 | + } else if (!model.hf_repo.empty()) { |
| 339 | + // If -m was used with -hf, treat the model "path" as the hf_file to download |
| 340 | + if (model.hf_file.empty() && !model.path.empty()) { |
| 341 | + model.hf_file = model.path; |
| 342 | + model.path = ""; |
| 343 | + } |
| 344 | + common_download_model_opts opts; |
| 345 | + opts.download_mmproj = true; |
| 346 | + opts.offline = offline; |
| 347 | + auto download_result = common_download_model(model, bearer_token, opts); |
| 348 | + |
| 349 | + if (download_result.model_path.empty()) { |
| 350 | + LOG_ERR("error: failed to download model from Hugging Face\n"); |
| 351 | + exit(1); |
| 352 | + } |
368 | 353 |
|
369 | | - } else if (!model.url.empty()) { |
370 | | - if (model.path.empty()) { |
371 | | - auto f = string_split<std::string>(model.url, '#').front(); |
372 | | - f = string_split<std::string>(f, '?').front(); |
373 | | - model.path = fs_get_cache_file(string_split<std::string>(f, '/').back()); |
374 | | - } |
| 354 | + model.name = model.hf_repo; |
| 355 | + model.path = download_result.model_path; |
375 | 356 |
|
| 357 | + if (!download_result.mmproj_path.empty()) { |
| 358 | + result.found_mmproj = true; |
| 359 | + result.mmproj.path = download_result.mmproj_path; |
| 360 | + } |
| 361 | + } else if (!model.url.empty()) { |
| 362 | + if (model.path.empty()) { |
| 363 | + auto f = string_split<std::string>(model.url, '#').front(); |
| 364 | + f = string_split<std::string>(f, '?').front(); |
| 365 | + model.path = fs_get_cache_file(string_split<std::string>(f, '/').back()); |
376 | 366 | } |
377 | | - } |
378 | 367 |
|
379 | | - // then, download it if needed |
380 | | - if (!model.url.empty()) { |
381 | | - bool ok = common_download_model(model, bearer_token, offline); |
382 | | - if (!ok) { |
| 368 | + common_download_model_opts opts; |
| 369 | + opts.offline = offline; |
| 370 | + auto download_result = common_download_model(model, bearer_token, opts); |
| 371 | + if (download_result.model_path.empty()) { |
383 | 372 | LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); |
384 | 373 | exit(1); |
385 | 374 | } |
@@ -539,6 +528,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context |
539 | 528 | // parse the first time to get -hf option (used for remote preset) |
540 | 529 | parse_cli_args(); |
541 | 530 |
|
| 531 | + // TODO: Remove later |
| 532 | + try { |
| 533 | + hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline); |
| 534 | + } catch (const std::exception & e) { |
| 535 | + LOG_WRN("HF cache migration failed: %s\n", e.what()); |
| 536 | + } |
| 537 | + |
542 | 538 | // maybe handle remote preset |
543 | 539 | if (!params.model.hf_repo.empty()) { |
544 | 540 | std::string cli_hf_repo = params.model.hf_repo; |
@@ -1061,12 +1057,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex |
1061 | 1057 | {"-cl", "--cache-list"}, |
1062 | 1058 | "show list of models in cache", |
1063 | 1059 | [](common_params &) { |
1064 | | - printf("model cache directory: %s\n", fs_get_cache_directory().c_str()); |
1065 | 1060 | auto models = common_list_cached_models(); |
1066 | 1061 | printf("number of models in cache: %zu\n", models.size()); |
1067 | 1062 | for (size_t i = 0; i < models.size(); i++) { |
1068 | | - auto & model = models[i]; |
1069 | | - printf("%4d. %s\n", (int) i + 1, model.to_string().c_str()); |
| 1063 | + printf("%4zu. %s\n", i + 1, models[i].to_string().c_str()); |
1070 | 1064 | } |
1071 | 1065 | exit(0); |
1072 | 1066 | } |
|
0 commit comments