Skip to content

Commit ee7c0c5

Browse files
author
Liav Weiss
committed
feat: re-enable EmbeddingGemma-300m support
Signed-off-by: Liav Weiss <[email protected]>
1 parent 8e1c5ca commit ee7c0c5

File tree

10 files changed

+76
-196
lines changed

10 files changed

+76
-196
lines changed

.github/workflows/integration-test-docker.yml

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,14 @@ jobs:
8282
- name: Download models
8383
run: |
8484
echo "Downloading minimal models for CI..."
85-
# Authenticate with HuggingFace if token is available
86-
# Note: For PRs from forks, HF_TOKEN is not available (GitHub security feature)
87-
# The makefile will gracefully skip gated models (e.g., embeddinggemma-300m) if token is missing
88-
if [ -n "${HF_TOKEN:-}" ]; then
89-
huggingface-cli login --token "$HF_TOKEN" --add-to-git-credential
90-
export HUGGINGFACE_HUB_TOKEN="$HF_TOKEN"
91-
fi
9285
make download-models
9386
env:
9487
CI: true
9588
CI_MINIMAL_MODELS: true
96-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
9789
HF_HUB_ENABLE_HF_TRANSFER: 1
9890
HF_HUB_DISABLE_TELEMETRY: 1
9991
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
100-
# For PRs from forks, this will be empty and the makefile will gracefully skip gated models
92+
# For PRs from forks, this will be empty and the model_manager will gracefully skip gated models
10193
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
10294
HF_TOKEN: ${{ secrets.HF_TOKEN }}
10395
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}

.github/workflows/performance-nightly.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,10 @@ jobs:
7070
- name: Download models (minimal set for nightly)
7171
env:
7272
CI_MINIMAL_MODELS: false
73-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
7473
HF_HUB_ENABLE_HF_TRANSFER: 1
7574
HF_HUB_DISABLE_TELEMETRY: 1
7675
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
77-
# For PRs from forks, this will be empty and the makefile will gracefully skip gated models
76+
# For PRs from forks, this will be empty and the model_manager will gracefully skip gated models
7877
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
7978
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8079
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}

.github/workflows/performance-test.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,10 @@ jobs:
7979
- name: Download models (minimal)
8080
env:
8181
CI_MINIMAL_MODELS: true
82-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8382
HF_HUB_ENABLE_HF_TRANSFER: 1
8483
HF_HUB_DISABLE_TELEMETRY: 1
8584
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
86-
# For PRs from forks, this will be empty and the makefile will gracefully skip gated models
85+
# For PRs from forks, this will be empty and the model_manager will gracefully skip gated models
8786
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
8887
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8988
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}

.github/workflows/test-and-build.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,10 @@ jobs:
140140
- name: Download models (minimal on PRs)
141141
env:
142142
CI_MINIMAL_MODELS: ${{ github.event_name == 'pull_request' }}
143-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
144143
HF_HUB_ENABLE_HF_TRANSFER: 1
145144
HF_HUB_DISABLE_TELEMETRY: 1
146145
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
147-
# For PRs from forks, this will be empty and the makefile will gracefully skip gated models
146+
# For PRs from forks, this will be empty and the model_manager will gracefully skip gated models
148147
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
149148
HF_TOKEN: ${{ secrets.HF_TOKEN }}
150149
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}

config/model_manager/models.lora.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,7 @@ models:
2929

3030
- id: Qwen3-Embedding-0.6B
3131
repo_id: Qwen/Qwen3-Embedding-0.6B
32+
33+
# Gated model - requires HF_TOKEN (will gracefully skip if token not available)
34+
- id: embeddinggemma-300m
35+
repo_id: google/embeddinggemma-300m

config/model_manager/models.minimal.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
# Equivalent to: make download-models-minimal
1111
# or CI_MINIMAL_MODELS=true make download-models
1212
#
13-
# Note: This is the minimal set for fast CI runs. Larger models like
14-
# embeddinggemma-300m are in models.yaml (full set) for local development.
13+
# Note: This is the minimal set for fast CI runs. Gated models like
14+
# embeddinggemma-300m will gracefully skip if HF_TOKEN is not available.
1515

1616
cache_dir: "models"
1717
verify: "size" # Use size for faster CI runs
@@ -56,6 +56,10 @@ models:
5656
- id: Qwen3-Embedding-0.6B
5757
repo_id: Qwen/Qwen3-Embedding-0.6B
5858

