Skip to content

Commit ba7c2b2

Browse files
committed
mv _download_hf_files_to_cache
1 parent bce9726 commit ba7c2b2

File tree

2 files changed

+67
-58
lines changed

2 files changed

+67
-58
lines changed

tools/who_what_benchmark/tests/ov_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import uuid
77
import shutil
88
import logging
9+
import subprocess # nosec B404
910

1011
from pathlib import Path
1112
from typing import Callable
@@ -133,3 +134,67 @@ def _cleanup_temp(self) -> None:
133134
shutil.rmtree(self.temp_path)
134135
except Exception:
135136
logger.exception("Could not clean up temp directory")
137+
138+
139+
def download_hf_files_to_cache(repo_id: str, cache_dir: Path, filenames: list[str]) -> Path:
140+
"""Download a set of files from a Hugging Face repo into a local cache directory.
141+
142+
This helper is designed for tests that share a cache across CI jobs. If the
143+
destination directory already exists, it verifies that all required files are
144+
present and only downloads missing ones.
145+
146+
Args:
147+
repo_id: Hugging Face repo id (e.g. "org/model").
148+
cache_dir: Destination directory.
149+
filenames: List of repo file paths to download.
150+
151+
Returns:
152+
Path to the destination directory containing the downloaded files.
153+
"""
154+
155+
dest_dir = Path(cache_dir)
156+
157+
def download_to_local_dir(local_dir: Path) -> None:
158+
for filename in filenames:
159+
command = [
160+
"huggingface-cli",
161+
"download",
162+
repo_id,
163+
filename,
164+
"--local-dir",
165+
str(local_dir),
166+
]
167+
168+
def _run_download() -> None:
169+
subprocess.run(command, check=True, text=True, capture_output=True)
170+
171+
retry_request(_run_download)
172+
173+
# If destination exists (e.g. shared CI cache), make sure all required files are present.
174+
if dest_dir.exists():
175+
dest_dir.mkdir(parents=True, exist_ok=True)
176+
missing = [name for name in filenames if not (dest_dir / name).exists()]
177+
if missing:
178+
temp_dir = dest_dir.parent / f".tmp_{dest_dir.name}_{uuid.uuid4().hex[:8]}"
179+
temp_dir.mkdir(parents=True, exist_ok=True)
180+
try:
181+
download_to_local_dir(temp_dir)
182+
for filename in missing:
183+
src = temp_dir / filename
184+
if not src.exists():
185+
raise AssertionError(f"Download failed: {src}")
186+
dst = dest_dir / filename
187+
dst.parent.mkdir(parents=True, exist_ok=True)
188+
src.replace(dst)
189+
finally:
190+
shutil.rmtree(temp_dir, ignore_errors=True)
191+
else:
192+
manager = AtomicDownloadManager(dest_dir)
193+
manager.execute(download_to_local_dir)
194+
195+
for filename in filenames:
196+
downloaded = dest_dir / filename
197+
if not downloaded.exists():
198+
raise AssertionError(f"Download failed: {downloaded}")
199+
200+
return dest_dir

tools/who_what_benchmark/tests/test_cli_vlm.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import logging
33
from pathlib import Path
4-
import subprocess # nosec B404
54
import sys
65

76
from conftest import convert_model, run_wwb
@@ -89,62 +88,6 @@ def run_test(model_id, model_type, optimum_threshold, genai_threshold, tmp_path)
8988
])
9089

9190

92-
def _download_hf_files_to_cache(repo_id: str, cache_dir: Path, filenames: list[str]):
93-
from ov_utils import AtomicDownloadManager, retry_request
94-
95-
dest_dir = Path(cache_dir)
96-
97-
def download_to_local_dir(local_dir: Path) -> None:
98-
for filename in filenames:
99-
command = [
100-
"huggingface-cli",
101-
"download",
102-
repo_id,
103-
filename,
104-
"--local-dir",
105-
str(local_dir),
106-
]
107-
108-
def _run_download() -> None:
109-
subprocess.run(command, check=True, text=True, capture_output=True)
110-
111-
retry_request(_run_download)
112-
113-
# If destination exists (e.g. shared CI cache), make sure all required files are present.
114-
# This test previously cached only adapter_model.safetensors, but peft also requires
115-
# adapter_config.json next to it.
116-
if dest_dir.exists():
117-
dest_dir.mkdir(parents=True, exist_ok=True)
118-
missing = [name for name in filenames if not (dest_dir / name).exists()]
119-
if missing:
120-
import shutil
121-
import uuid
122-
123-
temp_dir = dest_dir.parent / f".tmp_{dest_dir.name}_{uuid.uuid4().hex[:8]}"
124-
temp_dir.mkdir(parents=True, exist_ok=True)
125-
try:
126-
download_to_local_dir(temp_dir)
127-
for filename in missing:
128-
src = temp_dir / filename
129-
if not src.exists():
130-
raise AssertionError(f"Download failed: {src}")
131-
dst = dest_dir / filename
132-
dst.parent.mkdir(parents=True, exist_ok=True)
133-
src.replace(dst)
134-
finally:
135-
shutil.rmtree(temp_dir, ignore_errors=True)
136-
else:
137-
manager = AtomicDownloadManager(dest_dir)
138-
manager.execute(download_to_local_dir)
139-
140-
for filename in filenames:
141-
downloaded = dest_dir / filename
142-
if not downloaded.exists():
143-
raise AssertionError(f"Download failed: {downloaded}")
144-
145-
return dest_dir
146-
147-
14891
def run_test_with_lora(
14992
model_id: str,
15093
model_type: str,
@@ -165,10 +108,11 @@ def run_test_with_lora(
165108
model_path = convert_model(model_id)
166109

167110
from ov_utils import get_ov_cache_dir
111+
from ov_utils import download_hf_files_to_cache
168112

169113
lora_filenames = ["adapter_model.safetensors", "adapter_config.json"]
170114
lora_cache_dir = get_ov_cache_dir() / "test_data" / lora_cache_subdir
171-
lora_adapter_dir = _download_hf_files_to_cache(lora_repo_id, lora_cache_dir, lora_filenames)
115+
lora_adapter_dir = download_hf_files_to_cache(lora_repo_id, lora_cache_dir, lora_filenames)
172116
lora_adapter_file = lora_adapter_dir / "adapter_model.safetensors"
173117
assert lora_adapter_file.exists(), f"LoRA adapter wasn't downloaded: {lora_adapter_file}"
174118

0 commit comments

Comments
 (0)