diff --git a/src/huggingface_inference_toolkit/vertex_ai_utils.py b/src/huggingface_inference_toolkit/vertex_ai_utils.py index cb588174..885c2d07 100644 --- a/src/huggingface_inference_toolkit/vertex_ai_utils.py +++ b/src/huggingface_inference_toolkit/vertex_ai_utils.py @@ -15,7 +15,8 @@ def _load_repository_from_gcs(artifact_uri: str, target_dir: Union[str, Path] = from google.cloud import storage logger.info(f"Loading model artifacts from {artifact_uri} to {target_dir}") - target_dir = Path(target_dir) + if isinstance(target_dir, str): + target_dir = Path(target_dir) if artifact_uri.startswith(GCS_URI_PREFIX): matches = re.match(f"{GCS_URI_PREFIX}(.*?)/(.*)", artifact_uri) @@ -31,9 +32,9 @@ def _load_repository_from_gcs(artifact_uri: str, target_dir: Union[str, Path] = else name_without_prefix ) file_split = name_without_prefix.split("/") - directory = target_dir.join(file_split[0:-1]) + directory = target_dir / Path(*file_split[0:-1]) directory.mkdir(parents=True, exist_ok=True) if name_without_prefix and not name_without_prefix.endswith("/"): - blob.download_to_filename(name_without_prefix) + blob.download_to_filename(target_dir / name_without_prefix) return str(target_dir.absolute()) diff --git a/tests/unit/test_vertex_ai_utils.py b/tests/unit/test_vertex_ai_utils.py new file mode 100644 index 00000000..eca64f29 --- /dev/null +++ b/tests/unit/test_vertex_ai_utils.py @@ -0,0 +1,24 @@ +from pathlib import Path + +from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs + + +def test__load_repository_from_gcs(): + """Tests the `_load_repository_from_gcs` function against a public artifact URI. + + References: + - https://cloud.google.com/storage/docs/public-datasets/era5 + - https://console.cloud.google.com/storage/browser/gcp-public-data-arco-era5/raw/date-variable-static/2021/12/31/soil_type?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22)) + """ + + public_artifact_uri = ( + "gs://gcp-public-data-arco-era5/raw/date-variable-static/2021/12/31/soil_type" + ) + target_dir = Path.cwd() / "target" + target_dir_path = _load_repository_from_gcs( + artifact_uri=public_artifact_uri, target_dir=target_dir + ) + + assert target_dir == Path(target_dir_path) + assert Path(target_dir_path).exists() + assert (Path(target_dir_path) / "static.nc").exists()