diff --git a/src/huggingface_inference_toolkit/vertex_ai_utils.py b/src/huggingface_inference_toolkit/vertex_ai_utils.py
index cb588174..885c2d07 100644
--- a/src/huggingface_inference_toolkit/vertex_ai_utils.py
+++ b/src/huggingface_inference_toolkit/vertex_ai_utils.py
@@ -15,7 +15,8 @@ def _load_repository_from_gcs(artifact_uri: str, target_dir: Union[str, Path] =
from google.cloud import storage
logger.info(f"Loading model artifacts from {artifact_uri} to {target_dir}")
- target_dir = Path(target_dir)
+ if isinstance(target_dir, str):
+ target_dir = Path(target_dir)
if artifact_uri.startswith(GCS_URI_PREFIX):
matches = re.match(f"{GCS_URI_PREFIX}(.*?)/(.*)", artifact_uri)
@@ -31,9 +32,9 @@ def _load_repository_from_gcs(artifact_uri: str, target_dir: Union[str, Path] =
else name_without_prefix
)
file_split = name_without_prefix.split("/")
- directory = target_dir.join(file_split[0:-1])
+ directory = target_dir / Path(*file_split[0:-1])
directory.mkdir(parents=True, exist_ok=True)
if name_without_prefix and not name_without_prefix.endswith("/"):
- blob.download_to_filename(name_without_prefix)
+ blob.download_to_filename(target_dir / name_without_prefix)
return str(target_dir.absolute())
diff --git a/tests/unit/test_vertex_ai_utils.py b/tests/unit/test_vertex_ai_utils.py
new file mode 100644
index 00000000..eca64f29
--- /dev/null
+++ b/tests/unit/test_vertex_ai_utils.py
@@ -0,0 +1,24 @@
+from pathlib import Path
+
+from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs
+
+
+def test__load_repository_from_gcs():
+ """Tests the `_load_repository_from_gcs` function against a public artifact URI.
+
+ References:
+ - https://cloud.google.com/storage/docs/public-datasets/era5
+ - https://console.cloud.google.com/storage/browser/gcp-public-data-arco-era5/raw/date-variable-static/2021/12/31/soil_type?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))
+ """
+
+ public_artifact_uri = (
+ "gs://gcp-public-data-arco-era5/raw/date-variable-static/2021/12/31/soil_type"
+ )
+ target_dir = Path.cwd() / "target"
+ target_dir_path = _load_repository_from_gcs(
+ artifact_uri=public_artifact_uri, target_dir=target_dir
+ )
+
+ assert target_dir == Path(target_dir_path)
+ assert Path(target_dir_path).exists()
+ assert (Path(target_dir_path) / "static.nc").exists()