Skip to content

Commit cec298c

Browse files
author
Liav Weiss
committed
feat: re-enable EmbeddingGemma-300m support
Signed-off-by: Liav Weiss <[email protected]>
1 parent 1d1439a commit cec298c

File tree

13 files changed

+222
-50
lines changed

13 files changed

+222
-50
lines changed

.github/workflows/integration-test-docker.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,47 @@ jobs:
8282
- name: Download models
8383
run: |
8484
echo "Downloading minimal models for CI..."
85+
echo "Debug: Repository context: ${{ github.repository }}"
86+
echo "Debug: Event: ${{ github.event_name }}"
87+
echo "Debug: PR head repo: ${{ github.event.pull_request.head.repo.full_name }}"
88+
# Check if HF_TOKEN secret is available (even if empty, this helps debug)
89+
if [ -z "${HF_TOKEN:-}" ]; then
90+
echo "⚠️ HF_TOKEN secret is not available or is empty"
91+
echo " Repository: ${{ github.repository }}"
92+
echo " Event: ${{ github.event_name }}"
93+
if [ "${{ github.event_name }}" = "pull_request" ]; then
94+
echo " Note: For PRs from forks, secrets are not available (GitHub security)"
95+
echo " To test with Gemma: Add HF_TOKEN to your fork secrets and use workflow_dispatch"
96+
fi
97+
echo " Gemma model download will be gracefully skipped (workflow will continue)"
98+
else
99+
echo "✅ HF_TOKEN is set (length: ${#HF_TOKEN} characters)"
100+
echo "Authenticating with HuggingFace..."
101+
huggingface-cli login --token "$HF_TOKEN" --add-to-git-credential
102+
echo "✅ HuggingFace authentication successful"
103+
# Also ensure environment variable is set for hf CLI
104+
export HUGGINGFACE_HUB_TOKEN="$HF_TOKEN"
105+
fi
106+
# Export all environment variables for make
107+
export CI=true
108+
export CI_MINIMAL_MODELS=true
109+
export HF_HUB_ENABLE_HF_TRANSFER=1
110+
export HF_HUB_DISABLE_TELEMETRY=1
111+
# Pass token to make if available
112+
if [ -n "$HF_TOKEN" ]; then
113+
export HF_TOKEN="$HF_TOKEN"
114+
export HUGGINGFACE_HUB_TOKEN="$HF_TOKEN"
115+
fi
85116
make download-models
86117
env:
87118
CI: true
88119
CI_MINIMAL_MODELS: true
89120
HF_HUB_ENABLE_HF_TRANSFER: 1
90121
HF_HUB_DISABLE_TELEMETRY: 1
122+
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
123+
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
124+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
125+
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}
91126

92127
- name: Start CI services
93128
run: |

.github/workflows/integration-test-k8s.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ jobs:
8181
8282
- name: Run Integration E2E tests (${{ matrix.profile }})
8383
id: e2e-test
84+
env:
85+
# Pass HF_TOKEN to E2E tests for downloading gated models (e.g., embeddinggemma-300m)
86+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8487
run: |
8588
set +e # Don't exit on error, we want to capture the result
8689
make e2e-test E2E_PROFILE=${{ matrix.profile }} E2E_VERBOSE=true E2E_KEEP_CLUSTER=false

.github/workflows/performance-nightly.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ jobs:
7171
CI_MINIMAL_MODELS: false
7272
HF_HUB_ENABLE_HF_TRANSFER: 1
7373
HF_HUB_DISABLE_TELEMETRY: 1
74+
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
75+
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
76+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
77+
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}
7478
run: make download-models
7579

7680
- name: Create reports directory

.github/workflows/performance-test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ jobs:
8181
CI_MINIMAL_MODELS: true
8282
HF_HUB_ENABLE_HF_TRANSFER: 1
8383
HF_HUB_DISABLE_TELEMETRY: 1
84+
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
85+
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
86+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
87+
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}
8488
run: make download-models
8589

8690
- name: Download performance baselines

