Skip to content

Commit fee395f

Browse files
author
Liav Weiss
committed
feat: re-enable EmbeddingGemma-300m support
Signed-off-by: Liav Weiss <[email protected]>
1 parent 317dfcf commit fee395f

File tree

13 files changed

+376
-62
lines changed

13 files changed

+376
-62
lines changed

.github/workflows/integration-test-docker.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,24 @@ jobs:
8282
- name: Download models
8383
run: |
8484
echo "Downloading minimal models for CI..."
85+
# Authenticate with HuggingFace if token is available
86+
# Note: For PRs from forks, HF_TOKEN is not available (GitHub security feature)
87+
# The makefile will gracefully skip gated models (e.g., embeddinggemma-300m) if token is missing
88+
if [ -n "${HF_TOKEN:-}" ]; then
89+
huggingface-cli login --token "$HF_TOKEN" --add-to-git-credential
90+
export HUGGINGFACE_HUB_TOKEN="$HF_TOKEN"
91+
fi
8592
make download-models
8693
env:
8794
CI: true
8895
CI_MINIMAL_MODELS: true
8996
HF_HUB_ENABLE_HF_TRANSFER: 1
9097
HF_HUB_DISABLE_TELEMETRY: 1
98+
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
99+
# For PRs from forks, this will be empty and the makefile will gracefully skip gated models
100+
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
101+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
102+
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}
91103

92104
- name: Start CI services
93105
run: |

.github/workflows/integration-test-k8s.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ jobs:
8181
8282
- name: Run Integration E2E tests (${{ matrix.profile }})
8383
id: e2e-test
84+
env:
85+
# Pass HF_TOKEN to E2E tests for downloading gated models (e.g., embeddinggemma-300m)
86+
# For PRs from forks, this will be empty and the E2E framework will gracefully skip gated model downloads
87+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8488
run: |
8589
set +e # Don't exit on error, we want to capture the result
8690
make e2e-test E2E_PROFILE=${{ matrix.profile }} E2E_VERBOSE=true E2E_KEEP_CLUSTER=false

.github/workflows/performance-nightly.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ jobs:
7171
CI_MINIMAL_MODELS: true
7272
HF_HUB_ENABLE_HF_TRANSFER: 1
7373
HF_HUB_DISABLE_TELEMETRY: 1
74+
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
75+
# For PRs from forks, this will be empty and the makefile will gracefully skip gated models
76+
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
77+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
78+
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}
7479
run: make download-models
7580

7681
- name: Create reports directory

.github/workflows/performance-test.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ jobs:
8181
CI_MINIMAL_MODELS: true
8282
HF_HUB_ENABLE_HF_TRANSFER: 1
8383
HF_HUB_DISABLE_TELEMETRY: 1
84+
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
85+
# For PRs from forks, this will be empty and the makefile will gracefully skip gated models
86+
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
87+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
88+
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}
8489
run: make download-models
8590

8691
- name: Download performance baselines

.github/workflows/test-and-build.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ jobs:
143143
CI_MINIMAL_MODELS: ${{ github.event_name == 'pull_request' }}
144144
HF_HUB_ENABLE_HF_TRANSFER: 1
145145
HF_HUB_DISABLE_TELEMETRY: 1
146+
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
147+
# For PRs from forks, this will be empty and the makefile will gracefully skip gated models
148+
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
149+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
150+
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}
146151
run: make download-models
147152

148153
- name: Start Milvus service

candle-binding/semantic-router_test.go

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,34 +1482,33 @@ func TestGetEmbeddingSmart(t *testing.T) {
14821482
}
14831483

