From fa78671433919f8ffbe3c24750d2e22e2bd659cf Mon Sep 17 00:00:00 2001 From: vmpuri Date: Wed, 9 Oct 2024 00:41:05 -0700 Subject: [PATCH 1/4] Download huggingface models to huggingface cache instead of ~/.torchchat --- torchchat/cli/builder.py | 10 ++-- torchchat/cli/cli.py | 2 +- torchchat/cli/convert_hf_checkpoint.py | 74 ++++++----------------- torchchat/cli/download.py | 81 ++++++++++++++++++++------ torchchat/usages/openai_api.py | 6 +- 5 files changed, 89 insertions(+), 84 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 02b1545d0..9d9af8ac4 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -30,6 +30,7 @@ from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE +from torchchat.cli.download import get_model_dir from torchchat.model_config.model_config import resolve_model_config from torchchat.utils.build_utils import ( device_sync, @@ -73,7 +74,7 @@ def __post_init__(self): or (self.pte_path and Path(self.pte_path).is_file()) ): raise RuntimeError( - "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path" + f"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path {self.checkpoint_path}" ) if self.dso_path and self.pte_path: @@ -109,10 +110,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": model_config = resolve_model_config(args.model) checkpoint_path = ( - Path(args.model_directory) - / model_config.name + get_model_dir(model_config, args.model_directory) / model_config.checkpoint_file ) + print(f"Using checkpoint path: {checkpoint_path}") # The transformers config is keyed on the last section # of the name/path. params_table = ( @@ -264,8 +265,7 @@ def from_args(cls, args: argparse.Namespace) -> "TokenizerArgs": elif args.model: # Using a named, well-known model model_config = resolve_model_config(args.model) tokenizer_path = ( - Path(args.model_directory) - / model_config.name + get_model_dir(model_config, args.model_directory) / model_config.tokenizer_file ) elif args.checkpoint_path: diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 1d624c6c4..92f8f9987 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -244,7 +244,7 @@ def _add_jit_downloading_args(parser) -> None: "--model-directory", type=Path, default=default_model_dir, - help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}", + help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}. This is overriden by the huggingface cache directory if the model is downloaded from HuggingFace.", ) diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index f95cbdaef..12bbae281 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -3,7 +3,6 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import glob import json import os import re @@ -42,12 +41,7 @@ def convert_hf_checkpoint( print(f"Model config {config.__dict__}") # Load the json file containing weight mapping - model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))] - assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files" - if len(model_map_json_matches): - model_map_json = model_map_json_matches[0] - else: - model_map_json = model_dir / "pytorch_model.bin.index.json" + model_map_json = model_dir / "pytorch_model.bin.index.json" # If there is no weight mapping, check for a consolidated model and # tokenizer we can move. Llama 2 and Mistral have weight mappings, while @@ -62,9 +56,10 @@ def convert_hf_checkpoint( str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True ) del loaded_result # No longer needed - print(f"Moving checkpoint to {model_dir / 'model.pth'}.") - os.rename(consolidated_pth, model_dir / "model.pth") - os.rename(tokenizer_pth, model_dir / "tokenizer.model") + print(f"Symlinking checkpoint to {model_dir / 'model.pth'}.") + consolidated_pth = os.path.realpath(consolidated_pth) + os.symlink(consolidated_pth, model_dir / "model.pth") + os.symlink(tokenizer_pth, model_dir / "tokenizer.model") print("Done.") return else: @@ -81,17 +76,10 @@ def convert_hf_checkpoint( "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias", - "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias", - "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias", - "model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias", "model.layers.{}.self_attn.rotary_emb.inv_freq": None, "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - "model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias", - "model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias", - "model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias", "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", "model.norm.weight": "norm.weight", @@ -100,43 +88,19 @@ def convert_hf_checkpoint( bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()} def permute(w, n_heads): + dim = config.dim return ( - w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:]) + w.view(n_heads, 2, config.head_dim // 2, dim) .transpose(1, 2) - .reshape(w.shape) + .reshape(config.head_dim * n_heads, dim) ) merged_result = {} for file in sorted(bin_files): - - # The state_dict can be loaded from either a torch zip file or - # safetensors. We take our best guess from the name and try all - # possibilities - load_pt_mmap = lambda: torch.load( + state_dict = torch.load( str(file), map_location="cpu", mmap=True, weights_only=True ) - load_pt_no_mmap = lambda: torch.load( - str(file), map_location="cpu", mmap=False, weights_only=True - ) - def load_safetensors(): - import safetensors.torch - with open(file, "rb") as handle: - return safetensors.torch.load(handle.read()) - if "safetensors" in str(file): - loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap] - else: - loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors] - - state_dict = None - for loader in loaders: - try: - state_dict = loader() - break - except Exception: - continue - assert state_dict is not None, f"Unable to load tensors from {file}" merged_result.update(state_dict) - final_result = {} for key, value in merged_result.items(): if "layers" in key: @@ -152,18 +116,16 @@ def load_safetensors(): final_result[new_key] = value for key in tuple(final_result.keys()): - if "wq.weight" in key or "wq.bias" in key: - wk_key = key.replace("wq", "wk") - wv_key = key.replace("wq", "wv") + if "wq" in key: q = final_result[key] - k = final_result[wk_key] - v = final_result[wv_key] + k = final_result[key.replace("wq", "wk")] + v = final_result[key.replace("wq", "wv")] q = permute(q, config.n_heads) k = permute(k, config.n_local_heads) final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) del final_result[key] - del final_result[wk_key] - del final_result[wv_key] + del final_result[key.replace("wq", "wk")] + del final_result[key.replace("wq", "wv")] print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.") torch.save(final_result, model_dir / "model.pth") print("Done.") @@ -184,10 +146,10 @@ def convert_hf_checkpoint_to_tune( consolidated_pth = model_dir / "original" / "consolidated.pth" tokenizer_pth = model_dir / "original" / "tokenizer.model" if consolidated_pth.is_file() and tokenizer_pth.is_file(): - print(f"Moving checkpoint to {model_dir / 'model.pth'}.") - os.rename(consolidated_pth, model_dir / "model.pth") - print(f"Moving tokenizer to {model_dir / 'tokenizer.model'}.") - os.rename(tokenizer_pth, model_dir / "tokenizer.model") + print(f"Creating symlink from {consolidated_pth} to {model_dir / 'model.pth'}.") + os.symlink(consolidated_pth, model_dir / "model.pth") + print(f"Creating symlink from {tokenizer_pth} to {model_dir / 'tokenizer.model'}.") + os.symlink(tokenizer_pth, model_dir / "tokenizer.model") print("Done.") else: raise RuntimeError(f"Could not find {consolidated_pth}") diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index 14dfeb062..3c6579d6f 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -18,15 +18,19 @@ resolve_model_config, ) +# By default, download models from HuggingFace to the Hugginface hub directory. +# Both $HF_HOME and $HUGGINGFACE_HUB_CACHE are valid environment variables for the same directory. +HUGGINGFACE_HOME_PATH = Path(os.environ.get("HF_HOME", os.environ.get("HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub")))) def _download_hf_snapshot( - model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str] + model_config: ModelConfig, hf_token: Optional[str] ): from huggingface_hub import model_info, snapshot_download from requests.exceptions import HTTPError # Download and store the HF model artifacts. - print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr) + model_dir = get_model_dir(model_config, None) + print(f"Downloading {model_config.name} from Hugging Face to {model_dir}", file=sys.stderr, flush=True) try: # Fetch the info about the model's repo model_info = model_info(model_config.distribution_path, token=hf_token) @@ -56,8 +60,6 @@ def _download_hf_snapshot( snapshot_download( model_config.distribution_path, - local_dir=artifact_dir, - local_dir_use_symlinks=False, token=hf_token, ignore_patterns=ignore_patterns, ) @@ -76,16 +78,20 @@ def _download_hf_snapshot( else: raise e + # Update the model dir to include the snapshot we just downloaded. + model_dir = get_model_dir(model_config, None) + print("Model downloaded to", model_dir) + # Convert the Multimodal Llama model to the torchtune format. if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}: print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr) - convert_hf_checkpoint_to_tune( model_dir=artifact_dir, model_name=model_config.name) + convert_hf_checkpoint_to_tune( model_dir=model_dir, model_name=model_config.name) else: # Convert the model to the torchchat format. print(f"Converting {model_config.name} to torchchat format...", file=sys.stderr) convert_hf_checkpoint( - model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True + model_dir=model_dir, model_name=model_config.name, remove_bin_files=True ) @@ -99,12 +105,51 @@ def _download_direct( print(f"Downloading {url}...", file=sys.stderr) urllib.request.urlretrieve(url, str(local_path.absolute())) +def _get_hf_artifact_dir(model_config: ModelConfig) -> Path: + """ + Returns the directory where the model artifacts are stored. + + This is the root folder with blobs, refs and snapshots + """ + assert(model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot) + return HUGGINGFACE_HOME_PATH / f"models--{model_config.distribution_path.replace('/', '--')}" + + +def get_model_dir(model_config: ModelConfig, models_dir: Optional[Path]) -> Path: + """ + Returns the directory where the model artifacts are stored. + For HuggingFace snapshots, this is the HuggingFace cache directory. + For all other distribution channels, we use the models_dir. + + For CLI usage, pass in args.model_directory. + """ + if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot: + artifact_dir = _get_hf_artifact_dir(model_config) + + # If these paths doesn't exist, it means the model hasn't been downloaded yet. + if not os.path.isdir(artifact_dir) and not os.path.isdir(artifact_dir / "snapshots"): + return artifact_dir + snapshot = open(artifact_dir / "refs" / "main", "r").read().strip() + return artifact_dir / "snapshots" / snapshot + else: + return models_dir / model_config.name + def download_and_convert( model: str, models_dir: Path, hf_token: Optional[str] = None ) -> None: model_config = resolve_model_config(model) - model_dir = models_dir / model_config.name + model_dir = get_model_dir(model_config, models_dir) + + # HuggingFace download + if ( + model_config.distribution_channel + == ModelDistributionChannel.HuggingFaceSnapshot + ): + _download_hf_snapshot(model_config, hf_token) + return + + # Direct download # Download into a temporary directory. We'll move to the final # location once the download and conversion is complete. This @@ -117,11 +162,6 @@ def download_and_convert( try: if ( - model_config.distribution_channel - == ModelDistributionChannel.HuggingFaceSnapshot - ): - _download_hf_snapshot(model_config, temp_dir, hf_token) - elif ( model_config.distribution_channel == ModelDistributionChannel.DirectDownload ): _download_direct(model_config, temp_dir) @@ -144,9 +184,9 @@ def download_and_convert( def is_model_downloaded(model: str, models_dir: Path) -> bool: model_config = resolve_model_config(model) - + # Check if the model directory exists and is not empty. - model_dir = models_dir / model_config.name + model_dir = get_model_dir(model_config, models_dir) return os.path.isdir(model_dir) and os.listdir(model_dir) @@ -194,13 +234,16 @@ def remove_main(args) -> None: return model_config = resolve_model_config(args.model) - model_dir = args.model_directory / model_config.name + model_dir = get_model_dir(model_config, args.model_directory) if not os.path.isdir(model_dir): - print(f"Model {args.model} has no downloaded artifacts.") + print(f"Model {args.model} has no downloaded artifacts in {model_dir}.") return + if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot: + # For HuggingFace models, we need to remove the entire root directory. + model_dir = _get_hf_artifact_dir(model_config) - print(f"Removing downloaded model artifacts for {args.model}...") + print(f"Removing downloaded model artifacts for {args.model} at {model_dir}...") shutil.rmtree(model_dir) print("Done.") @@ -216,10 +259,10 @@ def where_main(args) -> None: return model_config = resolve_model_config(args.model) - model_dir = args.model_directory / model_config.name + model_dir = get_model_dir(model_config, args.model_directory) if not os.path.isdir(model_dir): - raise RuntimeError(f"Model {args.model} has no downloaded artifacts.") + raise RuntimeError(f"Model {args.model} has no downloaded artifacts in {model_dir}.") print(str(os.path.abspath(model_dir))) exit(0) diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 72a6dfc9b..2adf170bd 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -23,7 +23,7 @@ from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform -from torchchat.cli.download import is_model_downloaded, load_model_configs +from torchchat.cli.download import is_model_downloaded, load_model_configs, get_model_dir from torchchat.generate import Generator, GeneratorArgs from torchchat.model import FlamingoModel @@ -522,7 +522,7 @@ def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]: """ if model_config := load_model_configs().get(model_id): if is_model_downloaded(model_id, args.model_directory): - path = args.model_directory / model_config.name + path = get_model_dir(model_config, args.model_directory) created = int(os.path.getctime(path)) owned_by = getpwuid(os.stat(path).st_uid).pw_name @@ -545,7 +545,7 @@ def get_model_info_list(args) -> ModelInfo: data = [] for model_id, model_config in load_model_configs().items(): if is_model_downloaded(model_id, args.model_directory): - path = args.model_directory / model_config.name + path = get_model_dir(model_config, args.model_directory) created = int(os.path.getctime(path)) owned_by = getpwuid(os.stat(path).st_uid).pw_name From 2e722ba107f1075625b41526d8c7ff0dd7b17b27 Mon Sep 17 00:00:00 2001 From: vmpuri Date: Wed, 9 Oct 2024 13:13:21 -0700 Subject: [PATCH 2/4] Enable hf_transfer for faster download --- install/install_requirements.sh | 2 ++ install/requirements.txt | 2 +- torchchat/cli/download.py | 3 +++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index a05e255db..1867f3479 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -106,3 +106,5 @@ fi set -x $PIP_EXECUTABLE install evaluate=="0.4.3" lm-eval=="0.4.2" psutil=="6.0.0" ) + +export HF_HUB_ENABLE_HF_TRANSFER=1 diff --git a/install/requirements.txt b/install/requirements.txt index 3329563b4..ddb32a38f 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -1,7 +1,7 @@ # Requires python >=3.10 # Hugging Face download -huggingface_hub +huggingface_hub[hf_transfer] # GGUF import gguf diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index 3c6579d6f..14c19f943 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -22,6 +22,9 @@ # Both $HF_HOME and $HUGGINGFACE_HUB_CACHE are valid environment variables for the same directory. HUGGINGFACE_HOME_PATH = Path(os.environ.get("HF_HOME", os.environ.get("HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub")))) +if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", None) is None: + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + def _download_hf_snapshot( model_config: ModelConfig, hf_token: Optional[str] ): From 654dbec07c845c9d5b18a19935e536157c729d65 Mon Sep 17 00:00:00 2001 From: vmpuri Date: Wed, 9 Oct 2024 13:51:52 -0700 Subject: [PATCH 3/4] Cleanup --- install/install_requirements.sh | 2 -- torchchat/cli/builder.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 1867f3479..a05e255db 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -106,5 +106,3 @@ fi set -x $PIP_EXECUTABLE install evaluate=="0.4.3" lm-eval=="0.4.2" psutil=="6.0.0" ) - -export HF_HUB_ENABLE_HF_TRANSFER=1 diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 9d9af8ac4..fbca4659c 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -74,7 +74,7 @@ def __post_init__(self): or (self.pte_path and Path(self.pte_path).is_file()) ): raise RuntimeError( - f"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path {self.checkpoint_path}" + f"{self.checkpoint_path} is not a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path" ) if self.dso_path and self.pte_path: From 84602c8a9a5b7817e61ec92d179b88ef269d06ba Mon Sep 17 00:00:00 2001 From: vmpuri Date: Wed, 9 Oct 2024 14:06:26 -0700 Subject: [PATCH 4/4] Delete models from old location for huggingface download --- torchchat/cli/download.py | 104 ++++++++++++++++++++++++++++---------- 1 file changed, 76 insertions(+), 28 deletions(-) diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index 14c19f943..0a36f0274 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -10,7 +10,10 @@ from pathlib import Path from typing import Optional -from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint, convert_hf_checkpoint_to_tune +from torchchat.cli.convert_hf_checkpoint import ( + convert_hf_checkpoint, + convert_hf_checkpoint_to_tune, +) from torchchat.model_config.model_config import ( load_model_configs, ModelConfig, @@ -20,20 +23,46 @@ # By default, download models from HuggingFace to the Hugginface hub directory. # Both $HF_HOME and $HUGGINGFACE_HUB_CACHE are valid environment variables for the same directory. -HUGGINGFACE_HOME_PATH = Path(os.environ.get("HF_HOME", os.environ.get("HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub")))) +HUGGINGFACE_HOME_PATH = Path( + os.environ.get( + "HF_HOME", + os.environ.get( + "HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub") + ), + ) +) if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", None) is None: os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" -def _download_hf_snapshot( - model_config: ModelConfig, hf_token: Optional[str] -): + +# Previously, all models were stored in the torchchat models directory (by default ~/.torchchat/model-cache) +# For Hugging Face models, we now store them in the HuggingFace cache directory. +# This function will delete all model artifacts in the old directory for each model with the Hugging Face distribution path. +def _cleanup_hf_models_from_torchchat_dir(models_dir: Path): + for model_config in load_model_configs().values(): + if ( + model_config.distribution_channel + == ModelDistributionChannel.HuggingFaceSnapshot + ): + if os.path.exists(models_dir / model_config.name): + print( + f"Cleaning up old model artifacts in {models_dir / model_config.name}. New artifacts will be downloaded to {HUGGINGFACE_HOME_PATH}" + ) + shutil.rmtree(models_dir / model_config.name) + + +def _download_hf_snapshot(model_config: ModelConfig, hf_token: Optional[str]): from huggingface_hub import model_info, snapshot_download from requests.exceptions import HTTPError # Download and store the HF model artifacts. model_dir = get_model_dir(model_config, None) - print(f"Downloading {model_config.name} from Hugging Face to {model_dir}", file=sys.stderr, flush=True) + print( + f"Downloading {model_config.name} from Hugging Face to {model_dir}", + file=sys.stderr, + flush=True, + ) try: # Fetch the info about the model's repo model_info = model_info(model_config.distribution_path, token=hf_token) @@ -81,14 +110,17 @@ def _download_hf_snapshot( else: raise e - # Update the model dir to include the snapshot we just downloaded. + # Update the model dir to include the snapshot we just downloaded. model_dir = get_model_dir(model_config, None) print("Model downloaded to", model_dir) # Convert the Multimodal Llama model to the torchtune format. - if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}: + if model_config.name in { + "meta-llama/Llama-3.2-11B-Vision-Instruct", + "meta-llama/Llama-3.2-11B-Vision", + }: print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr) - convert_hf_checkpoint_to_tune( model_dir=model_dir, model_name=model_config.name) + convert_hf_checkpoint_to_tune(model_dir=model_dir, model_name=model_config.name) else: # Convert the model to the torchchat format. @@ -108,32 +140,44 @@ def _download_direct( print(f"Downloading {url}...", file=sys.stderr) urllib.request.urlretrieve(url, str(local_path.absolute())) + def _get_hf_artifact_dir(model_config: ModelConfig) -> Path: """ Returns the directory where the model artifacts are stored. - + This is the root folder with blobs, refs and snapshots """ - assert(model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot) - return HUGGINGFACE_HOME_PATH / f"models--{model_config.distribution_path.replace('/', '--')}" + assert ( + model_config.distribution_channel + == ModelDistributionChannel.HuggingFaceSnapshot + ) + return ( + HUGGINGFACE_HOME_PATH + / f"models--{model_config.distribution_path.replace('/', '--')}" + ) def get_model_dir(model_config: ModelConfig, models_dir: Optional[Path]) -> Path: """ - Returns the directory where the model artifacts are stored. - For HuggingFace snapshots, this is the HuggingFace cache directory. - For all other distribution channels, we use the models_dir. - - For CLI usage, pass in args.model_directory. + Returns the directory where the model artifacts are expected to be stored. + For Hugging Face artifacts, this will be the location of the "main" snapshot if it exists, or the expected model directory otherwise. + For all other distribution channels, we use the models_dir. + + For CLI usage, pass in args.model_directory. """ - if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot: - artifact_dir = _get_hf_artifact_dir(model_config) - + if ( + model_config.distribution_channel + == ModelDistributionChannel.HuggingFaceSnapshot + ): + artifact_dir = _get_hf_artifact_dir(model_config) + # If these paths doesn't exist, it means the model hasn't been downloaded yet. - if not os.path.isdir(artifact_dir) and not os.path.isdir(artifact_dir / "snapshots"): + if not os.path.isdir(artifact_dir) and not os.path.isdir( + artifact_dir / "snapshots" + ): return artifact_dir snapshot = open(artifact_dir / "refs" / "main", "r").read().strip() - return artifact_dir / "snapshots" / snapshot + return artifact_dir / "snapshots" / snapshot else: return models_dir / model_config.name @@ -164,9 +208,7 @@ def download_and_convert( os.makedirs(temp_dir, exist_ok=True) try: - if ( - model_config.distribution_channel == ModelDistributionChannel.DirectDownload - ): + if model_config.distribution_channel == ModelDistributionChannel.DirectDownload: _download_direct(model_config, temp_dir) else: raise RuntimeError( @@ -187,7 +229,7 @@ def download_and_convert( def is_model_downloaded(model: str, models_dir: Path) -> bool: model_config = resolve_model_config(model) - + # Check if the model directory exists and is not empty. model_dir = get_model_dir(model_config, models_dir) return os.path.isdir(model_dir) and os.listdir(model_dir) @@ -242,7 +284,10 @@ def remove_main(args) -> None: if not os.path.isdir(model_dir): print(f"Model {args.model} has no downloaded artifacts in {model_dir}.") return - if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot: + if ( + model_config.distribution_channel + == ModelDistributionChannel.HuggingFaceSnapshot + ): # For HuggingFace models, we need to remove the entire root directory. model_dir = _get_hf_artifact_dir(model_config) @@ -265,7 +310,9 @@ def where_main(args) -> None: model_dir = get_model_dir(model_config, args.model_directory) if not os.path.isdir(model_dir): - raise RuntimeError(f"Model {args.model} has no downloaded artifacts in {model_dir}.") + raise RuntimeError( + f"Model {args.model} has no downloaded artifacts in {model_dir}." + ) print(str(os.path.abspath(model_dir))) exit(0) @@ -273,4 +320,5 @@ def where_main(args) -> None: # Subcommand to download model artifacts. def download_main(args) -> None: + _cleanup_hf_models_from_torchchat_dir(args.model_directory) download_and_convert(args.model, args.model_directory, args.hf_token)