Skip to content

Commit 4fa457f

Browse files
author
Hen Schwartz
committed
refactor(milvus): share lifecycle across stores
Introduce a shared Milvus lifecycle seam and refactor cache, memory, vectorstore, and routerreplay to use it. Also normalize localhost to 127.0.0.1 for Redis/Valkey clients and make Valkey integration index names unique to avoid test collisions. Signed-off-by: Hen Schwartz <hschwart@hschwart-thinkpadp1gen7.raanaii.csb>
1 parent 4c55e59 commit 4fa457f

File tree

10 files changed

+478
-330
lines changed

10 files changed

+478
-330
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package cache
2+
3+
import "strings"
4+
5+
// normalizeLocalHostForContainerRuntimes forces IPv4 loopback for localhost.
6+
// In some rootless/containerized local environments, "localhost" can resolve to
7+
// "::1" and produce EOF/reset errors when services only listen on IPv4.
8+
func normalizeLocalHostForContainerRuntimes(host string) string {
9+
normalized := strings.TrimSpace(host)
10+
if strings.EqualFold(normalized, "localhost") {
11+
return "127.0.0.1"
12+
}
13+
return normalized
14+
}

src/semantic-router/pkg/cache/milvus_cache.go

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515

1616
candle_binding "github.com/vllm-project/semantic-router/candle-binding"
1717
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
18+
milvuslifecycle "github.com/vllm-project/semantic-router/src/semantic-router/pkg/milvus"
1819
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging"
1920
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/metrics"
2021
)
@@ -46,6 +47,8 @@ type MilvusCacheOptions struct {
4647
}
4748