14841484
t.Run("ShortTextHighLatency", func(t *testing.T) {
1485-
// Short text with high latency priority - uses Qwen3 (1024) since Gemma is not available
1485+
// Short text with high latency priority should use Gemma (768)
14861486
text := "Hello world"
14871487
embedding, err := GetEmbeddingSmart(text, 0.3, 0.8)
14881488

14891489
if err != nil {
14901490
t.Fatalf("GetEmbeddingSmart failed: %v", err)
14911491
}
14921492

1493-
// Expect Qwen3 (1024) dimension since Gemma is not available
1494-
if len(embedding) != 1024 {
1495-
t.Errorf("Expected 1024-dim embedding, got %d", len(embedding))
1493+
if len(embedding) != 768 {
1494+
t.Errorf("Expected 768-dim embedding, got %d", len(embedding))
14961495
}
14971496

14981497
t.Logf("Short text embedding generated: dim=%d", len(embedding))
14991498
})
15001499

15011500
t.Run("MediumTextBalanced", func(t *testing.T) {
1502-
// Medium text with balanced priorities - uses Qwen3 (1024) since Gemma is not available
1501+
// Medium text with balanced priorities - may select Qwen3 (1024) or Gemma (768)
15031502
text := strings.Repeat("This is a medium length text with enough words to exceed 512 tokens. ", 10)
15041503
embedding, err := GetEmbeddingSmart(text, 0.5, 0.5)
15051504

15061505
if err != nil {
15071506
t.Fatalf("GetEmbeddingSmart failed: %v", err)
15081507
}
15091508

1510-
// Expect Qwen3 (1024) dimension since Gemma is not available
1511-
if len(embedding) != 1024 {
1512-
t.Errorf("Expected 1024-dim embedding, got %d", len(embedding))
1509+
// Accept both Qwen3 (1024) and Gemma (768) dimensions
1510+
if len(embedding) != 768 && len(embedding) != 1024 {
1511+
t.Errorf("Expected 768 or 1024-dim embedding, got %d", len(embedding))
15131512
}
15141513

15151514
t.Logf("Medium text embedding generated: dim=%d", len(embedding))
@@ -1569,9 +1568,9 @@ func TestGetEmbeddingSmart(t *testing.T) {
15691568
return
15701569
}
15711570

1572-
// Expect Qwen3 (1024) since Gemma is not available
1573-
if len(embedding) != 1024 {
1574-
t.Errorf("Expected 1024-dim embedding, got %d", len(embedding))
1571+
// Smart routing may select Qwen3 (1024) or Gemma (768) based on priorities
1572+
if len(embedding) != 768 && len(embedding) != 1024 {
1573+
t.Errorf("Expected 768 or 1024-dim embedding, got %d", len(embedding))
15751574
}
15761575
t.Logf("Priority test %s: generated %d-dim embedding", tc.desc, len(embedding))
15771576
})
@@ -1594,9 +1593,9 @@ func TestGetEmbeddingSmart(t *testing.T) {
15941593
continue
15951594
}
15961595

1597-
// Expect Qwen3 (1024) since Gemma is not available
1598-
if len(embedding) != 1024 {
1599-
t.Errorf("Iteration %d: Expected 1024-dim embedding, got %d", i, len(embedding))
1596+
// Smart routing may select Qwen3 (1024) or Gemma (768)
1597+
if len(embedding) != 768 && len(embedding) != 1024 {
1598+
t.Errorf("Iteration %d: Expected 768 or 1024-dim embedding, got %d", i, len(embedding))
16001599
}
16011600

16021601
// Verify no nil pointers
@@ -1635,12 +1634,11 @@ func BenchmarkGetEmbeddingSmart(b *testing.B) {
16351634
}
16361635

16371636
// Test constants for embedding models (Phase 4.2)
1638-
// Note: Gemma model is gated and requires HF_TOKEN, so tests use Qwen3 only
16391637
const (
16401638
Qwen3EmbeddingModelPath = "../models/Qwen3-Embedding-0.6B"
1641-
GemmaEmbeddingModelPath = "" // Gemma is gated, not used in CI tests
1639+
GemmaEmbeddingModelPath = "../models/embeddinggemma-300m"
16421640
TestEmbeddingText = "This is a test sentence for embedding generation"
1643-
TestLongContextText = "This is a longer text that might benefit from long-context embedding models like Qwen3"
1641+
TestLongContextText = "This is a longer text that might benefit from long-context embedding models like Qwen3 or Gemma"
16441642
)
16451643

16461644
// Test constants for Qwen3 Multi-LoRA
@@ -1702,8 +1700,22 @@ func TestInitEmbeddingModels(t *testing.T) {
17021700
})
17031701

17041702
t.Run("InitGemmaOnly", func(t *testing.T) {
1705-
// Gemma is a gated model requiring HF_TOKEN, skip in CI
1706-
t.Skip("Skipping Gemma-only test: Gemma is a gated model requiring HF_TOKEN")
1703+
err := InitEmbeddingModels("", GemmaEmbeddingModelPath, true)
1704+
if err != nil {
1705+
t.Logf("InitEmbeddingModels (Gemma only) returned error (may already be initialized): %v", err)
1706+
1707+
// Verify functionality
1708+
_, testErr := GetEmbeddingSmart("test", 0.5, 0.5)
1709+
if testErr == nil {
1710+
t.Log("✓ ModelFactory is functional (already initialized)")
1711+
} else {
1712+
if isModelInitializationError(testErr) {
1713+
t.Skipf("Skipping test due to model unavailability: %v", testErr)
1714+
}
1715+
}
1716+
} else {
1717+
t.Log("✓ Gemma model initialized successfully")
1718+
}
17071719
})
17081720

17091721
t.Run("InitWithInvalidPaths", func(t *testing.T) {
@@ -1785,16 +1797,16 @@ func TestGetEmbeddingWithDim(t *testing.T) {
17851797

17861798
t.Run("OversizedDimension", func(t *testing.T) {
17871799
// Test graceful degradation when requested dimension exceeds model capacity
1788-
// Qwen3: 1024, so 2048 should fall back to full dimension
1800+
// Qwen3: 1024, Gemma: 768, so 2048 should fall back to full dimension
17891801
embedding, err := GetEmbeddingWithDim(TestEmbeddingText, 0.5, 0.5, 2048)
17901802
if err != nil {
17911803
t.Errorf("Should gracefully handle oversized dimension, got error: %v", err)
17921804
return
17931805
}
17941806

1795-
// Should return full dimension (1024 for Qwen3)
1796-
if len(embedding) != 1024 {
1797-
t.Errorf("Expected full dimension (1024), got %d", len(embedding))
1807+
// Should return full dimension (1024 for Qwen3 or 768 for Gemma)
1808+
if len(embedding) != 1024 && len(embedding) != 768 {
1809+
t.Errorf("Expected full dimension (1024 or 768), got %d", len(embedding))
17981810
} else {
17991811
t.Logf("✓ Oversized dimension gracefully degraded to full dimension: %d", len(embedding))
18001812
}
@@ -1889,9 +1901,6 @@ func TestEmbeddingPriorityRouting(t *testing.T) {
18891901
if err != nil {
18901902
t.Fatalf("Failed to initialize embedding models: %v", err)
18911903
}
1892-
1893-
// Note: These tests use Matryoshka dimension truncation (768) with Qwen3 model
1894-
// The dimension is truncated from Qwen3's full 1024 dimensions
18951904
testCases := []struct {
18961905
name string
18971906
text string
@@ -1906,23 +1915,23 @@ func TestEmbeddingPriorityRouting(t *testing.T) {
19061915
qualityPriority: 0.2,
19071916
latencyPriority: 0.9,
19081917
expectedDim: 768,
1909-
description: "Uses Qwen3 with Matryoshka 768 truncation",
1918+
description: "Should prefer faster embedding model (Gemma > Qwen3)",
19101919
},
19111920
{
19121921
name: "HighQualityPriority",
19131922
text: strings.Repeat("Long context text ", 30),
19141923
qualityPriority: 0.9,
19151924
latencyPriority: 0.2,
19161925
expectedDim: 768,
1917-
description: "Uses Qwen3 with Matryoshka 768 truncation",
1926+
description: "Should prefer quality model (Qwen3/Gemma)",
19181927
},
19191928
{
19201929
name: "BalancedPriority",
19211930
text: "Medium length text for embedding",
19221931
qualityPriority: 0.5,
19231932
latencyPriority: 0.5,
19241933
expectedDim: 768,
1925-
description: "Uses Qwen3 with Matryoshka 768 truncation",
1934+
description: "Should select based on text length",
19261935
},
19271936
}
19281937

deploy/helm/semantic-router/templates/deployment.yaml

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,29 +49,50 @@ spec:
4949
{{- range .Values.initContainer.models }}
5050
# Download {{ .name }}
5151
echo "Downloading {{ .name }} from {{ .repo }}..."
52-
# Remove .cache directory to ensure fresh download
53-
rm -rf "{{ .name }}/.cache" 2>/dev/null || true
54-
# Download with ignore_patterns to exclude ONNX-only files if pytorch model exists
55-
python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='{{ .repo }}', local_dir='{{ .name }}', ignore_patterns=['*.onnx', '*.msgpack', '*.h5', '*.tflite'] if '{{ .name }}' == 'all-MiniLM-L12-v2' else None)"
56-
57-
# Check for required model files
58-
echo "Checking {{ .name }} for required files:"
59-
if [ -f "{{ .name }}/pytorch_model.bin" ] || [ -f "{{ .name }}/model.safetensors" ]; then
60-
echo "✓ Found PyTorch model weights in {{ .name }}"
52+
53+
# Check if this is a gated model and if token is missing
54+
{{- if or (eq .name "embeddinggemma-300m") (contains "embeddinggemma" .name) }}
55+
if [ -z "${HF_TOKEN:-}" ] && [ -z "${HUGGINGFACE_HUB_TOKEN:-}" ]; then
56+
echo "⚠️ Warning: HF_TOKEN not set, skipping {{ .name }} download (gated model requires authentication)"
57+
echo " This is expected for PRs from forks where secrets are not available"
58+
echo " Continuing with other models..."
6159
else
62-
echo "✗ WARNING: No PyTorch model weights found in {{ .name }}"
63-
ls -la "{{ .name }}/" | head -20
60+
{{- end }}
61+
# Remove .cache directory to ensure fresh download
62+
rm -rf "{{ .name }}/.cache" 2>/dev/null || true
63+
# Download with ignore_patterns to exclude ONNX-only files if pytorch model exists
64+
python -c "
65+
from huggingface_hub import snapshot_download
66+
67+
repo_id = '{{ .repo }}'
68+
local_dir = '{{ .name }}'
69+
ignore_patterns = ['*.onnx', '*.msgpack', '*.h5', '*.tflite'] if '{{ .name }}' == 'all-MiniLM-L12-v2' else None
70+
71+
snapshot_download(repo_id=repo_id, local_dir=local_dir, ignore_patterns=ignore_patterns)
72+
print(f'✓ Successfully downloaded {repo_id}')
73+
"
74+
75+
# Check for required model files
76+
echo "Checking {{ .name }} for required files:"
77+
if [ -f "{{ .name }}/pytorch_model.bin" ] || [ -f "{{ .name }}/model.safetensors" ]; then
78+
echo "✓ Found PyTorch model weights in {{ .name }}"
79+
else
80+
echo "✗ WARNING: No PyTorch model weights found in {{ .name }}"
81+
ls -la "{{ .name }}/" | head -20
82+
fi
83+
{{- if or (eq .name "embeddinggemma-300m") (contains "embeddinggemma" .name) }}
6484
fi
85+
{{- end }}
6586
6687
{{- end }}
6788
echo "All models downloaded successfully!"
6889
ls -la /app/models/
6990
env:
7091
- name: HF_HUB_CACHE
7192
value: /tmp/hf_cache
72-
{{- with .Values.initContainer.env }}
73-
{{- toYaml . | nindent 10 }}
74-
{{- end }}
93+
{{- with .Values.initContainer.env }}
94+
{{- toYaml . | nindent 8 }}
95+
{{- end }}
7596
resources:
7697
{{- toYaml .Values.initContainer.resources | nindent 10 }}
7798
volumeMounts:

deploy/helm/semantic-router/values.yaml

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,20 +149,28 @@ initContainer:
149149
# -- Additional environment variables for the init container.
150150
# For example, to use a private Hugging Face model, you can pass a token
151151
# and specify an endpoint using a pre-existing Kubernetes secret.
152-
# env:
153-
# - name: HF_TOKEN
154-
# valueFrom:
155-
# secretKeyRef:
156-
# name: my-hf-secret
157-
# key: token
158-
# - name: HF_ENDPOINT
159-
# value: "https://huggingface.co"
160-
env: []
152+
# HF_TOKEN is required for downloading gated models like embeddinggemma-300m
153+
# For PRs from forks, this will be empty and gated models will be gracefully skipped
154+
env:
155+
- name: HF_TOKEN
156+
valueFrom:
157+
secretKeyRef:
158+
name: hf-token-secret
159+
key: token
160+
optional: true # Allow deployment even if secret doesn't exist (for local testing)
161+
- name: HUGGINGFACE_HUB_TOKEN
162+
valueFrom:
163+
secretKeyRef:
164+
name: hf-token-secret
165+
key: token
166+
optional: true # Allow deployment even if secret doesn't exist (for local testing)
161167
# -- Models to download
162168
models:
163169
# Embedding models for semantic cache and tools
164170
- name: Qwen3-Embedding-0.6B
165171
repo: Qwen/Qwen3-Embedding-0.6B
172+
- name: embeddinggemma-300m
173+
repo: google/embeddinggemma-300m
166174
- name: all-MiniLM-L12-v2
167175
repo: sentence-transformers/all-MiniLM-L12-v2
168176
- name: lora_intent_classifier_bert-base-uncased_model

0 commit comments

Comments
 (0)