59+
# Gated model - requires HF_TOKEN (will gracefully skip if token not available)
60+
- id: embeddinggemma-300m
61+
repo_id: google/embeddinggemma-300m
62+
5963
# =============================================================================
6064
# Hallucination Detection - Required for hallucination tests
6165
# =============================================================================

src/model_manager/__init__.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
MissingModelError,
2525
BadChecksumError,
2626
DownloadError,
27+
GatedModelError,
2728
)
2829

2930
__version__ = "0.1.0"
@@ -41,6 +42,7 @@
4142
"MissingModelError",
4243
"BadChecksumError",
4344
"DownloadError",
45+
"GatedModelError",
4446
]
4547

4648

@@ -96,15 +98,25 @@ def ensure_all(self) -> dict[str, str]:
9698
continue
9799

98100
logger.info(f"Downloading model '{spec.id}' from {spec.repo_id}...")
99-
local_path = download_model(spec, self.config.cache_dir)
100-
101-
if self.config.verify != "none":
102-
logger.info(f"Verifying model '{spec.id}'...")
103-
if not verify_model(local_path, self.config.verify):
104-
raise BadChecksumError(f"Verification failed for model '{spec.id}'")
105-
106-
results[spec.id] = local_path
107-
logger.info(f"Model '{spec.id}' ready at {local_path}")
101+
try:
102+
local_path = download_model(spec, self.config.cache_dir)
103+
104+
if self.config.verify != "none":
105+
logger.info(f"Verifying model '{spec.id}'...")
106+
if not verify_model(local_path, self.config.verify):
107+
raise BadChecksumError(
108+
f"Verification failed for model '{spec.id}'"
109+
)
110+
111+
results[spec.id] = local_path
112+
logger.info(f"Model '{spec.id}' ready at {local_path}")
113+
except GatedModelError as e:
114+
# Gracefully skip gated models when token is not available
115+
logger.warning(
116+
f"⚠️ Skipping gated model '{spec.id}': {e}. "
117+
"This is expected for PRs from forks where HF_TOKEN is not available."
118+
)
119+
continue
108120

109121
return results
110122

@@ -117,7 +129,14 @@ def ensure_model(self, model_id: str) -> str:
117129
118130
Returns:
119131
Local path to the model
132+
133+
Raises:
134+
GatedModelError: If the model is gated and HF_TOKEN is not available
120135
"""
136+
import logging
137+
138+
logger = logging.getLogger(__name__)
139+
121140
spec = self.get_model_spec(model_id)
122141
if spec is None:
123142
raise MissingModelError(f"Model '{model_id}' not found in configuration")

src/model_manager/downloader.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,14 @@
1111
from huggingface_hub import snapshot_download
1212
from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError
1313

14+
try:
15+
from huggingface_hub.errors import GatedRepoError
16+
except ImportError:
17+
# Fallback for older versions of huggingface_hub
18+
GatedRepoError = None
19+
1420
from .config import ModelSpec
15-
from .errors import DownloadError, MissingModelError
21+
from .errors import DownloadError, MissingModelError, GatedModelError
1622

1723
logger = logging.getLogger(__name__)
1824

@@ -86,6 +92,22 @@ def download_model(spec: ModelSpec, cache_dir: str) -> str:
8692
"Check if the revision (commit/tag/branch) exists."
8793
)
8894
except Exception as e:
95+
# Check if this is a gated model error (401 Unauthorized or GatedRepoError)
96+
error_str = str(e).lower()
97+
is_gated_error = (
98+
(GatedRepoError is not None and isinstance(e, GatedRepoError))
99+
or "401" in error_str
100+
or "unauthorized" in error_str
101+
or "gated" in error_str
102+
or "gatedrepoerror" in error_str
103+
)
104+
105+
if is_gated_error:
106+
raise GatedModelError(
107+
f"Gated model '{spec.id}' requires HF_TOKEN authentication. "
108+
f"Set HF_TOKEN or HUGGINGFACE_HUB_TOKEN environment variable to download."
109+
) from e
110+
89111
raise DownloadError(
90112
f"Failed to download model '{spec.id}' from '{spec.repo_id}': {e}"
91113
) from e

src/model_manager/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,9 @@ class ConfigurationError(ModelManagerError):
3131
"""Raised when configuration is invalid or missing."""
3232

3333
pass
34+
35+
36+
class GatedModelError(ModelManagerError):
37+
"""Raised when attempting to download a gated model without authentication."""
38+
39+
pass

0 commit comments

Comments
 (0)