.github/workflows/test-and-build.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ jobs:
143143
CI_MINIMAL_MODELS: ${{ github.event_name == 'pull_request' }}
144144
HF_HUB_ENABLE_HF_TRANSFER: 1
145145
HF_HUB_DISABLE_TELEMETRY: 1
146+
# HF_TOKEN is required for downloading gated models (e.g., embeddinggemma-300m)
147+
# The hf CLI uses HUGGINGFACE_HUB_TOKEN, so we set both for compatibility
148+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
149+
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}
146150
run: make download-models
147151

148152
- name: Start Milvus service

candle-binding/semantic-router_test.go

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,34 +1482,33 @@ func TestGetEmbeddingSmart(t *testing.T) {
14821482
}
14831483

14841484
t.Run("ShortTextHighLatency", func(t *testing.T) {
1485-
// Short text with high latency priority - uses Qwen3 (1024) since Gemma is not available
1485+
// Short text with high latency priority should use Gemma (768)
14861486
text := "Hello world"
14871487
embedding, err := GetEmbeddingSmart(text, 0.3, 0.8)
14881488

14891489
if err != nil {
14901490
t.Fatalf("GetEmbeddingSmart failed: %v", err)
14911491
}
14921492

1493-
// Expect Qwen3 (1024) dimension since Gemma is not available
1494-
if len(embedding) != 1024 {
1495-
t.Errorf("Expected 1024-dim embedding, got %d", len(embedding))
1493+
if len(embedding) != 768 {
1494+
t.Errorf("Expected 768-dim embedding, got %d", len(embedding))
14961495
}
14971496

14981497
t.Logf("Short text embedding generated: dim=%d", len(embedding))
14991498
})
15001499

15011500
t.Run("MediumTextBalanced", func(t *testing.T) {
1502-
// Medium text with balanced priorities - uses Qwen3 (1024) since Gemma is not available
1501+
// Medium text with balanced priorities - may select Qwen3 (1024) or Gemma (768)
15031502
text := strings.Repeat("This is a medium length text with enough words to exceed 512 tokens. ", 10)
15041503
embedding, err := GetEmbeddingSmart(text, 0.5, 0.5)
15051504

15061505
if err != nil {
15071506
t.Fatalf("GetEmbeddingSmart failed: %v", err)
15081507
}
15091508

1510-
// Expect Qwen3 (1024) dimension since Gemma is not available
1511-
if len(embedding) != 1024 {
1512-
t.Errorf("Expected 1024-dim embedding, got %d", len(embedding))
1509+
// Accept both Qwen3 (1024) and Gemma (768) dimensions
1510+
if len(embedding) != 768 && len(embedding) != 1024 {
1511+
t.Errorf("Expected 768 or 1024-dim embedding, got %d", len(embedding))
15131512
}
15141513

15151514
t.Logf("Medium text embedding generated: dim=%d", len(embedding))
@@ -1569,9 +1568,9 @@ func TestGetEmbeddingSmart(t *testing.T) {
15691568
return
15701569
}
15711570

1572-
// Expect Qwen3 (1024) since Gemma is not available
1573-
if len(embedding) != 1024 {
1574-
t.Errorf("Expected 1024-dim embedding, got %d", len(embedding))
1571+
// Smart routing may select Qwen3 (1024) or Gemma (768) based on priorities
1572+
if len(embedding) != 768 && len(embedding) != 1024 {
1573+
t.Errorf("Expected 768 or 1024-dim embedding, got %d", len(embedding))
15751574
}
15761575
t.Logf("Priority test %s: generated %d-dim embedding", tc.desc, len(embedding))
15771576
})
@@ -1594,9 +1593,9 @@ func TestGetEmbeddingSmart(t *testing.T) {
15941593
continue
15951594
}
15961595

1597-
// Expect Qwen3 (1024) since Gemma is not available
1598-
if len(embedding) != 1024 {
1599-
t.Errorf("Iteration %d: Expected 1024-dim embedding, got %d", i, len(embedding))
1596+
// Smart routing may select Qwen3 (1024) or Gemma (768)
1597+
if len(embedding) != 768 && len(embedding) != 1024 {
1598+
t.Errorf("Iteration %d: Expected 768 or 1024-dim embedding, got %d", i, len(embedding))
16001599
}
16011600

16021601
// Verify no nil pointers
@@ -1635,12 +1634,11 @@ func BenchmarkGetEmbeddingSmart(b *testing.B) {
16351634
}
16361635

16371636
// Test constants for embedding models (Phase 4.2)
1638-
// Note: Gemma model is gated and requires HF_TOKEN, so tests use Qwen3 only
16391637
const (
16401638
Qwen3EmbeddingModelPath = "../models/Qwen3-Embedding-0.6B"
1641-
GemmaEmbeddingModelPath = "" // Gemma is gated, not used in CI tests
1639+
GemmaEmbeddingModelPath = "../models/embeddinggemma-300m"
16421640
TestEmbeddingText = "This is a test sentence for embedding generation"
1643-
TestLongContextText = "This is a longer text that might benefit from long-context embedding models like Qwen3"
1641+
TestLongContextText = "This is a longer text that might benefit from long-context embedding models like Qwen3 or Gemma"
16441642
)
16451643

16461644
// Test constants for Qwen3 Multi-LoRA
@@ -1702,8 +1700,22 @@ func TestInitEmbeddingModels(t *testing.T) {
17021700
})
17031701

17041702
t.Run("InitGemmaOnly", func(t *testing.T) {
1705-
// Gemma is a gated model requiring HF_TOKEN, skip in CI
1706-
t.Skip("Skipping Gemma-only test: Gemma is a gated model requiring HF_TOKEN")
1703+
err := InitEmbeddingModels("", GemmaEmbeddingModelPath, true)
1704+
if err != nil {
1705+
t.Logf("InitEmbeddingModels (Gemma only) returned error (may already be initialized): %v", err)
1706+
1707+
// Verify functionality
1708+
_, testErr := GetEmbeddingSmart("test", 0.5, 0.5)
1709+
if testErr == nil {
1710+
t.Log("✓ ModelFactory is functional (already initialized)")
1711+
} else {
1712+
if isModelInitializationError(testErr) {
1713+
t.Skipf("Skipping test due to model unavailability: %v", testErr)
1714+
}
1715+
}
1716+
} else {
1717+
t.Log("✓ Gemma model initialized successfully")
1718+
}
17071719
})
17081720

17091721
t.Run("InitWithInvalidPaths", func(t *testing.T) {
@@ -1785,16 +1797,16 @@ func TestGetEmbeddingWithDim(t *testing.T) {
17851797

17861798
t.Run("OversizedDimension", func(t *testing.T) {
17871799
// Test graceful degradation when requested dimension exceeds model capacity
1788-
// Qwen3: 1024, so 2048 should fall back to full dimension
1800+
// Qwen3: 1024, Gemma: 768, so 2048 should fall back to full dimension
17891801
embedding, err := GetEmbeddingWithDim(TestEmbeddingText, 0.5, 0.5, 2048)
17901802
if err != nil {
17911803
t.Errorf("Should gracefully handle oversized dimension, got error: %v", err)
17921804
return
17931805
}
17941806

1795-
// Should return full dimension (1024 for Qwen3)
1796-
if len(embedding) != 1024 {
1797-
t.Errorf("Expected full dimension (1024), got %d", len(embedding))
1807+
// Should return full dimension (1024 for Qwen3 or 768 for Gemma)
1808+
if len(embedding) != 1024 && len(embedding) != 768 {
1809+
t.Errorf("Expected full dimension (1024 or 768), got %d", len(embedding))
17981810
} else {
17991811
t.Logf("✓ Oversized dimension gracefully degraded to full dimension: %d", len(embedding))
18001812
}
@@ -1889,9 +1901,6 @@ func TestEmbeddingPriorityRouting(t *testing.T) {
18891901
if err != nil {
18901902
t.Fatalf("Failed to initialize embedding models: %v", err)
18911903
}
1892-
1893-
// Note: These tests use Matryoshka dimension truncation (768) with Qwen3 model
1894-
// The dimension is truncated from Qwen3's full 1024 dimensions
18951904
testCases := []struct {
18961905
name string
18971906
text string
@@ -1906,23 +1915,23 @@ func TestEmbeddingPriorityRouting(t *testing.T) {
19061915
qualityPriority: 0.2,
19071916
latencyPriority: 0.9,
19081917
expectedDim: 768,
1909-
description: "Uses Qwen3 with Matryoshka 768 truncation",
1918+
description: "Should prefer faster embedding model (Gemma > Qwen3)",
19101919
},
19111920
{
19121921
name: "HighQualityPriority",
19131922
text: strings.Repeat("Long context text ", 30),
19141923
qualityPriority: 0.9,
19151924
latencyPriority: 0.2,
19161925
expectedDim: 768,
1917-
description: "Uses Qwen3 with Matryoshka 768 truncation",
1926+
description: "Should prefer quality model (Qwen3/Gemma)",
19181927
},
19191928
{
19201929
name: "BalancedPriority",
19211930
text: "Medium length text for embedding",
19221931
qualityPriority: 0.5,
19231932
latencyPriority: 0.5,
19241933
expectedDim: 768,
1925-
description: "Uses Qwen3 with Matryoshka 768 truncation",
1934+
description: "Should select based on text length",
19261935
},
19271936
}
19281937

deploy/helm/semantic-router/templates/deployment.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ spec:
6969
env:
7070
- name: HF_HUB_CACHE
7171
value: /tmp/hf_cache
72-
{{- with .Values.initContainer.env }}
73-
{{- toYaml . | nindent 10 }}
74-
{{- end }}
72+
{{- with .Values.initContainer.env }}
73+
{{- toYaml . | nindent 8 }}
74+
{{- end }}
7575
resources:
7676
{{- toYaml .Values.initContainer.resources | nindent 10 }}
7777
volumeMounts:

deploy/helm/semantic-router/values.yaml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,20 +149,21 @@ initContainer:
149149
# -- Additional environment variables for the init container.
150150
# For example, to use a private Hugging Face model, you can pass a token
151151
# and specify an endpoint using a pre-existing Kubernetes secret.
152-
# env:
153-
# - name: HF_TOKEN
154-
# valueFrom:
155-
# secretKeyRef:
156-
# name: my-hf-secret
157-
# key: token
158-
# - name: HF_ENDPOINT
159-
# value: "https://huggingface.co"
160-
env: []
152+
# HF_TOKEN is required for downloading gated models like embeddinggemma-300m
153+
env:
154+
- name: HF_TOKEN
155+
valueFrom:
156+
secretKeyRef:
157+
name: hf-token-secret
158+
key: token
159+
optional: true # Allow deployment even if secret doesn't exist (for local testing)
161160
# -- Models to download
162161
models:
163162
# Embedding models for semantic cache and tools
164163
- name: Qwen3-Embedding-0.6B
165164
repo: Qwen/Qwen3-Embedding-0.6B
165+
- name: embeddinggemma-300m
166+
repo: google/embeddinggemma-300m
166167
- name: all-MiniLM-L12-v2
167168
repo: sentence-transformers/all-MiniLM-L12-v2
168169
- name: lora_intent_classifier_bert-base-uncased_model

e2e/pkg/framework/runner.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,19 @@ func (r *Runner) Run(ctx context.Context) error {
114114
// Set Kubernetes client for report generator
115115
r.reporter.SetKubeClient(kubeClient)
116116

117+
// Step 3.5: Create HF_TOKEN secret if available (for gated model downloads)
118+
// This is required for downloading gated models like google/embeddinggemma-300m
119+
if hfToken := os.Getenv("HF_TOKEN"); hfToken != "" {
120+
if err := r.createHFTokenSecret(ctx, kubeClient); err != nil {
121+
r.log("⚠️ Warning: Failed to create HF_TOKEN secret: %v", err)
122+
r.log(" Model downloads may fail if gated models (e.g., embeddinggemma-300m) are required")
123+
} else {
124+
r.log("✅ Created HF_TOKEN secret for gated model downloads")
125+
}
126+
} else {
127+
r.log("ℹ️ HF_TOKEN not set - gated models (e.g., embeddinggemma-300m) may not be downloadable")
128+
}
129+
117130
// Step 4: Setup profile (deploy Helm charts, etc.)
118131
if !r.opts.SkipSetup {
119132
setupOpts := &SetupOptions{
@@ -492,6 +505,68 @@ func (r *Runner) collectSemanticRouterLogs(ctx context.Context, client *kubernet
492505
return nil
493506
}
494507

508+
// createHFTokenSecret creates a Kubernetes secret for HF_TOKEN if it's available in the environment
509+
// This is required for the init container to download gated models like google/embeddinggemma-300m
510+
// The secret must be in the same namespace as the semantic-router deployment (vllm-semantic-router-system)
511+
// because Kubernetes secrets are namespace-scoped
512+
func (r *Runner) createHFTokenSecret(ctx context.Context, kubeClient *kubernetes.Clientset) error {
513+
hfToken := os.Getenv("HF_TOKEN")
514+
if hfToken == "" {
515+
return nil // No token to create
516+
}
517+
518+
// All E2E profiles deploy semantic-router to this namespace
519+
nsName := "vllm-semantic-router-system"
520+
521+
// First, ensure the namespace exists
522+
_, err := kubeClient.CoreV1().Namespaces().Get(ctx, nsName, metav1.GetOptions{})
523+
if err != nil {
524+
// Namespace doesn't exist, create it
525+
ns := &corev1.Namespace{
526+
ObjectMeta: metav1.ObjectMeta{
527+
Name: nsName,
528+
},
529+
}
530+
_, err = kubeClient.CoreV1().Namespaces().Create(ctx, ns, metav1.CreateOptions{})
531+
if err != nil && !strings.Contains(err.Error(), "already exists") {
532+
// If we can't create the namespace, that's okay - the profile will create it
533+
r.log("⚠️ Could not create namespace %s (will be created by profile): %v", nsName, err)
534+
}
535+
}
536+
537+
// Create the secret in the namespace where semantic-router is deployed
538+
secret := &corev1.Secret{
539+
ObjectMeta: metav1.ObjectMeta{
540+
Name: "hf-token-secret",
541+
Namespace: nsName,
542+
},
543+
Type: corev1.SecretTypeOpaque,
544+
StringData: map[string]string{
545+
"token": hfToken,
546+
},
547+
}
548+
549+
_, err = kubeClient.CoreV1().Secrets(nsName).Create(ctx, secret, metav1.CreateOptions{})
550+
if err != nil {
551+
// If secret already exists, update it
552+
if strings.Contains(err.Error(), "already exists") {
553+
_, err = kubeClient.CoreV1().Secrets(nsName).Update(ctx, secret, metav1.UpdateOptions{})
554+
if err != nil {
555+
return fmt.Errorf("failed to update existing HF_TOKEN secret in %s: %w", nsName, err)
556+
}
557+
return nil
558+
}
559+
// If namespace still doesn't exist, that's okay - it will be created by Helm
560+
if strings.Contains(err.Error(), "not found") {
561+
r.log("⚠️ Namespace %s not found yet (will be created by profile)", nsName)
562+
return nil
563+
}
564+
return fmt.Errorf("failed to create HF_TOKEN secret in %s: %w", nsName, err)
565+
}
566+
567+
return nil
568+
}
569+
495570
func getPodReadyStatus(pod corev1.Pod) string {
496571
readyCount := 0
497572
totalCount := len(pod.Status.ContainerStatuses)

e2e/profiles/ai-gateway/values.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,7 @@ config:
628628
# - EmbeddingGemma-300M: Up to 8K context, fast inference, Matryoshka support (768/512/256/128)
629629
embedding_models:
630630
qwen3_model_path: "models/Qwen3-Embedding-0.6B"
631+
gemma_model_path: "models/embeddinggemma-300m"
631632
use_cpu: true # Set to false for GPU acceleration (requires CUDA)
632633

633634
# Observability Configuration

0 commit comments

Comments
 (0)