4849
// NewMilvusCache initializes a new Milvus-backed semantic cache instance
50+
//
51+
//nolint:funlen
4952
func NewMilvusCache(options MilvusCacheOptions) (*MilvusCache, error) {
5053
if !options.Enabled {
5154
logging.Debugf("MilvusCache: disabled, returning stub")
@@ -82,7 +85,7 @@ func NewMilvusCache(options MilvusCacheOptions) (*MilvusCache, error) {
8285
defer cancel()
8386
logging.Debugf("MilvusCache: connection timeout set to %s", timeout)
8487
}
85-
milvusClient, err := client.NewGrpcClient(dialCtx, connectionString)
88+
milvusClient, err := milvuslifecycle.ConnectGRPC(dialCtx, connectionString, 0)
8689
if err != nil {
8790
logging.Debugf("MilvusCache: failed to connect: %v", err)
8891
return nil, fmt.Errorf("failed to create Milvus client: %w", err)
@@ -107,7 +110,7 @@ func NewMilvusCache(options MilvusCacheOptions) (*MilvusCache, error) {
107110
// Test connection using the new CheckConnection method
108111
if err := cache.CheckConnection(); err != nil {
109112
logging.Debugf("MilvusCache: connection check failed: %v", err)
110-
milvusClient.Close()
113+
_ = milvusClient.Close() // best-effort close
111114
return nil, err
112115
}
113116
logging.Debugf("MilvusCache: successfully connected to Milvus")
@@ -116,7 +119,7 @@ func NewMilvusCache(options MilvusCacheOptions) (*MilvusCache, error) {
116119
logging.Debugf("MilvusCache: initializing collection '%s'", milvusConfig.Collection.Name)
117120
if err := cache.initializeCollection(); err != nil {
118121
logging.Debugf("MilvusCache: failed to initialize collection: %v", err)
119-
milvusClient.Close()
122+
_ = milvusClient.Close() // best-effort close
120123
return nil, fmt.Errorf("failed to initialize collection: %w", err)
121124
}
122125
logging.Debugf("MilvusCache: initialization complete")
@@ -125,6 +128,8 @@ func NewMilvusCache(options MilvusCacheOptions) (*MilvusCache, error) {
125128
}
126129

127130
// loadMilvusConfig reads and parses the Milvus configuration from file (Deprecated)
131+
//
132+
//nolint:cyclop,funlen
128133
func loadMilvusConfig(configPath string) (*config.MilvusConfig, error) {
129134
if configPath == "" {
130135
return nil, fmt.Errorf("milvus config path is required")
@@ -213,7 +218,6 @@ func (c *MilvusCache) initializeCollection() error {
213218
logging.Debugf("MilvusCache: failed to drop collection: %v", err)
214219
return fmt.Errorf("failed to drop collection: %w", err)
215220
}
216-
hasCollection = false
217221
logging.Debugf("MilvusCache: dropped existing collection '%s' for development", c.collectionName)
218222
logging.LogEvent("collection_dropped", map[string]interface{}{
219223
"backend": "milvus",
@@ -222,17 +226,14 @@ func (c *MilvusCache) initializeCollection() error {
222226
})
223227
}
224228

225-
// Create collection if it doesn't exist
226-
if !hasCollection {
227-
fmt.Printf("[DEBUG] Collection '%s' does not exist. AutoCreateCollection=%v\n",
229+
if err := milvuslifecycle.EnsureCollectionLoaded(ctx, c.client, c.collectionName, func(innerCtx context.Context) error {
230+
logging.Debugf("MilvusCache: collection '%s' does not exist. AutoCreateCollection=%v",
228231
c.collectionName, c.config.Development.AutoCreateCollection)
229232
if !c.config.Development.AutoCreateCollection {
230233
return fmt.Errorf("collection %s does not exist and auto-creation is disabled", c.collectionName)
231234
}
232-
233-
if err := c.createCollection(); err != nil {
234-
logging.Debugf("MilvusCache: failed to create collection: %v", err)
235-
return fmt.Errorf("failed to create collection: %w", err)
235+
if err := c.createCollection(innerCtx); err != nil {
236+
return err
236237
}
237238
logging.Debugf("MilvusCache: created new collection '%s' with dimension %d",
238239
c.collectionName, c.config.Collection.VectorField.Dimension)
@@ -241,16 +242,12 @@ func (c *MilvusCache) initializeCollection() error {
241242
"collection": c.collectionName,
242243
"dimension": c.config.Collection.VectorField.Dimension,
243244
})
245+
return nil
246+
}); err != nil {
247+
logging.Debugf("MilvusCache: failed to ensure/load collection: %v", err)
248+
return fmt.Errorf("failed to ensure/load collection: %w", err)
244249
}
245250

246-
// Load collection into memory for queries
247-
logging.Debugf("MilvusCache: loading collection '%s' into memory", c.collectionName)
248-
if err := c.client.LoadCollection(ctx, c.collectionName, false); err != nil {
249-
logging.Debugf("MilvusCache: failed to load collection: %v", err)
250-
return fmt.Errorf("failed to load collection: %w", err)
251-
}
252-
logging.Debugf("MilvusCache: collection loaded successfully")
253-
254251
return nil
255252
}
256253

@@ -299,9 +296,7 @@ func (c *MilvusCache) getEmbedding(text string) ([]float32, error) {
299296
}
300297

301298
// createCollection builds the Milvus collection with the appropriate schema
302-
func (c *MilvusCache) createCollection() error {
303-
ctx := context.Background()
304-
299+
func (c *MilvusCache) createCollection(ctx context.Context) error {
305300
// Determine embedding dimension automatically
306301
testEmbedding, err := c.getEmbedding("test")
307302
if err != nil {
@@ -446,6 +441,8 @@ func (c *MilvusCache) AddPendingRequest(requestID string, model string, query st
446441
}
447442

448443
// UpdateWithResponse completes a pending request by adding the response
444+
//
445+
//nolint:gocognit,cyclop,funlen
449446
func (c *MilvusCache) UpdateWithResponse(requestID string, responseBody []byte, ttlSeconds int) error {
450447
start := time.Now()
451448

@@ -571,6 +568,8 @@ func (c *MilvusCache) AddEntry(requestID string, model string, query string, req
571568
}
572569

573570
// AddEntriesBatch stores multiple request-response pairs in the cache efficiently
571+
//
572+
//nolint:funlen
574573
func (c *MilvusCache) AddEntriesBatch(entries []CacheEntry) error {
575574
start := time.Now()
576575

@@ -667,6 +666,8 @@ func (c *MilvusCache) Flush() error {
667666
}
668667

669668
// addEntry handles the internal logic for storing entries in Milvus
669+
//
670+
//nolint:funlen
670671
func (c *MilvusCache) addEntry(id string, requestID string, model string, query string, requestBody, responseBody []byte, ttlSeconds int) error {
671672
// Determine effective TTL: use provided value or fall back to cache default
672673
effectiveTTL := ttlSeconds
@@ -752,6 +753,8 @@ func (c *MilvusCache) FindSimilar(model string, query string) ([]byte, bool, err
752753
}
753754

754755
// FindSimilarWithThreshold searches for semantically similar cached requests using a specific threshold
756+
//
757+
//nolint:cyclop,funlen
755758
func (c *MilvusCache) FindSimilarWithThreshold(model string, query string, threshold float32) ([]byte, bool, error) {
756759
start := time.Now()
757760

@@ -969,6 +972,8 @@ func isHexString(s string) bool {
969972
// GetByID retrieves a document from Milvus by its request ID
970973
// This is much more efficient than FindSimilar when you already know the ID
971974
// Used by hybrid cache to fetch documents after local HNSW search
975+
//
976+
//nolint:funlen,cyclop,nestif
972977
func (c *MilvusCache) GetByID(ctx context.Context, requestID string) ([]byte, error) {
973978
start := time.Now()
974979

@@ -1070,6 +1075,8 @@ func (c *MilvusCache) Close() error {
10701075
//
10711076
// If these parameters are empty/zero, the method uses the cache collection's configuration.
10721077
// This allows RAG collections to use different configurations when needed.
1078+
//
1079+
//nolint:gocognit,cyclop,funlen,nestif
10731080
func (c *MilvusCache) SearchDocuments(ctx context.Context, collectionName string, queryEmbedding []float32, threshold float32, topK int, filterExpr string, contentField string, vectorFieldName string, metricType string, ef int) ([]string, []float32, error) {
10741081
if !c.enabled {
10751082
return nil, nil, fmt.Errorf("milvus cache is not enabled")
@@ -1179,6 +1186,8 @@ func (c *MilvusCache) SearchDocuments(ctx context.Context, collectionName string
11791186
}
11801187

11811188
// GetStats provides current cache performance metrics
1189+
//
1190+
//nolint:nestif
11821191
func (c *MilvusCache) GetStats() CacheStats {
11831192
c.mu.RLock()
11841193
defer c.mu.RUnlock()

src/semantic-router/pkg/cache/redis_cache.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ func NewRedisCache(options RedisCacheOptions) (*RedisCache, error) {
7373
redisConfig.Connection.Host, redisConfig.Connection.Port, redisConfig.Index.Name)
7474

7575
// Establish connection to Redis server
76-
logging.Debugf("RedisCache: connecting to Redis at %s:%d", redisConfig.Connection.Host, redisConfig.Connection.Port)
76+
resolvedHost := normalizeLocalHostForContainerRuntimes(redisConfig.Connection.Host)
77+
logging.Debugf("RedisCache: connecting to Redis at %s:%d (configured host=%s)",
78+
resolvedHost, redisConfig.Connection.Port, redisConfig.Connection.Host)
7779

7880
redisClient := redis.NewClient(&redis.Options{
79-
Addr: fmt.Sprintf("%s:%d", redisConfig.Connection.Host, redisConfig.Connection.Port),
81+
Addr: fmt.Sprintf("%s:%d", resolvedHost, redisConfig.Connection.Port),
8082
Password: redisConfig.Connection.Password,
8183
DB: redisConfig.Connection.Database,
8284
Protocol: 2, // Use RESP2 protocol for compatibility

src/semantic-router/pkg/cache/valkey_cache.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@ func NewValkeyCache(options ValkeyCacheOptions) (*ValkeyCache, error) {
6363
logging.Debugf("ValkeyCache: config loaded - host=%s:%d, index=%s, dimension=auto-detect",
6464
valkeyConfig.Connection.Host, valkeyConfig.Connection.Port, valkeyConfig.Index.Name)
6565

66-
logging.Debugf("ValkeyCache: connecting to Valkey at %s:%d", valkeyConfig.Connection.Host, valkeyConfig.Connection.Port)
66+
resolvedHost := normalizeLocalHostForContainerRuntimes(valkeyConfig.Connection.Host)
67+
logging.Debugf("ValkeyCache: connecting to Valkey at %s:%d (configured host=%s)",
68+
resolvedHost, valkeyConfig.Connection.Port, valkeyConfig.Connection.Host)
6769

6870
clientConfig := config.NewClientConfiguration().
6971
WithAddress(&config.NodeAddress{
70-
Host: valkeyConfig.Connection.Host,
72+
Host: resolvedHost,
7173
Port: valkeyConfig.Connection.Port,
7274
})
7375

src/semantic-router/pkg/cache/valkey_cache_integration_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ func TestValkeyCacheIntegration_FLATIndexType(t *testing.T) {
476476
valkeyConfig.Connection.Port = 6379
477477
valkeyConfig.Connection.Database = 0
478478

479-
valkeyConfig.Index.Name = "test_flat_idx"
479+
valkeyConfig.Index.Name = fmt.Sprintf("test_flat_idx_%d", time.Now().UnixNano())
480480
valkeyConfig.Index.Prefix = "flat:"
481481
valkeyConfig.Index.VectorField.Name = "embedding"
482482
valkeyConfig.Index.VectorField.Dimension = 384
@@ -608,7 +608,7 @@ func TestValkeyCacheIntegration_L2MetricType(t *testing.T) {
608608
valkeyConfig.Connection.Port = 6379
609609
valkeyConfig.Connection.Database = 0
610610

611-
valkeyConfig.Index.Name = "test_l2_idx"
611+
valkeyConfig.Index.Name = fmt.Sprintf("test_l2_idx_%d", time.Now().UnixNano())
612612
valkeyConfig.Index.Prefix = "l2:"
613613
valkeyConfig.Index.VectorField.Name = "embedding"
614614
valkeyConfig.Index.VectorField.Dimension = 384
@@ -657,7 +657,7 @@ func TestValkeyCacheIntegration_IPMetricType(t *testing.T) {
657657
valkeyConfig.Connection.Port = 6379
658658
valkeyConfig.Connection.Database = 0
659659

660-
valkeyConfig.Index.Name = "test_ip_idx"
660+
valkeyConfig.Index.Name = fmt.Sprintf("test_ip_idx_%d", time.Now().UnixNano())
661661
valkeyConfig.Index.Prefix = "ip:"
662662
valkeyConfig.Index.VectorField.Name = "embedding"
663663
valkeyConfig.Index.VectorField.Dimension = 384

0 commit comments

Comments
 (0)