diff --git a/config/config.yaml b/config/config.yaml index b33264872..2eec11808 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1416,9 +1416,25 @@ global: enable_metrics: true memory: enabled: true + backend: milvus auto_store: true disabled_routes: [] disabled_models: [] + valkey: + host: valkey + port: 6379 + database: 0 + password: valkey-secret + timeout: 10 + collection_prefix: "mem:" + index_name: mem_idx + dimension: 384 + metric_type: COSINE + index_m: 16 + index_ef_construction: 200 + tls_enabled: false + tls_ca_path: "" + tls_insecure_skip_verify: false milvus: address: milvus:19530 collection: agentic_memory diff --git a/deploy/examples/runtime/README.md b/deploy/examples/runtime/README.md index 247431ae5..2f20ca498 100644 --- a/deploy/examples/runtime/README.md +++ b/deploy/examples/runtime/README.md @@ -2,8 +2,10 @@ This directory under `deploy/examples/runtime/` holds repo-owned runtime support examples that are not part of the user-facing `config/` contract. +- `memory/`: agentic memory backend configuration references (Milvus, Valkey) - `semantic-cache/`: external semantic-cache backend example files - `response-api/`: external Response API Redis example files - `tools/`: local tools database examples +- `vector-store/`: vector store backend configuration references These files exist for local development, tutorials, and tests. They are not the canonical router config surface. diff --git a/deploy/examples/runtime/memory/valkey.yaml b/deploy/examples/runtime/memory/valkey.yaml new file mode 100644 index 000000000..81a7aa82c --- /dev/null +++ b/deploy/examples/runtime/memory/valkey.yaml @@ -0,0 +1,93 @@ +# Valkey Memory Store Configuration for Agentic Memory +# This configuration file contains settings for using Valkey (with the Search +# module) as the agentic memory backend. +# +# To use this configuration: +# 1. Set backend: "valkey" in global.stores.memory in your config.yaml +# 2. Inline the valkey settings from this file into global.stores.memory.valkey +# 3. Ensure Valkey server with the Search module is running and accessible +# +# Valkey Search module is required for FT.CREATE / FT.SEARCH vector operations. +# Use the valkey/valkey-bundle image or compile Valkey with --enable-search. + +# Connection settings +host: "localhost" # For production: use your Valkey cluster endpoint +port: 6379 # Standard Valkey port +database: 0 # Database number (0-15) +password: "" # Authentication password (leave empty if not required) +timeout: 10 # Connection/request timeout in seconds + +# Index and storage settings +collection_prefix: "mem:" # Key prefix for HASH documents +index_name: "mem_idx" # FT.CREATE index name +dimension: 384 # Embedding vector dimension (must match model) +metric_type: "COSINE" # Distance metric: COSINE, L2, or IP + +# HNSW index tuning +index_m: 16 # Bi-directional links per node (higher = more accurate, more RAM) +index_ef_construction: 256 # Build-time search width (higher = better recall, slower build) + +# TLS settings +tls_enabled: false # Enable TLS for the Valkey connection +tls_ca_path: "" # Path to PEM-encoded CA cert (empty = system trust store) +tls_insecure_skip_verify: false # Skip server cert verification (development only) + +# Full canonical config.yaml usage example: +# +# global: +# stores: +# memory: +# enabled: true +# backend: valkey # <-- select Valkey backend +# auto_store: true +# valkey: # <-- Valkey-specific settings +# host: valkey +# port: 6379 +# database: 0 +# timeout: 10 +# collection_prefix: "mem:" +# index_name: mem_idx +# dimension: 384 +# metric_type: COSINE +# index_m: 16 +# index_ef_construction: 256 +# tls_enabled: false +# tls_ca_path: "" +# tls_insecure_skip_verify: false +# embedding_model: bert +# default_retrieval_limit: 5 +# default_similarity_threshold: 0.70 +# hybrid_search: true +# hybrid_mode: rerank +# adaptive_threshold: true +# quality_scoring: +# initial_strength_days: 30 +# prune_threshold: 0.15 +# max_memories_per_user: 200 +# reflection: +# enabled: true +# algorithm: recency_semantic +# max_inject_tokens: 512 +# recency_decay_days: 14 +# dedup_threshold: 0.9 +# +# Example configurations for different environments: +# +# Local Development (Docker): +# host: "localhost" +# port: 6379 +# password: "" +# +# Production (Docker / Kubernetes): +# host: "valkey-service.valkey-system.svc.cluster.local" +# port: 6379 +# password: "${VALKEY_PASSWORD}" # from secret +# index_m: 32 # higher recall for production +# index_ef_construction: 512 +# +# Kubernetes with TLS: +# host: "valkey-tls.valkey-system.svc.cluster.local" +# port: 6380 +# password: "${VALKEY_PASSWORD}" +# tls_enabled: true +# tls_ca_path: "/etc/valkey/certs/ca.pem" # mounted from secret diff --git a/e2e/config/config.memory-user-valkey.yaml b/e2e/config/config.memory-user-valkey.yaml new file mode 100644 index 000000000..b464ad7d6 --- /dev/null +++ b/e2e/config/config.memory-user-valkey.yaml @@ -0,0 +1,134 @@ +# Same routing as config.memory-user.yaml but uses the Valkey memory backend +# instead of Milvus. Requires a Valkey instance with the Search module on the +# same Docker network (e.g. semantic-router-valkey:6379). +# +# Usage: point your router's CONFIG_FILE to this file. + +version: v0.3 +listeners: + - name: http-8888 + address: 0.0.0.0 + port: 8888 + timeout: 300s +providers: + defaults: + default_model: qwen3 + models: + - name: qwen3 + provider_model_id: qwen3 + backend_refs: + - name: llm_katan + weight: 1 + endpoint: host.docker.internal:8000 + protocol: http +routing: + modelCards: + - name: qwen3 + modality: text + signals: + domains: + - name: general + description: General queries for memory testing + mmlu_categories: [other] + keywords: + - name: no_memory_trigger + operator: OR + keywords: [NOMEM_MARKER] + - name: custom_threshold_trigger + operator: OR + keywords: [THRESHOLD_MARKER] + decisions: + - name: no_memory_route + description: Route with memory explicitly disabled for per-decision testing + priority: 200 + rules: + operator: OR + conditions: + - type: keyword + name: no_memory_trigger + modelRefs: + - model: qwen3 + use_reasoning: false + plugins: + - type: system_prompt + configuration: + system_prompt: You are a helpful assistant. Memory access is disabled for this route. + mode: insert + - type: memory + configuration: + enabled: false + + - name: custom_threshold_route + description: Route with high similarity threshold for per-decision testing + priority: 150 + rules: + operator: OR + conditions: + - type: keyword + name: custom_threshold_trigger + modelRefs: + - model: qwen3 + use_reasoning: false + plugins: + - type: system_prompt + configuration: + system_prompt: You are a helpful assistant with strict memory matching. + mode: insert + - type: memory + configuration: + enabled: true + retrieval_limit: 5 + similarity_threshold: 0.99 + auto_store: true + + - name: default_route + description: Default route for memory testing + priority: 1 + rules: + operator: OR + conditions: + - type: domain + name: general + modelRefs: + - model: qwen3 + use_reasoning: false + plugins: + - type: system_prompt + configuration: + system_prompt: You are MoM, a helpful AI assistant with memory. You remember important facts about users and use this context to provide personalized assistance. + mode: insert + - type: memory + configuration: + enabled: true + retrieval_limit: 5 + similarity_threshold: 0.45 + auto_store: true +global: + services: + response_api: + enabled: true + store_backend: memory + ttl_seconds: 86400 + stores: + memory: + enabled: true + backend: valkey + auto_store: true + valkey: + host: semantic-router-valkey + port: 6379 + database: 0 + timeout: 10 + collection_prefix: "mem:" + index_name: mem_idx + dimension: 384 + metric_type: COSINE + index_m: 16 + index_ef_construction: 256 + embedding_model: mmbert + default_retrieval_limit: 5 + default_similarity_threshold: 0.45 + semantic_cache: + embedding_model: mmbert + vector_store: + embedding_model: mmbert diff --git a/src/semantic-router/pkg/config/reference_config_global_test.go b/src/semantic-router/pkg/config/reference_config_global_test.go index 697171bef..6cf2891d4 100644 --- a/src/semantic-router/pkg/config/reference_config_global_test.go +++ b/src/semantic-router/pkg/config/reference_config_global_test.go @@ -119,6 +119,7 @@ func assertReferenceConfigSemanticCacheCoverage(t testingT, semanticCache map[st func assertReferenceConfigMemoryCoverage(t testingT, memory map[string]interface{}) { assertMapCoversStructFields(t, memory, reflect.TypeOf(MemoryConfig{}), "global.stores.memory") assertMapCoversStructFields(t, mustMapAt(t, memory, "milvus"), reflect.TypeOf(MemoryMilvusConfig{}), "global.stores.memory.milvus") + assertMapCoversStructFields(t, mustMapAt(t, memory, "valkey"), reflect.TypeOf(MemoryValkeyConfig{}), "global.stores.memory.valkey") assertMapCoversStructFields(t, mustMapAt(t, memory, "quality_scoring"), reflect.TypeOf(MemoryQualityScoringConfig{}), "global.stores.memory.quality_scoring") assertMapCoversStructFields(t, mustMapAt(t, memory, "reflection"), reflect.TypeOf(MemoryReflectionConfig{}), "global.stores.memory.reflection") } diff --git a/src/semantic-router/pkg/config/runtime_config.go b/src/semantic-router/pkg/config/runtime_config.go index d34586894..bf6ce5431 100644 --- a/src/semantic-router/pkg/config/runtime_config.go +++ b/src/semantic-router/pkg/config/runtime_config.go @@ -207,10 +207,12 @@ type SemanticCache struct { type MemoryConfig struct { Enabled bool `yaml:"enabled,omitempty"` + Backend string `yaml:"backend,omitempty"` AutoStore bool `yaml:"auto_store,omitempty"` DisabledRoutes []string `yaml:"disabled_routes,omitempty"` DisabledModels []string `yaml:"disabled_models,omitempty"` Milvus MemoryMilvusConfig `yaml:"milvus,omitempty"` + Valkey *MemoryValkeyConfig `yaml:"valkey,omitempty"` RedisCache *MemoryRedisCacheConfig `yaml:"redis_cache,omitempty"` EmbeddingModel string `yaml:"embedding_model,omitempty"` ExtractionBatchSize int `yaml:"extraction_batch_size,omitempty"` @@ -262,6 +264,40 @@ type MemoryMilvusConfig struct { NumPartitions int `yaml:"num_partitions,omitempty"` } +// MemoryValkeyConfig holds configuration for the Valkey memory store backend. +// Uses Valkey with the Search module for vector similarity operations. +type MemoryValkeyConfig struct { + // Host is the Valkey server hostname (default "localhost"). + Host string `yaml:"host"` + // Port is the Valkey server port (default 6379). + Port int `yaml:"port"` + // Database number (default 0). + Database int `yaml:"database"` + // Password for Valkey authentication (optional). + Password string `yaml:"password,omitempty"` + // Timeout is the connection/request timeout in seconds (default 10). + Timeout int `yaml:"timeout"` + // CollectionPrefix is the prefix for hash keys (default "mem:"). + CollectionPrefix string `yaml:"collection_prefix,omitempty"` + // IndexName is the FT index name (default "mem_idx"). + IndexName string `yaml:"index_name,omitempty"` + // Dimension is the embedding vector dimension (default 384). + Dimension int `yaml:"dimension,omitempty"` + // MetricType is the distance metric: "COSINE", "L2", or "IP" (default "COSINE"). + MetricType string `yaml:"metric_type,omitempty"` + // IndexM is the HNSW M parameter (default 16). + IndexM int `yaml:"index_m,omitempty"` + // IndexEfConstruction is the HNSW efConstruction parameter (default 256). + IndexEfConstruction int `yaml:"index_ef_construction,omitempty"` + // TLSEnabled enables TLS for the Valkey connection. + TLSEnabled bool `yaml:"tls_enabled,omitempty"` + // TLSCAPath is the path to a PEM-encoded CA certificate file for server verification. + // When empty and TLS is enabled, the system's default trust store is used. + TLSCAPath string `yaml:"tls_ca_path,omitempty"` + // TLSInsecureSkipVerify skips server certificate verification (development only). + TLSInsecureSkipVerify bool `yaml:"tls_insecure_skip_verify,omitempty"` +} + // ResponseAPIConfig controls response and conversation history storage. // StoreBackend defaults to "redis" for durable storage that survives router // restarts. Set to "memory" only for local development — all history is lost diff --git a/src/semantic-router/pkg/extproc/router_memory.go b/src/semantic-router/pkg/extproc/router_memory.go index 657ac688c..e1d22b1cd 100644 --- a/src/semantic-router/pkg/extproc/router_memory.go +++ b/src/semantic-router/pkg/extproc/router_memory.go @@ -6,6 +6,8 @@ import ( "time" "github.com/milvus-io/milvus-sdk-go/v2/client" + glide "github.com/valkey-io/valkey-glide/go/v2" + glideconfig "github.com/valkey-io/valkey-glide/go/v2/config" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/memory" @@ -23,10 +25,15 @@ func createMemoryRuntime(cfg *config.RouterConfig) (memory.Store, *memory.Memory return nil, nil } + memory.SetGlobalMemoryStore(memoryStore) + backend := cfg.Memory.Backend + if backend == "" { + backend = "milvus" + } if rc := cfg.Memory.RedisCache; rc != nil && rc.Enabled && rc.Address != "" { - logging.Infof("Memory enabled with Milvus backend and Redis hot cache") + logging.Infof("Memory enabled with %s backend and Redis hot cache", backend) } else { - logging.Infof("Memory enabled with Milvus backend") + logging.Infof("Memory enabled with %s backend", backend) } memoryExtractor := memory.NewMemoryChunkStore(memoryStore) @@ -53,7 +60,52 @@ func isMemoryEnabled(cfg *config.RouterConfig) bool { } // createMemoryStore creates a memory store based on configuration. +// Switches on cfg.Memory.Backend: "valkey" creates a ValkeyStore, "milvus" (or empty) creates a MilvusStore. func createMemoryStore(cfg *config.RouterConfig) (memory.Store, error) { + backend := cfg.Memory.Backend + + var store memory.Store + var err error + + switch backend { + case "valkey": + store, err = createValkeyMemoryStore(cfg) + case "milvus", "": + store, err = createMilvusMemoryStore(cfg) + default: + return nil, fmt.Errorf("unsupported memory backend: %q (supported: milvus, valkey)", backend) + } + + if err != nil { + return nil, err + } + + // Optionally wrap with Redis hot cache + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + result := store + if rc := cfg.Memory.RedisCache; rc != nil && rc.Enabled && rc.Address != "" { + cacheCfg := &memory.RedisCacheConfig{ + Address: rc.Address, + Password: rc.Password, + DB: rc.DB, + KeyPrefix: rc.KeyPrefix, + TTLSeconds: rc.TTLSeconds, + } + redisCache, err := memory.NewRedisCache(ctx, cacheCfg) + if err != nil { + logging.Warnf("Memory: Redis cache disabled (connection failed: %v)", err) + } else { + result = memory.NewCachingStore(store, redisCache) + } + } + + return result, nil +} + +// createMilvusMemoryStore creates a MilvusStore backend. +func createMilvusMemoryStore(cfg *config.RouterConfig) (memory.Store, error) { milvusAddress := cfg.Memory.Milvus.Address if milvusAddress == "" { milvusAddress = "localhost:19530" @@ -101,31 +153,106 @@ func createMemoryStore(cfg *config.RouterConfig) (memory.Store, error) { return nil, fmt.Errorf("failed to create memory store: %w", err) } - logging.Infof( - "Memory store initialized: address=%s, collection=%s, embedding=%s", - milvusAddress, - collectionName, - embeddingConfig.Model, - ) + logging.Infof("Memory store initialized: backend=milvus, address=%s, collection=%s, embedding=%s", + milvusAddress, collectionName, embeddingConfig.Model) - var result memory.Store = store - if rc := cfg.Memory.RedisCache; rc != nil && rc.Enabled && rc.Address != "" { - cacheCfg := &memory.RedisCacheConfig{ - Address: rc.Address, - Password: rc.Password, - DB: rc.DB, - KeyPrefix: rc.KeyPrefix, - TTLSeconds: rc.TTLSeconds, + return store, nil +} + +// createValkeyMemoryStore creates a ValkeyStore backend. +func createValkeyMemoryStore(cfg *config.RouterConfig) (memory.Store, error) { + vc := cfg.Memory.Valkey + if vc == nil { + return nil, fmt.Errorf("memory.valkey configuration is required when backend is 'valkey'") + } + + host := vc.Host + if host == "" { + host = "localhost" + } + port := vc.Port + if port <= 0 { + port = 6379 + } + + embeddingConfig := &memory.EmbeddingConfig{ + Model: memory.EmbeddingModelType(detectMemoryEmbeddingModel(cfg)), + Dimension: vc.Dimension, + } + + logging.Infof("Memory: connecting to Valkey at %s:%d, embedding=%s", host, port, embeddingConfig.Model) + + clientConfig := glideconfig.NewClientConfiguration(). + WithAddress(&glideconfig.NodeAddress{ + Host: host, + Port: port, + }) + + if vc.Password != "" { + clientConfig = clientConfig.WithCredentials( + glideconfig.NewServerCredentials("", vc.Password), + ) + } + + if vc.Database != 0 { + clientConfig = clientConfig.WithDatabaseId(vc.Database) + } + + if vc.Timeout > 0 { + timeout := time.Duration(vc.Timeout) * time.Second + clientConfig = clientConfig.WithRequestTimeout(timeout) + } + + if vc.TLSEnabled { + tlsCfg, tlsErr := buildValkeyTLSConfig(vc) + if tlsErr != nil { + return nil, tlsErr } - redisCache, err := memory.NewRedisCache(ctx, cacheCfg) + clientConfig = clientConfig.WithUseTLS(true). + WithAdvancedConfiguration( + glideconfig.NewAdvancedClientConfiguration().WithTlsConfiguration(tlsCfg), + ) + logging.Infof("Memory: Valkey TLS enabled (ca_path=%q, insecure_skip_verify=%v)", vc.TLSCAPath, vc.TLSInsecureSkipVerify) + } + + valkeyClient, err := glide.NewClient(clientConfig) + if err != nil { + return nil, fmt.Errorf("failed to create Valkey client: %w", err) + } + + store, err := memory.NewValkeyStore(memory.ValkeyStoreOptions{ + Client: valkeyClient, + Config: cfg.Memory, + ValkeyConfig: vc, + Enabled: true, + EmbeddingConfig: embeddingConfig, + }) + if err != nil { + valkeyClient.Close() + return nil, fmt.Errorf("failed to create Valkey memory store: %w", err) + } + + logging.Infof("Memory store initialized: backend=valkey, address=%s:%d, embedding=%s", + host, port, embeddingConfig.Model) + + return store, nil +} + +// buildValkeyTLSConfig constructs a glide TLS configuration from the Valkey config. +func buildValkeyTLSConfig(vc *config.MemoryValkeyConfig) (*glideconfig.TlsConfiguration, error) { + tlsConfig := glideconfig.NewTlsConfiguration() + if vc.TLSCAPath != "" { + caCert, err := glideconfig.LoadRootCertificatesFromFile(vc.TLSCAPath) if err != nil { - logging.Warnf("Memory: Redis cache disabled (connection failed: %v)", err) - } else { - result = memory.NewCachingStore(store, redisCache) + return nil, fmt.Errorf("failed to load TLS CA certificate from %s: %w", vc.TLSCAPath, err) } + tlsConfig = tlsConfig.WithRootCertificates(caCert) } - - return result, nil + if vc.TLSInsecureSkipVerify { + tlsConfig = tlsConfig.WithInsecureTLS(true) + logging.Warnf("Memory: Valkey TLS certificate verification is DISABLED — do not use in production") + } + return tlsConfig, nil } func detectMemoryEmbeddingModel(cfg *config.RouterConfig) string { diff --git a/src/semantic-router/pkg/memory/consolidation.go b/src/semantic-router/pkg/memory/consolidation.go index ee400c05f..6d7e2a495 100644 --- a/src/semantic-router/pkg/memory/consolidation.go +++ b/src/semantic-router/pkg/memory/consolidation.go @@ -20,15 +20,16 @@ const ( // deleted. This reduces redundancy and improves retrieval quality over time. // // Designed to be called from a background goroutine on a schedule. -func (m *MilvusStore) ConsolidateUser(ctx context.Context, userID string) (merged int, deleted int, err error) { - if !m.enabled { - return 0, 0, fmt.Errorf("milvus store is not enabled") +// Accepts any Store implementation so it works with both Milvus and Valkey backends. +func ConsolidateUser(ctx context.Context, store Store, userID string) (merged int, deleted int, err error) { + if !store.IsEnabled() { + return 0, 0, fmt.Errorf("memory store is not enabled") } if userID == "" { return 0, 0, fmt.Errorf("user ID is required") } - result, err := m.List(ctx, ListOptions{ + result, err := store.List(ctx, ListOptions{ UserID: userID, Limit: consolidationMaxListLimit, }) @@ -58,13 +59,13 @@ func (m *MilvusStore) ConsolidateUser(ctx context.Context, userID string) (merge Importance: maxImportance(group), } - if err := m.Store(ctx, summaryMem); err != nil { + if err := store.Store(ctx, summaryMem); err != nil { logging.Warnf("ConsolidateUser: failed to store merged memory: %v", err) continue } for _, old := range group { - if ferr := m.Forget(ctx, old.ID); ferr != nil { + if ferr := store.Forget(ctx, old.ID); ferr != nil { logging.Warnf("ConsolidateUser: failed to delete original memory id=%s: %v", old.ID, ferr) } else { deleted++ diff --git a/src/semantic-router/pkg/memory/valkey_store.go b/src/semantic-router/pkg/memory/valkey_store.go new file mode 100644 index 000000000..57536a7fa --- /dev/null +++ b/src/semantic-router/pkg/memory/valkey_store.go @@ -0,0 +1,746 @@ +package memory + +import ( + "context" + "fmt" + "sort" + "strconv" + "strings" + "time" + + glide "github.com/valkey-io/valkey-glide/go/v2" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" +) + +// ValkeyStore provides memory storage and retrieval using Valkey with the Search module. +// Implements the Store interface with HASH-based storage and FT.SEARCH for vector similarity. +type ValkeyStore struct { + client *glide.Client + config config.MemoryConfig + valkeyConfig *config.MemoryValkeyConfig + enabled bool + maxRetries int + retryBaseDelay time.Duration + embeddingConfig EmbeddingConfig + indexName string + collectionPrefix string + metricType string + dimension int +} + +// ValkeyStoreOptions contains configuration for creating a ValkeyStore. +type ValkeyStoreOptions struct { + // Client is the valkey-glide client instance. + Client *glide.Client + // Config is the memory configuration. + Config config.MemoryConfig + // ValkeyConfig is the Valkey-specific configuration. + ValkeyConfig *config.MemoryValkeyConfig + // Enabled controls whether the store is active. + Enabled bool + // EmbeddingConfig is the unified embedding configuration. + EmbeddingConfig *EmbeddingConfig +} + +// NewValkeyStore creates a new ValkeyStore instance. +func NewValkeyStore(options ValkeyStoreOptions) (*ValkeyStore, error) { + if !options.Enabled { + logging.Debugf("ValkeyStore: disabled, returning stub") + return &ValkeyStore{enabled: false}, nil + } + + if options.Client == nil { + return nil, fmt.Errorf("valkey client is required") + } + + if options.ValkeyConfig == nil { + return nil, fmt.Errorf("valkey config is required") + } + + cfg := options.Config + if cfg.EmbeddingModel == "" { + cfg = DefaultMemoryConfig() + } + + var embeddingCfg EmbeddingConfig + if options.EmbeddingConfig != nil { + embeddingCfg = *options.EmbeddingConfig + } else { + embeddingCfg = EmbeddingConfig{Model: EmbeddingModelBERT} + } + + vc := options.ValkeyConfig + + indexName := vc.IndexName + if indexName == "" { + indexName = "mem_idx" + } + collectionPrefix := vc.CollectionPrefix + if collectionPrefix == "" { + collectionPrefix = "mem:" + } + metricType := vc.MetricType + if metricType == "" { + metricType = "COSINE" + } + dimension := vc.Dimension + if dimension <= 0 { + dimension = 384 + } + + store := &ValkeyStore{ + client: options.Client, + config: cfg, + valkeyConfig: vc, + enabled: options.Enabled, + maxRetries: DefaultMaxRetries, + retryBaseDelay: DefaultRetryBaseDelay * time.Millisecond, + embeddingConfig: embeddingCfg, + indexName: indexName, + collectionPrefix: collectionPrefix, + metricType: strings.ToUpper(metricType), + dimension: dimension, + } + + // Use the configured timeout for index initialization (default 10s). + indexTimeout := 10 * time.Second + if vc.Timeout > 0 { + indexTimeout = time.Duration(vc.Timeout) * time.Second + } + ctx, cancel := context.WithTimeout(context.Background(), indexTimeout) + defer cancel() + if err := store.ensureIndex(ctx); err != nil { + return nil, fmt.Errorf("failed to ensure index exists: %w", err) + } + + logging.Infof("ValkeyStore: initialized with index='%s', prefix='%s', embedding_model='%s', dimension=%d", + store.indexName, store.collectionPrefix, store.embeddingConfig.Model, store.dimension) + + return store, nil +} + +// ensureIndex checks if the FT index exists and creates it if not. +func (v *ValkeyStore) ensureIndex(ctx context.Context) error { + _, err := v.client.CustomCommand(ctx, []string{"FT.INFO", v.indexName}) + if err == nil { + logging.Debugf("ValkeyStore: index '%s' already exists", v.indexName) + return nil + } + + logging.Infof("ValkeyStore: creating index '%s' with dimension %d", v.indexName, v.dimension) + + indexM := v.valkeyConfig.IndexM + if indexM <= 0 { + indexM = 16 + } + efConstruction := v.valkeyConfig.IndexEfConstruction + if efConstruction <= 0 { + efConstruction = 256 + } + + createCmd := []string{ + "FT.CREATE", v.indexName, + "ON", "HASH", + "PREFIX", "1", v.collectionPrefix, + "SCHEMA", + "id", "TAG", + "user_id", "TAG", + "project_id", "TAG", + "memory_type", "TAG", + "content", "TEXT", + "source", "TAG", + "embedding", "VECTOR", "HNSW", "10", + "TYPE", "FLOAT32", + "DIM", strconv.Itoa(v.dimension), + "DISTANCE_METRIC", v.metricType, + "M", strconv.Itoa(indexM), + "EF_CONSTRUCTION", strconv.Itoa(efConstruction), + "created_at", "NUMERIC", "SORTABLE", + "updated_at", "NUMERIC", + "access_count", "NUMERIC", + "importance", "NUMERIC", + } + + _, err = v.client.CustomCommand(ctx, createCmd) + if err != nil { + return fmt.Errorf("FT.CREATE failed: %w", err) + } + + logging.Infof("ValkeyStore: index '%s' created successfully", v.indexName) + return nil +} + +// hashKey returns the HASH key for a memory document. +func (v *ValkeyStore) hashKey(id string) string { + return v.collectionPrefix + id +} + +// Store saves a new memory to Valkey. +// Generates embedding for the content and inserts as a HASH key. +func (v *ValkeyStore) Store(ctx context.Context, memory *Memory) error { + startTime := time.Now() + backend := "valkey" + operation := "store" + status := "success" + + defer func() { + duration := time.Since(startTime).Seconds() + RecordMemoryStoreOperation(backend, operation, status, duration) + }() + + if !v.enabled { + status = "error" + return fmt.Errorf("valkey store is not enabled") + } + + if err := valkeyValidateMemory(memory); err != nil { + status = "error" + return err + } + + logging.Debugf("ValkeyStore.Store: id=%s, user=%s, type=%s, content_len=%d", + memory.ID, memory.UserID, memory.Type, len(memory.Content)) + + var embedding []float32 + if len(memory.Embedding) > 0 { + embedding = memory.Embedding + } else { + var err error + embedding, err = GenerateEmbedding(memory.Content, v.embeddingConfig) + if err != nil { + status = "error" + return fmt.Errorf("failed to generate embedding: %w", err) + } + } + + now := time.Now() + if memory.CreatedAt.IsZero() { + memory.CreatedAt = now + } + memory.UpdatedAt = now + if memory.LastAccessed.IsZero() { + memory.LastAccessed = now + } + + fields, err := valkeyBuildHashFields(memory, embedding) + if err != nil { + status = "error" + return fmt.Errorf("failed to build hash fields: %w", err) + } + + key := v.hashKey(memory.ID) + + err = v.retryWithBackoff(ctx, func() error { + _, hsetErr := v.client.HSet(ctx, key, fields) + return hsetErr + }) + if err != nil { + status = "error" + return fmt.Errorf("valkey HSET failed for memory id=%s: %w", memory.ID, err) + } + + logging.Debugf("ValkeyStore.Store: successfully stored memory id=%s", memory.ID) + return nil +} + +// rerankAndFilter applies hybrid re-ranking, adaptive threshold, score filtering, and access tracking. +func (v *ValkeyStore) rerankAndFilter(candidates []*RetrieveResult, opts RetrieveOptions, threshold float32, limit int) []*RetrieveResult { + if opts.HybridSearch && len(candidates) > 1 { + candidates = v.hybridRerank(candidates, opts) + } + + if opts.AdaptiveThreshold && len(candidates) > 1 { + threshold = adaptiveThresholdElbow(candidates, threshold) + } + + results := make([]*RetrieveResult, 0, limit) + for _, c := range candidates { + if c.Score < threshold { + continue + } + results = append(results, c) + if len(results) >= limit { + break + } + } + + if len(results) > 0 { + ids := make([]string, len(results)) + for i, r := range results { + ids[i] = r.Memory.ID + } + go v.recordRetrievalBatch(ids) + } + + return results +} + +// buildRetrieveSearchCmd constructs the FT.SEARCH command for vector similarity retrieval. +// Note: RETURN fetches fields from the underlying HASH document, not from the +// index schema. The "metadata" field is stored in the HASH but not indexed as +// a schema field, which is intentional — it's only needed for result parsing. +func (v *ValkeyStore) buildRetrieveSearchCmd(opts RetrieveOptions, embedding []float32, limit int) []string { + filterExpr := fmt.Sprintf("@user_id:{%s}", valkeyEscapeTagValue(opts.UserID)) + + if len(opts.Types) > 0 { + typeValues := make([]string, len(opts.Types)) + for i, memType := range opts.Types { + typeValues[i] = valkeyEscapeTagValue(string(memType)) + } + filterExpr = fmt.Sprintf("%s @memory_type:{%s}", filterExpr, strings.Join(typeValues, " | ")) + } + + searchTopK := limit * 4 + if opts.HybridSearch { + searchTopK = limit * 8 + } + if searchTopK < 20 { + searchTopK = 20 + } + + embeddingBytes := valkeyFloat32ToBytes(embedding) + query := fmt.Sprintf("(%s)=>[KNN %d @embedding $BLOB AS vector_distance]", filterExpr, searchTopK) + + return []string{ + "FT.SEARCH", v.indexName, query, + "PARAMS", "2", "BLOB", string(embeddingBytes), + "RETURN", "5", "id", "content", "memory_type", "metadata", "vector_distance", + "LIMIT", "0", strconv.Itoa(searchTopK), + "DIALECT", "2", + } +} + +// Retrieve searches for memories in Valkey with similarity threshold filtering. +func (v *ValkeyStore) Retrieve(ctx context.Context, opts RetrieveOptions) ([]*RetrieveResult, error) { + startTime := time.Now() + backend := "valkey" + operation := "retrieve" + status := "success" + resultCount := 0 + + defer func() { + duration := time.Since(startTime).Seconds() + RecordMemoryRetrieval(backend, operation, status, opts.UserID, duration, resultCount) + }() + + if !v.enabled { + status = "error" + return nil, fmt.Errorf("valkey store is not enabled") + } + + limit := opts.Limit + if limit <= 0 { + limit = v.config.DefaultRetrievalLimit + } + + threshold := opts.Threshold + if threshold <= 0 { + threshold = v.config.DefaultSimilarityThreshold + } + + if err := valkeyValidateRetrieveOpts(opts); err != nil { + status = "error" + return nil, err + } + + logging.Debugf("ValkeyStore.Retrieve: query='%s', user_id='%s', limit=%d, threshold=%.4f, hybrid=%v (mode=%s)", + opts.Query, opts.UserID, limit, threshold, opts.HybridSearch, opts.HybridMode) + + embedding, err := GenerateEmbedding(opts.Query, v.embeddingConfig) + if err != nil { + status = "error" + return nil, fmt.Errorf("failed to generate embedding: %w", err) + } + + searchCmd := v.buildRetrieveSearchCmd(opts, embedding, limit) + + var searchResult any + err = v.retryWithBackoff(ctx, func() error { + var retryErr error + searchResult, retryErr = v.client.CustomCommand(ctx, searchCmd) + return retryErr + }) + if err != nil { + status = "error" + return nil, fmt.Errorf("valkey FT.SEARCH failed after retries: %w", err) + } + + candidates := v.parseSearchCandidates(searchResult, opts.UserID) + if len(candidates) == 0 { + status = "miss" + return []*RetrieveResult{}, nil + } + + results := v.rerankAndFilter(candidates, opts, threshold, limit) + resultCount = len(results) + if resultCount > 0 { + status = "hit" + } else { + status = "miss" + } + + return results, nil +} + +// Get retrieves a memory by ID from Valkey. +func (v *ValkeyStore) Get(ctx context.Context, id string) (*Memory, error) { + if !v.enabled { + return nil, fmt.Errorf("valkey store is not enabled") + } + + if id == "" { + return nil, fmt.Errorf("memory ID is required") + } + + logging.Debugf("ValkeyStore.Get: retrieving memory id=%s", id) + + key := v.hashKey(id) + + var fields map[string]string + err := v.retryWithBackoff(ctx, func() error { + var retryErr error + fields, retryErr = v.client.HGetAll(ctx, key) + return retryErr + }) + if err != nil { + return nil, fmt.Errorf("valkey HGETALL failed for memory id=%s: %w", id, err) + } + + if len(fields) == 0 { + return nil, fmt.Errorf("memory not found: %s", id) + } + + mem := valkeyFieldsToMemory(fields) + if mem.ID == "" { + return nil, fmt.Errorf("memory not found: %s", id) + } + + logging.Debugf("ValkeyStore.Get: found memory id=%s, user_id=%s", mem.ID, mem.UserID) + return mem, nil +} + +// Update modifies an existing memory in Valkey using HSET (atomic overwrite). +// Preserves CreatedAt from the existing row and sets UpdatedAt to now. +func (v *ValkeyStore) Update(ctx context.Context, id string, memory *Memory) error { + startTime := time.Now() + backend := "valkey" + operation := "update" + status := "success" + + defer func() { + duration := time.Since(startTime).Seconds() + RecordMemoryStoreOperation(backend, operation, status, duration) + }() + + if !v.enabled { + status = "error" + return fmt.Errorf("valkey store is not enabled") + } + + if id == "" { + status = "error" + return fmt.Errorf("memory ID is required") + } + + logging.Debugf("ValkeyStore.Update: upserting memory id=%s", id) + + memory.ID = id + memory.UpdatedAt = time.Now() + + // If CreatedAt or Embedding are missing, fetch from the existing row so we don't lose data + if memory.CreatedAt.IsZero() || len(memory.Embedding) == 0 { + existing, err := v.Get(ctx, id) + if err != nil { + status = "error" + return fmt.Errorf("memory not found: %s", id) + } + if memory.CreatedAt.IsZero() { + memory.CreatedAt = existing.CreatedAt + } + if len(memory.Embedding) == 0 { + memory.Embedding = existing.Embedding + } + } + + err := v.upsert(ctx, memory) + if err != nil { + status = "error" + return err + } + return nil +} + +// upsert atomically replaces a memory in Valkey by HSET on its hash key. +// The memory must be fully populated (including Embedding, timestamps, etc.). +// Reuses valkeyBuildHashFields to avoid duplicating field-building logic. +func (v *ValkeyStore) upsert(ctx context.Context, memory *Memory) error { + if len(memory.Embedding) == 0 { + return fmt.Errorf("embedding is required for upsert") + } + + fields, err := valkeyBuildHashFields(memory, memory.Embedding) + if err != nil { + return fmt.Errorf("failed to build hash fields: %w", err) + } + + key := v.hashKey(memory.ID) + + err = v.retryWithBackoff(ctx, func() error { + _, hsetErr := v.client.HSet(ctx, key, fields) + return hsetErr + }) + if err != nil { + return fmt.Errorf("valkey HSET upsert failed for memory id=%s: %w", memory.ID, err) + } + + logging.Debugf("ValkeyStore.upsert: successfully upserted memory id=%s", memory.ID) + return nil +} + +// List returns memories matching the filter criteria with pagination. +// Uses FT.SEARCH with TAG filters (no vector search) and returns paginated results. +// Total count comes from the FT.SEARCH header (first element of the result array), +// which reports the full match count regardless of LIMIT, so we only fetch the +// requested page. +func (v *ValkeyStore) List(ctx context.Context, opts ListOptions) (*ListResult, error) { + if !v.enabled { + return nil, fmt.Errorf("valkey store is not enabled") + } + + if opts.UserID == "" { + return nil, fmt.Errorf("user ID is required for listing memories") + } + + limit := opts.Limit + if limit <= 0 { + limit = 20 + } + if limit > 100 { + limit = 100 + } + + logging.Debugf("ValkeyStore.List: user_id=%s, types=%v, limit=%d", + opts.UserID, opts.Types, limit) + + // Build filter expression + filterExpr := fmt.Sprintf("@user_id:{%s}", valkeyEscapeTagValue(opts.UserID)) + + if len(opts.Types) > 0 { + typeValues := make([]string, len(opts.Types)) + for i, memType := range opts.Types { + typeValues[i] = valkeyEscapeTagValue(string(memType)) + } + filterExpr = fmt.Sprintf("%s @memory_type:{%s}", filterExpr, strings.Join(typeValues, " | ")) + } + + // Fetch limit+1 pages worth of data so we can sort client-side and still + // respect the limit. We over-fetch by a factor to allow client-side sorting + // by created_at (FT.SEARCH does not support ORDER BY on NUMERIC fields + // without SORTABLE in all Valkey Search versions). + // The total count comes from the FT.SEARCH header element. + fetchLimit := limit * 5 + if fetchLimit < 100 { + fetchLimit = 100 + } + if fetchLimit > 10000 { + fetchLimit = 10000 + } + + searchCmd := []string{ + "FT.SEARCH", v.indexName, filterExpr, + "RETURN", "7", "id", "content", "user_id", "memory_type", "metadata", "created_at", "updated_at", + "LIMIT", "0", strconv.Itoa(fetchLimit), + "DIALECT", "2", + } + + result, err := v.client.CustomCommand(ctx, searchCmd) + if err != nil { + return nil, fmt.Errorf("valkey FT.SEARCH list failed: %w", err) + } + + // Extract total count from the FT.SEARCH header. + total := v.extractTotalCount(result) + + memories := v.parseListSearchResults(result) + + // Sort by created_at descending for deterministic results. + sort.Slice(memories, func(i, j int) bool { + return memories[i].CreatedAt.After(memories[j].CreatedAt) + }) + + if limit < len(memories) { + memories = memories[:limit] + } + + logging.Debugf("ValkeyStore.List: found %d total, returning %d (limit=%d)", + total, len(memories), limit) + + return &ListResult{ + Memories: memories, + Total: total, + Limit: limit, + }, nil +} + +// Forget deletes a memory by ID from Valkey. +func (v *ValkeyStore) Forget(ctx context.Context, id string) error { + startTime := time.Now() + backend := "valkey" + operation := "forget" + status := "success" + + defer func() { + duration := time.Since(startTime).Seconds() + RecordMemoryStoreOperation(backend, operation, status, duration) + }() + + if !v.enabled { + status = "error" + return fmt.Errorf("valkey store is not enabled") + } + + if id == "" { + status = "error" + return fmt.Errorf("memory ID is required") + } + + logging.Debugf("ValkeyStore.Forget: deleting memory id=%s", id) + + key := v.hashKey(id) + + err := v.retryWithBackoff(ctx, func() error { + deleted, delErr := v.client.Del(ctx, []string{key}) + if delErr != nil { + return delErr + } + if deleted == 0 { + logging.Debugf("ValkeyStore.Forget: key %s did not exist (already deleted)", key) + } + return nil + }) + if err != nil { + status = "error" + return fmt.Errorf("valkey DEL failed for memory id=%s: %w", id, err) + } + + logging.Debugf("ValkeyStore.Forget: successfully deleted memory id=%s", id) + return nil +} + +// ForgetByScope deletes all memories matching the scope from Valkey. +// Scope includes UserID (required), ProjectID (optional), Types (optional). +// Uses batch DEL for efficiency instead of one-by-one deletion. +func (v *ValkeyStore) ForgetByScope(ctx context.Context, scope MemoryScope) error { + startTime := time.Now() + backend := "valkey" + operation := "forget_by_scope" + status := "success" + + defer func() { + duration := time.Since(startTime).Seconds() + RecordMemoryStoreOperation(backend, operation, status, duration) + }() + + if !v.enabled { + status = "error" + return fmt.Errorf("valkey store is not enabled") + } + + if scope.UserID == "" { + status = "error" + return fmt.Errorf("user ID is required for scope deletion") + } + + logging.Debugf("ValkeyStore.ForgetByScope: deleting memories for user_id=%s, project_id=%s, types=%v", + scope.UserID, scope.ProjectID, scope.Types) + + // Build filter expression. + // project_id is indexed as a TAG field in the schema, so we can filter + // directly when the caller specifies a ProjectID. + filterExpr := fmt.Sprintf("@user_id:{%s}", valkeyEscapeTagValue(scope.UserID)) + + if scope.ProjectID != "" { + filterExpr = fmt.Sprintf("%s @project_id:{%s}", filterExpr, valkeyEscapeTagValue(scope.ProjectID)) + } + + if len(scope.Types) > 0 { + typeValues := make([]string, len(scope.Types)) + for i, memType := range scope.Types { + typeValues[i] = valkeyEscapeTagValue(string(memType)) + } + filterExpr = fmt.Sprintf("%s @memory_type:{%s}", filterExpr, strings.Join(typeValues, " | ")) + } + + // Delete in batches: always search from offset 0, delete found keys, repeat + // until none remain. We intentionally re-query at offset 0 each iteration + // because the previous batch's DEL removes those keys from the index, so the + // next FT.SEARCH at offset 0 naturally returns the next page of results. + // Incrementing the offset would skip keys that shifted forward after deletion. + // This follows the same pattern as ValkeyBackend.DeleteByFileID in + // pkg/vectorstore/valkey_backend.go. + const pageSize = 1000 + totalDeleted := 0 + + for { + searchCmd := []string{ + "FT.SEARCH", v.indexName, filterExpr, + "RETURN", "1", "id", + "LIMIT", "0", strconv.Itoa(pageSize), + "DIALECT", "2", + } + + result, err := v.client.CustomCommand(ctx, searchCmd) + if err != nil { + status = "error" + return fmt.Errorf("valkey FT.SEARCH failed for scope deletion: %w", err) + } + + keys := v.extractHashKeysFromSearchResult(result) + if len(keys) == 0 { + break + } + + _, err = v.client.Del(ctx, keys) + if err != nil { + status = "error" + return fmt.Errorf("valkey batch DEL failed for scope deletion: %w", err) + } + + totalDeleted += len(keys) + } + + logging.Debugf("ValkeyStore.ForgetByScope: deleted %d memories", totalDeleted) + return nil +} + +// IsEnabled returns whether the store is enabled. +func (v *ValkeyStore) IsEnabled() bool { + return v.enabled +} + +// CheckConnection verifies the Valkey connection is healthy. +func (v *ValkeyStore) CheckConnection(ctx context.Context) error { + if !v.enabled { + return nil + } + + if v.client == nil { + return fmt.Errorf("valkey client is not initialized") + } + + // Verify the FT index exists + _, err := v.client.CustomCommand(ctx, []string{"FT.INFO", v.indexName}) + if err != nil { + return fmt.Errorf("valkey FT.INFO failed for index '%s': %w", v.indexName, err) + } + + return nil +} + +// Close releases resources held by the store. +func (v *ValkeyStore) Close() error { + // The caller is responsible for managing the client lifecycle + return nil +} diff --git a/src/semantic-router/pkg/memory/valkey_store_helpers.go b/src/semantic-router/pkg/memory/valkey_store_helpers.go new file mode 100644 index 000000000..9e5fbf7a8 --- /dev/null +++ b/src/semantic-router/pkg/memory/valkey_store_helpers.go @@ -0,0 +1,606 @@ +package memory + +import ( + "context" + "encoding/binary" + "encoding/json" + "fmt" + "math" + "sort" + "strconv" + "strings" + "time" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability/logging" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/vectorstore" +) + +// --------------------------------------------------------------------------- +// Background access tracking +// --------------------------------------------------------------------------- + +// recordRetrievalBatch updates LastAccessed and AccessCount for each retrieved memory in the background. +// Uses targeted HINCRBY + HSET instead of full read-modify-write for efficiency. +// The user-facing behavior matches the Milvus backend (access_count incremented, timestamps updated). +func (v *ValkeyStore) recordRetrievalBatch(ids []string) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + for _, id := range ids { + if err := v.recordRetrieval(ctx, id); err != nil { + logging.Warnf("ValkeyStore.recordRetrievalBatch: id=%s: %v", id, err) + } + } +} + +// recordRetrieval updates LastAccessed and AccessCount for a single memory (reinforcement: S += 1, t = 0). +// +// The authoritative access_count and updated_at live as top-level HASH fields and are updated +// atomically via HINCRBY / HSET. The metadata JSON blob also carries copies of these values for +// convenience during result parsing. Rather than doing a racy read-modify-write on the JSON +// (which could lose concurrent increments), we read the current access_count HASH field after +// the atomic increment and rebuild the metadata JSON from that single source of truth. +func (v *ValkeyStore) recordRetrieval(ctx context.Context, id string) error { + key := v.hashKey(id) + now := time.Now() + nowUnix := strconv.FormatInt(now.Unix(), 10) + + // Increment access_count atomically. + newCount, err := v.client.CustomCommand(ctx, []string{"HINCRBY", key, "access_count", "1"}) + if err != nil { + return fmt.Errorf("HINCRBY access_count failed: %w", err) + } + + // Update timestamps. + _, err = v.client.HSet(ctx, key, map[string]string{ + "updated_at": nowUnix, + }) + if err != nil { + return fmt.Errorf("HSET timestamps failed: %w", err) + } + + // Sync metadata JSON with the authoritative HASH fields. + // We read the current metadata, overwrite access_count and last_accessed + // with the values we just wrote atomically above, and write it back. + // This avoids the previous read-modify-write race: even if two goroutines + // run concurrently, each writes the post-increment count it received from + // HINCRBY, so the JSON converges to the correct value. + fields, err := v.client.HGetAll(ctx, key) + if err != nil { + return nil // Non-critical: top-level HASH fields are already updated + } + if metadataStr, ok := fields["metadata"]; ok && metadataStr != "" { + var metadata map[string]interface{} + if jsonErr := json.Unmarshal([]byte(metadataStr), &metadata); jsonErr == nil { + metadata["last_accessed"] = now.Unix() + // Use the authoritative count returned by HINCRBY instead of + // incrementing the stale JSON value. + metadata["access_count"] = valkeyToInt64(newCount) + if updated, mErr := json.Marshal(metadata); mErr == nil { + _, _ = v.client.HSet(ctx, key, map[string]string{"metadata": string(updated)}) + } + } + } + + return nil +} + +// --------------------------------------------------------------------------- +// Hybrid re-ranking +// --------------------------------------------------------------------------- + +// hybridRerank applies BM25 + n-gram scoring on top of vector results. +func (v *ValkeyStore) hybridRerank(candidates []*RetrieveResult, opts RetrieveOptions) []*RetrieveResult { + pseudoChunks := make(map[string]vectorstore.EmbeddedChunk, len(candidates)) + vectorScores := make(map[string]float64, len(candidates)) + keyToCandidate := make(map[string]*RetrieveResult, len(candidates)) + + for i, c := range candidates { + key := fmt.Sprintf("_mem_%d", i) + pseudoChunks[key] = vectorstore.EmbeddedChunk{ID: key, Content: c.Memory.Content} + vectorScores[key] = float64(c.Score) + keyToCandidate[key] = c + } + + hybridCfg := &vectorstore.HybridSearchConfig{Mode: opts.HybridMode} + + bm25K1 := hybridCfg.BM25K1 + if bm25K1 == 0 { + bm25K1 = 1.2 + } + bm25B := hybridCfg.BM25B + if bm25B == 0 { + bm25B = 0.75 + } + ngramSize := hybridCfg.NgramSize + if ngramSize <= 0 { + ngramSize = 3 + } + + bm25Idx := vectorstore.NewBM25Index(pseudoChunks) + bm25Scores := bm25Idx.Score(opts.Query, bm25K1, bm25B) + + ngramIdx := vectorstore.NewNgramIndex(pseudoChunks, ngramSize) + ngramScores := ngramIdx.Score(opts.Query) + + fused := vectorstore.FuseScores(vectorScores, bm25Scores, ngramScores, hybridCfg) + + reranked := make([]*RetrieveResult, 0, len(fused)) + for _, fc := range fused { + c, ok := keyToCandidate[fc.ChunkID] + if !ok { + continue + } + c.Score = float32(fc.FinalScore) + reranked = append(reranked, c) + } + return reranked +} + +// --------------------------------------------------------------------------- +// Result parsing +// --------------------------------------------------------------------------- + +// valkeyIterateSearchDocs extracts field maps from an FT.SEARCH result array. +// It skips the total-count header and yields each document's field map. +func valkeyIterateSearchDocs(result any) []map[string]interface{} { + arr, ok := result.([]interface{}) + if !ok || len(arr) < 1 { + return nil + } + + totalCount := valkeyToInt64(arr[0]) + if totalCount == 0 { + return nil + } + + var docs []map[string]interface{} + for i := 1; i < len(arr); i++ { + docMap, ok := arr[i].(map[string]interface{}) + if !ok { + continue + } + for _, docValue := range docMap { + fieldsMap, mapOk := docValue.(map[string]interface{}) + if !mapOk { + continue + } + docs = append(docs, fieldsMap) + } + } + return docs +} + +func (v *ValkeyStore) parseSearchCandidates(result any, defaultUserID string) []*RetrieveResult { + docs := valkeyIterateSearchDocs(result) + if len(docs) == 0 { + return nil + } + + var candidates []*RetrieveResult + + for _, fieldsMap := range docs { + id := fmt.Sprint(fieldsMap["id"]) + content := fmt.Sprint(fieldsMap["content"]) + memType := fmt.Sprint(fieldsMap["memory_type"]) + + if id == "" || id == "" || content == "" || content == "" { + continue + } + + score := valkeyParseScoreFromMap(fieldsMap, "vector_distance", v.metricType) + + mem := &Memory{ + ID: id, + Content: content, + Type: MemoryType(memType), + } + + if metadataStr, ok := fieldsMap["metadata"].(string); ok { + valkeyParseMetadata(mem, metadataStr) + } + + if mem.UserID == "" { + mem.UserID = defaultUserID + } + + candidates = append(candidates, &RetrieveResult{Memory: mem, Score: float32(score)}) + } + + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].Score > candidates[j].Score + }) + + return candidates +} + +// parseListSearchResults parses FT.SEARCH results for List operations (no vector_distance). +func (v *ValkeyStore) parseListSearchResults(result any) []*Memory { + docs := valkeyIterateSearchDocs(result) + if len(docs) == 0 { + return nil + } + + var memories []*Memory + for _, fieldsMap := range docs { + mem := valkeyFieldsMapToMemory(fieldsMap) + if mem.ID == "" { + continue + } + memories = append(memories, mem) + } + + return memories +} + +// valkeyMatchesProjectFilter checks whether a document's metadata contains the expected project_id. +func valkeyMatchesProjectFilter(fieldsMap map[string]interface{}, projectIDFilter string) bool { + metadataStr, ok := fieldsMap["metadata"].(string) + if !ok || metadataStr == "" { + return false + } + var metadata map[string]interface{} + if err := json.Unmarshal([]byte(metadataStr), &metadata); err != nil { + return false + } + projectID, ok := metadata["project_id"].(string) + return ok && projectID == projectIDFilter +} + +// extractIDsFromSearchResult extracts memory IDs from FT.SEARCH results, optionally filtering by project_id. +func (v *ValkeyStore) extractIDsFromSearchResult(result any, projectIDFilter string) []string { + docs := valkeyIterateSearchDocs(result) + if len(docs) == 0 { + return nil + } + + var ids []string + for _, fieldsMap := range docs { + id := fmt.Sprint(fieldsMap["id"]) + if id == "" || id == "" { + continue + } + if projectIDFilter != "" && !valkeyMatchesProjectFilter(fieldsMap, projectIDFilter) { + continue + } + ids = append(ids, id) + } + + return ids +} + +// extractHashKeysFromSearchResult extracts the full hash keys (document keys) +// from an FT.SEARCH result. These are the actual Valkey keys that can be passed +// to DEL for batch deletion. Follows the same pattern as +// extractKeysFromSearchResult in pkg/vectorstore/valkey_backend.go. +func (v *ValkeyStore) extractHashKeysFromSearchResult(result any) []string { + arr, ok := result.([]interface{}) + if !ok || len(arr) < 2 { + return nil + } + + var keys []string + for i := 1; i < len(arr); i++ { + switch val := arr[i].(type) { + case string: + keys = append(keys, val) + case map[string]interface{}: + for docKey := range val { + keys = append(keys, docKey) + } + } + } + return keys +} + +// extractTotalCount extracts the total match count from the FT.SEARCH result header. +// The first element of the result array is always the total count of matching documents, +// regardless of the LIMIT clause. +func (v *ValkeyStore) extractTotalCount(result any) int { + arr, ok := result.([]interface{}) + if !ok || len(arr) < 1 { + return 0 + } + return int(valkeyToInt64(arr[0])) +} + +// --------------------------------------------------------------------------- +// Retry logic +// --------------------------------------------------------------------------- + +// retryWithBackoff retries an operation with exponential backoff for transient errors. +func (v *ValkeyStore) retryWithBackoff(ctx context.Context, operation func() error) error { + var lastErr error + + for attempt := 0; attempt < v.maxRetries; attempt++ { + lastErr = operation() + + if lastErr == nil || !isTransientError(lastErr) { + return lastErr + } + + if attempt == v.maxRetries-1 { + logging.Warnf("ValkeyStore: operation failed after %d retries: %v", v.maxRetries, lastErr) + return lastErr + } + + exponent := attempt + if exponent < 0 { + exponent = 0 + } else if exponent > 30 { + exponent = 30 + } + delay := v.retryBaseDelay * time.Duration(1< 0 { + _, _ = client.Del(ctx, keys) + } + if cursor == "0" { + return + } + } +} + +// --------------------------------------------------------------------------- +// CheckConnection +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_CheckConnection(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + err := store.CheckConnection(ctx) + assert.NoError(t, err) +} + +// --------------------------------------------------------------------------- +// Store + Get (full CRUD lifecycle) +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_StoreAndGet(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + mem := &Memory{ + ID: fmt.Sprintf("mem_integ_%d", time.Now().UnixNano()), + Type: MemoryTypeSemantic, + Content: "The user's preferred programming language is Go", + UserID: "test_user_1", + ProjectID: "proj_1", + Source: "conversation", + Importance: 0.8, + } + + // Store + err := store.Store(ctx, mem) + require.NoError(t, err) + + // Allow index to catch up + time.Sleep(200 * time.Millisecond) + + // Get + retrieved, err := store.Get(ctx, mem.ID) + require.NoError(t, err) + assert.Equal(t, mem.ID, retrieved.ID) + assert.Equal(t, mem.Content, retrieved.Content) + assert.Equal(t, mem.UserID, retrieved.UserID) + assert.Equal(t, MemoryTypeSemantic, retrieved.Type) + assert.Equal(t, "proj_1", retrieved.ProjectID) + assert.Equal(t, "conversation", retrieved.Source) + assert.InDelta(t, 0.8, float64(retrieved.Importance), 0.01) + assert.NotNil(t, retrieved.Embedding, "Get should return the embedding") + assert.False(t, retrieved.CreatedAt.IsZero()) + assert.False(t, retrieved.UpdatedAt.IsZero()) +} + +// --------------------------------------------------------------------------- +// Get — not found +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_GetNotFound(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + _, err := store.Get(ctx, "nonexistent_id_12345") + assert.Error(t, err) + assert.Contains(t, err.Error(), "memory not found") +} + +// --------------------------------------------------------------------------- +// Retrieve (semantic search) +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_Retrieve(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + // Store some memories + memories := []*Memory{ + { + ID: fmt.Sprintf("mem_ret_%d_1", time.Now().UnixNano()), + Type: MemoryTypeSemantic, + Content: "The user prefers dark mode in all their applications", + UserID: "retrieve_user", + }, + { + ID: fmt.Sprintf("mem_ret_%d_2", time.Now().UnixNano()), + Type: MemoryTypeProcedural, + Content: "To deploy the application, run make deploy in the project root", + UserID: "retrieve_user", + }, + { + ID: fmt.Sprintf("mem_ret_%d_3", time.Now().UnixNano()), + Type: MemoryTypeSemantic, + Content: "The capital of France is Paris, a well known European city", + UserID: "retrieve_user", + }, + } + + for _, m := range memories { + require.NoError(t, store.Store(ctx, m)) + } + + // Allow index to catch up + time.Sleep(500 * time.Millisecond) + + // Search for dark mode preference + results, err := store.Retrieve(ctx, RetrieveOptions{ + Query: "What are the user's display preferences?", + UserID: "retrieve_user", + Limit: 3, + Threshold: 0.3, // Low threshold for integration test + }) + require.NoError(t, err) + assert.NotEmpty(t, results, "should find at least one result") + + // Verify results have scores + for _, r := range results { + assert.NotEmpty(t, r.Memory.ID) + assert.NotEmpty(t, r.Memory.Content) + assert.Positive(t, r.Score, "score should be positive") + } +} + +// --------------------------------------------------------------------------- +// Retrieve — empty results +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_RetrieveEmpty(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + results, err := store.Retrieve(ctx, RetrieveOptions{ + Query: "something that matches nothing in an empty store", + UserID: "empty_user_" + strconv.FormatInt(time.Now().UnixNano(), 36), + Limit: 5, + Threshold: 0.99, // Very high threshold + }) + require.NoError(t, err) + assert.Empty(t, results) +} + +// --------------------------------------------------------------------------- +// Retrieve with type filter +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_RetrieveWithTypeFilter(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + ts := time.Now().UnixNano() + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_tf_%d_1", ts), Type: MemoryTypeSemantic, + Content: "User likes Python programming language very much", UserID: "filter_user", + })) + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_tf_%d_2", ts), Type: MemoryTypeProcedural, + Content: "To run Python tests use pytest command in terminal", UserID: "filter_user", + })) + + time.Sleep(500 * time.Millisecond) + + results, err := store.Retrieve(ctx, RetrieveOptions{ + Query: "Python programming", + UserID: "filter_user", + Types: []MemoryType{MemoryTypeProcedural}, + Limit: 5, + Threshold: 0.3, + }) + require.NoError(t, err) + + for _, r := range results { + assert.Equal(t, MemoryTypeProcedural, r.Memory.Type, "should only return procedural memories") + } +} + +// --------------------------------------------------------------------------- +// User-scoped isolation +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_UserScopedIsolation(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + ts := time.Now().UnixNano() + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_iso_%d_a", ts), Type: MemoryTypeSemantic, + Content: "User A's secret preference is dark chocolate", UserID: "user_a", + })) + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_iso_%d_b", ts), Type: MemoryTypeSemantic, + Content: "User B's secret preference is white chocolate", UserID: "user_b", + })) + + time.Sleep(500 * time.Millisecond) + + // User A should only see their own memories + resultsA, err := store.Retrieve(ctx, RetrieveOptions{ + Query: "chocolate preference", UserID: "user_a", Limit: 5, Threshold: 0.3, + }) + require.NoError(t, err) + for _, r := range resultsA { + assert.Equal(t, "user_a", r.Memory.UserID, "user_a should only see their own memories") + } + + // User B should only see their own memories + resultsB, err := store.Retrieve(ctx, RetrieveOptions{ + Query: "chocolate preference", UserID: "user_b", Limit: 5, Threshold: 0.3, + }) + require.NoError(t, err) + for _, r := range resultsB { + assert.Equal(t, "user_b", r.Memory.UserID, "user_b should only see their own memories") + } +} + +// --------------------------------------------------------------------------- +// Update +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_Update(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + id := fmt.Sprintf("mem_upd_%d", time.Now().UnixNano()) + original := &Memory{ + ID: id, Type: MemoryTypeSemantic, + Content: "Original content for update test", UserID: "update_user", + Importance: 0.5, + } + require.NoError(t, store.Store(ctx, original)) + time.Sleep(200 * time.Millisecond) + + // Update content and importance + updated := &Memory{ + Content: "Updated content after modification", + UserID: "update_user", + Type: MemoryTypeSemantic, + Importance: 0.9, + } + err := store.Update(ctx, id, updated) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + + // Verify update + retrieved, err := store.Get(ctx, id) + require.NoError(t, err) + assert.Equal(t, "Updated content after modification", retrieved.Content) + assert.InDelta(t, 0.9, float64(retrieved.Importance), 0.01) + assert.False(t, retrieved.CreatedAt.IsZero(), "CreatedAt should be preserved") +} + +// --------------------------------------------------------------------------- +// Update — not found +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_UpdateNotFound(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + err := store.Update(ctx, "nonexistent_update_id", &Memory{ + Content: "will not work", UserID: "u1", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "memory not found") +} + +// --------------------------------------------------------------------------- +// Forget +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_Forget(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + id := fmt.Sprintf("mem_forget_%d", time.Now().UnixNano()) + require.NoError(t, store.Store(ctx, &Memory{ + ID: id, Type: MemoryTypeSemantic, + Content: "Memory to be forgotten", UserID: "forget_user", + })) + time.Sleep(200 * time.Millisecond) + + // Verify it exists + _, err := store.Get(ctx, id) + require.NoError(t, err) + + // Forget + err = store.Forget(ctx, id) + require.NoError(t, err) + + // Verify deleted + _, err = store.Get(ctx, id) + assert.Error(t, err) + assert.Contains(t, err.Error(), "memory not found") +} + +// --------------------------------------------------------------------------- +// ForgetByScope — user only +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_ForgetByScope_UserOnly(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + userID := fmt.Sprintf("scope_user_%d", time.Now().UnixNano()) + for i := 0; i < 3; i++ { + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_scope_%s_%d", userID, i), + Type: MemoryTypeSemantic, + Content: fmt.Sprintf("Scoped memory %d for deletion", i), + UserID: userID, + })) + } + time.Sleep(500 * time.Millisecond) + + // Verify they exist + list, err := store.List(ctx, ListOptions{UserID: userID, Limit: 10}) + require.NoError(t, err) + assert.Equal(t, 3, list.Total) + + // Delete by scope + err = store.ForgetByScope(ctx, MemoryScope{UserID: userID}) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + + // Verify all deleted + list, err = store.List(ctx, ListOptions{UserID: userID, Limit: 10}) + require.NoError(t, err) + assert.Equal(t, 0, list.Total) +} + +// --------------------------------------------------------------------------- +// ForgetByScope — with type filter +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_ForgetByScope_WithTypeFilter(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + userID := fmt.Sprintf("scope_type_user_%d", time.Now().UnixNano()) + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_st_%s_1", userID), Type: MemoryTypeSemantic, + Content: "Semantic memory to keep", UserID: userID, + })) + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_st_%s_2", userID), Type: MemoryTypeProcedural, + Content: "Procedural memory to delete", UserID: userID, + })) + time.Sleep(500 * time.Millisecond) + + // Delete only procedural + err := store.ForgetByScope(ctx, MemoryScope{ + UserID: userID, + Types: []MemoryType{MemoryTypeProcedural}, + }) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + + // Verify only semantic remains + list, err := store.List(ctx, ListOptions{UserID: userID, Limit: 10}) + require.NoError(t, err) + assert.Equal(t, 1, list.Total) + assert.Equal(t, MemoryTypeSemantic, list.Memories[0].Type) +} + +// --------------------------------------------------------------------------- +// ForgetByScope — missing UserID +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_ForgetByScope_MissingUserID(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + err := store.ForgetByScope(ctx, MemoryScope{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "user ID is required") +} + +// --------------------------------------------------------------------------- +// List +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_List(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + userID := fmt.Sprintf("list_user_%d", time.Now().UnixNano()) + for i := 0; i < 5; i++ { + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_list_%s_%d", userID, i), + Type: MemoryTypeSemantic, + Content: fmt.Sprintf("List test memory number %d", i), + UserID: userID, + })) + time.Sleep(50 * time.Millisecond) // ensure different timestamps + } + time.Sleep(500 * time.Millisecond) + + list, err := store.List(ctx, ListOptions{UserID: userID, Limit: 3}) + require.NoError(t, err) + assert.Equal(t, 5, list.Total) + assert.Len(t, list.Memories, 3) + + // Verify sorted by created_at descending + for i := 1; i < len(list.Memories); i++ { + assert.False(t, list.Memories[i-1].CreatedAt.Before(list.Memories[i].CreatedAt), + "memories should be sorted by created_at descending") + } +} + +// --------------------------------------------------------------------------- +// List — missing UserID +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_List_MissingUserID(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + _, err := store.List(ctx, ListOptions{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "user ID is required") +} + +// --------------------------------------------------------------------------- +// Duplicate keys (overwrite behavior) +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_DuplicateKeys(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + id := fmt.Sprintf("mem_dup_%d", time.Now().UnixNano()) + + // Store first version + require.NoError(t, store.Store(ctx, &Memory{ + ID: id, Type: MemoryTypeSemantic, + Content: "First version", UserID: "dup_user", + })) + time.Sleep(200 * time.Millisecond) + + // Store again with same ID (HSET overwrites) + require.NoError(t, store.Store(ctx, &Memory{ + ID: id, Type: MemoryTypeSemantic, + Content: "Second version", UserID: "dup_user", + })) + time.Sleep(200 * time.Millisecond) + + // Verify latest content + retrieved, err := store.Get(ctx, id) + require.NoError(t, err) + assert.Equal(t, "Second version", retrieved.Content) +} + +// --------------------------------------------------------------------------- +// Concurrent access +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_ConcurrentAccess(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + const numGoroutines = 10 + var wg sync.WaitGroup + errs := make(chan error, numGoroutines*2) + + userID := fmt.Sprintf("concurrent_user_%d", time.Now().UnixNano()) + + // Concurrent stores + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + err := store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_conc_%s_%d", userID, idx), + Type: MemoryTypeSemantic, + Content: fmt.Sprintf("Concurrent memory %d for testing parallel access", idx), + UserID: userID, + }) + if err != nil { + errs <- err + } + }(i) + } + wg.Wait() + close(errs) + + for err := range errs { + t.Errorf("concurrent store error: %v", err) + } + + time.Sleep(500 * time.Millisecond) + + // Verify all stored + list, err := store.List(ctx, ListOptions{UserID: userID, Limit: 100}) + require.NoError(t, err) + assert.Equal(t, numGoroutines, list.Total) +} + +// --------------------------------------------------------------------------- +// ForgetByScope — with project filter +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_ForgetByScope_WithProjectFilter(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + userID := fmt.Sprintf("scope_proj_user_%d", time.Now().UnixNano()) + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_sp_%s_1", userID), Type: MemoryTypeSemantic, + Content: "Memory for project A", UserID: userID, ProjectID: "projA", + })) + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_sp_%s_2", userID), Type: MemoryTypeSemantic, + Content: "Memory for project B", UserID: userID, ProjectID: "projB", + })) + time.Sleep(500 * time.Millisecond) + + // Delete only projA + err := store.ForgetByScope(ctx, MemoryScope{ + UserID: userID, + ProjectID: "projA", + }) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + + // Verify only projB remains + list, err := store.List(ctx, ListOptions{UserID: userID, Limit: 10}) + require.NoError(t, err) + assert.Equal(t, 1, list.Total) + assert.Equal(t, "projB", list.Memories[0].ProjectID) +} + +// --------------------------------------------------------------------------- +// ConsolidateUser (standalone function with ValkeyStore) +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_ConsolidateUser(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + userID := fmt.Sprintf("consol_user_%d", time.Now().UnixNano()) + + // Store memories with similar content (high Jaccard similarity) + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_c_%s_1", userID), Type: MemoryTypeSemantic, + Content: "The user prefers dark mode in all applications", UserID: userID, + })) + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_c_%s_2", userID), Type: MemoryTypeSemantic, + Content: "The user prefers dark mode in all their applications and IDEs", UserID: userID, + })) + // Store a different memory + require.NoError(t, store.Store(ctx, &Memory{ + ID: fmt.Sprintf("mem_c_%s_3", userID), Type: MemoryTypeSemantic, + Content: "Python is installed at /usr/bin/python3", UserID: userID, + })) + time.Sleep(500 * time.Millisecond) + + merged, deleted, err := ConsolidateUser(ctx, store, userID) + require.NoError(t, err) + + // The two similar memories should be merged + assert.GreaterOrEqual(t, merged, 0, "merged count should be non-negative") + assert.GreaterOrEqual(t, deleted, 0, "deleted count should be non-negative") + + // The total should be reduced + list, err := store.List(ctx, ListOptions{UserID: userID, Limit: 100}) + require.NoError(t, err) + t.Logf("ConsolidateUser: merged=%d, deleted=%d, remaining=%d", merged, deleted, list.Total) +} diff --git a/src/semantic-router/pkg/memory/valkey_store_integration_validation_test.go b/src/semantic-router/pkg/memory/valkey_store_integration_validation_test.go new file mode 100644 index 000000000..12bc4e012 --- /dev/null +++ b/src/semantic-router/pkg/memory/valkey_store_integration_validation_test.go @@ -0,0 +1,139 @@ +//go:build !windows && cgo + +package memory + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" +) + +// --------------------------------------------------------------------------- +// Store validation errors +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_StoreValidation(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + t.Run("missing ID", func(t *testing.T) { + err := store.Store(ctx, &Memory{Content: "test", UserID: "u1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "memory ID is required") + }) + + t.Run("missing content", func(t *testing.T) { + err := store.Store(ctx, &Memory{ID: "test_id", UserID: "u1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "memory content is required") + }) + + t.Run("missing user ID", func(t *testing.T) { + err := store.Store(ctx, &Memory{ID: "test_id", Content: "test"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "user ID is required") + }) +} + +// --------------------------------------------------------------------------- +// Retrieve validation errors +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_RetrieveValidation(t *testing.T) { + store, _ := setupValkeyMemoryIntegration(t) + ctx := context.Background() + + t.Run("missing query", func(t *testing.T) { + _, err := store.Retrieve(ctx, RetrieveOptions{UserID: "u1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "query is required") + }) + + t.Run("missing user ID", func(t *testing.T) { + _, err := store.Retrieve(ctx, RetrieveOptions{Query: "test"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "user id is required") + }) +} + +// --------------------------------------------------------------------------- +// IsEnabled / disabled store +// --------------------------------------------------------------------------- + +func TestValkeyStoreInteg_DisabledStore(t *testing.T) { + t.Parallel() + + store := &ValkeyStore{enabled: false} + + assert.False(t, store.IsEnabled()) + + ctx := context.Background() + assert.Error(t, store.Store(ctx, &Memory{})) + _, err := store.Retrieve(ctx, RetrieveOptions{}) + assert.Error(t, err) + _, err = store.Get(ctx, "id") + assert.Error(t, err) + assert.Error(t, store.Update(ctx, "id", &Memory{})) + _, err = store.List(ctx, ListOptions{}) + assert.Error(t, err) + assert.Error(t, store.Forget(ctx, "id")) + assert.Error(t, store.ForgetByScope(ctx, MemoryScope{})) +} + +// --------------------------------------------------------------------------- +// TLS integration tests +// --------------------------------------------------------------------------- + +// TestValkeyStoreInteg_TLS_ConfigPropagation verifies that TLS config fields +// are accepted by NewValkeyStore without error when the store is disabled. +// The actual TLS handshake is not tested here (would require a TLS-enabled +// Valkey instance); the wiring from config to glide client lives in +// router_memory.go and is validated via the unit tests in valkey_store_test.go. +func TestValkeyStoreInteg_TLS_ConfigPropagation(t *testing.T) { + t.Parallel() + + vc := &config.MemoryValkeyConfig{ + Host: "localhost", + Port: 6380, + TLSEnabled: true, + TLSCAPath: "/nonexistent/ca.pem", + TLSInsecureSkipVerify: true, + Dimension: 384, + MetricType: "COSINE", + } + + // Disabled store should accept TLS config without attempting a connection. + store, err := NewValkeyStore(ValkeyStoreOptions{ + Enabled: false, + ValkeyConfig: vc, + }) + require.NoError(t, err) + assert.False(t, store.IsEnabled()) + + // The config struct itself carries the TLS values correctly. + assert.True(t, vc.TLSEnabled) + assert.Equal(t, "/nonexistent/ca.pem", vc.TLSCAPath) + assert.True(t, vc.TLSInsecureSkipVerify) +} + +// TestValkeyStoreInteg_TLS_BadCAPath verifies that createValkeyMemoryStore +// would fail with a clear error when given a non-existent CA path. We test +// this at the config level since the actual client creation happens in +// router_memory.go and requires the full extproc wiring. +func TestValkeyStoreInteg_TLS_BadCAPathConfig(t *testing.T) { + t.Parallel() + + vc := &config.MemoryValkeyConfig{ + TLSEnabled: true, + TLSCAPath: "/definitely/does/not/exist/ca.pem", + } + + // The CA file validation happens in router_memory.go (LoadRootCertificatesFromFile), + // not in NewValkeyStore itself. Verify the config is valid at the store level. + assert.True(t, vc.TLSEnabled) + assert.NotEmpty(t, vc.TLSCAPath) +} diff --git a/src/semantic-router/pkg/memory/valkey_store_test.go b/src/semantic-router/pkg/memory/valkey_store_test.go new file mode 100644 index 000000000..deac3d3fb --- /dev/null +++ b/src/semantic-router/pkg/memory/valkey_store_test.go @@ -0,0 +1,707 @@ +package memory + +import ( + "encoding/binary" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" +) + +// Verify ValkeyStore satisfies the Store interface at compile time. +var _ Store = (*ValkeyStore)(nil) + +// --------------------------------------------------------------------------- +// TLS configuration (config struct tests — no live Valkey required) +// --------------------------------------------------------------------------- + +func TestMemoryValkeyConfig_TLSFields(t *testing.T) { + t.Parallel() + + t.Run("defaults are zero-values", func(t *testing.T) { + t.Parallel() + cfg := config.MemoryValkeyConfig{} + assert.False(t, cfg.TLSEnabled) + assert.Empty(t, cfg.TLSCAPath) + assert.False(t, cfg.TLSInsecureSkipVerify) + }) + + t.Run("all fields populated", func(t *testing.T) { + t.Parallel() + cfg := config.MemoryValkeyConfig{ + TLSEnabled: true, + TLSCAPath: "/etc/certs/ca.pem", + TLSInsecureSkipVerify: false, + } + assert.True(t, cfg.TLSEnabled) + assert.Equal(t, "/etc/certs/ca.pem", cfg.TLSCAPath) + assert.False(t, cfg.TLSInsecureSkipVerify) + }) + + t.Run("insecure skip verify", func(t *testing.T) { + t.Parallel() + cfg := config.MemoryValkeyConfig{ + TLSEnabled: true, + TLSInsecureSkipVerify: true, + } + assert.True(t, cfg.TLSEnabled) + assert.True(t, cfg.TLSInsecureSkipVerify) + }) +} + +func TestNewValkeyStore_DisabledWithTLS(t *testing.T) { + t.Parallel() + // TLS fields should not matter when the store is disabled. + store, err := NewValkeyStore(ValkeyStoreOptions{ + Enabled: false, + ValkeyConfig: &config.MemoryValkeyConfig{ + TLSEnabled: true, + TLSCAPath: "/nonexistent/ca.pem", + }, + }) + require.NoError(t, err) + assert.False(t, store.IsEnabled()) +} + +func TestNewValkeyStore_RequiresClient(t *testing.T) { + t.Parallel() + // Enabled store with TLS but no client should fail with clear error. + _, err := NewValkeyStore(ValkeyStoreOptions{ + Enabled: true, + ValkeyConfig: &config.MemoryValkeyConfig{ + TLSEnabled: true, + }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "valkey client is required") +} + +// --------------------------------------------------------------------------- +// valkeyFloat32ToBytes / valkeyBytesToFloat32 +// --------------------------------------------------------------------------- + +func TestValkeyFloat32ToBytes_Roundtrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input []float32 + }{ + {"basic values", []float32{1.0, 2.0, 3.0}}, + {"negative and special", []float32{0.0, -1.5, 3.14, math.MaxFloat32, math.SmallestNonzeroFloat32}}, + {"empty", []float32{}}, + {"single value", []float32{42.0}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + b := valkeyFloat32ToBytes(tc.input) + assert.Len(t, b, len(tc.input)*4) + + // Verify little-endian encoding + for i, expected := range tc.input { + bits := binary.LittleEndian.Uint32(b[i*4:]) + assert.Equal(t, expected, math.Float32frombits(bits)) + } + + // Roundtrip + result := valkeyBytesToFloat32(b) + assert.Equal(t, tc.input, result) + }) + } +} + +func TestValkeyBytesToFloat32_InvalidLength(t *testing.T) { + t.Parallel() + // Not a multiple of 4 + assert.Nil(t, valkeyBytesToFloat32([]byte{1, 2, 3})) + assert.Nil(t, valkeyBytesToFloat32([]byte{1})) +} + +// --------------------------------------------------------------------------- +// valkeyEscapeTagValue +// --------------------------------------------------------------------------- + +func TestValkeyEscapeTagValue(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + {"hyphens", "file-123", "file\\-123"}, + {"dots", "doc.txt", "doc\\.txt"}, + {"colons", "ns:val", "ns\\:val"}, + {"slashes", "path/to", "path\\/to"}, + {"spaces", "hello world", "hello\\ world"}, + {"multiple specials", "a-b.c:d/e f", "a\\-b\\.c\\:d\\/e\\ f"}, + {"safe string", "abc123", "abc123"}, + {"empty", "", ""}, + {"braces", "a{b}c", "a\\{b\\}c"}, + {"brackets", "a[b]c", "a\\[b\\]c"}, + {"pipe", "a|b", "a\\|b"}, + {"at sign", "user@host", "user\\@host"}, + {"parens", "f(x)", "f\\(x\\)"}, + {"asterisk", "a*b", "a\\*b"}, + {"exclamation", "no!", "no\\!"}, + {"tilde", "~user", "\\~user"}, + {"caret", "a^b", "a\\^b"}, + {"quotes", `a"b'c`, `a\"b\'c`}, + {"hash", "a#b", "a\\#b"}, + {"dollar", "a$b", "a\\$b"}, + {"percent", "a%b", "a\\%b"}, + {"ampersand", "a&b", "a\\&b"}, + {"plus equals", "a+=b", "a\\+\\=b"}, + {"backslash", "a\\b", "a\\\\b"}, + {"semicolon", "a;b", "a\\;b"}, + {"comma", "a,b", "a\\,b"}, + {"angle brackets", "ac", "a\\c"}, + {"tab", "a\tb", "a\\\tb"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, valkeyEscapeTagValue(tc.input)) + }) + } +} + +// --------------------------------------------------------------------------- +// valkeyDistanceToSimilarity +// --------------------------------------------------------------------------- + +func TestValkeyDistanceToSimilarity(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metricType string + distance float64 + expected float64 + tolerance float64 + }{ + {"COSINE zero", "COSINE", 0.0, 1.0, 0.001}, + {"COSINE 0.2", "COSINE", 0.2, 0.9, 0.001}, + {"COSINE 2.0", "COSINE", 2.0, 0.0, 0.001}, + {"L2 zero", "L2", 0.0, 1.0, 0.001}, + {"L2 0.3", "L2", 0.3, 0.769, 0.01}, + {"IP identity", "IP", 0.95, 0.95, 0.001}, + {"case insensitive", "cosine", 0.2, 0.9, 0.001}, + {"unknown metric warns", "UNKNOWN", 0.3, 0.7, 0.001}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := valkeyDistanceToSimilarity(tc.metricType, tc.distance) + assert.InDelta(t, tc.expected, result, tc.tolerance) + }) + } +} + +// --------------------------------------------------------------------------- +// valkeyToInt64 +// --------------------------------------------------------------------------- + +func TestValkeyToInt64(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input interface{} + expected int64 + }{ + {"int64", int64(42), 42}, + {"float64", float64(42.9), 42}, + {"string", "123", 123}, + {"nil", nil, 0}, + {"bool", true, 0}, + {"invalid string", "abc", 0}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, valkeyToInt64(tc.input)) + }) + } +} + +// --------------------------------------------------------------------------- +// valkeyParseScoreFromMap +// --------------------------------------------------------------------------- + +func TestValkeyParseScoreFromMap(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + fields map[string]interface{} + key string + metricType string + expected float64 + tolerance float64 + }{ + { + "valid COSINE distance", + map[string]interface{}{"vector_distance": "0.2"}, + "vector_distance", "COSINE", 0.9, 0.01, + }, + { + "missing key", + map[string]interface{}{"other": "0.2"}, + "vector_distance", "COSINE", 0.0, 0.001, + }, + { + "invalid number", + map[string]interface{}{"vector_distance": "abc"}, + "vector_distance", "COSINE", 0.0, 0.001, + }, + { + "float64 value", + map[string]interface{}{"vector_distance": float64(0.4)}, + "vector_distance", "COSINE", 0.8, 0.01, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := valkeyParseScoreFromMap(tc.fields, tc.key, tc.metricType) + assert.InDelta(t, tc.expected, result, tc.tolerance) + }) + } +} + +// --------------------------------------------------------------------------- +// valkeyFieldsToMemory +// --------------------------------------------------------------------------- + +func TestValkeyFieldsToMemory(t *testing.T) { + t.Parallel() + + t.Run("full fields", func(t *testing.T) { + t.Parallel() + embedding := []float32{0.1, 0.2, 0.3} + fields := map[string]string{ + "id": "mem_123", + "content": "test content", + "user_id": "user1", + "memory_type": "semantic", + "metadata": `{"project_id":"proj1","source":"conversation","importance":0.8,"access_count":3,"last_accessed":1700000000}`, + "created_at": "1700000000", + "updated_at": "1700000100", + "embedding": string(valkeyFloat32ToBytes(embedding)), + "access_count": "3", + "importance": "0.8", + } + + mem := valkeyFieldsToMemory(fields) + assert.Equal(t, "mem_123", mem.ID) + assert.Equal(t, "test content", mem.Content) + assert.Equal(t, "user1", mem.UserID) + assert.Equal(t, MemoryType("semantic"), mem.Type) + assert.Equal(t, "proj1", mem.ProjectID) + assert.Equal(t, "conversation", mem.Source) + assert.InDelta(t, float32(0.8), mem.Importance, 0.01) + assert.Equal(t, 3, mem.AccessCount) + assert.False(t, mem.CreatedAt.IsZero()) + assert.False(t, mem.UpdatedAt.IsZero()) + require.Len(t, mem.Embedding, 3) + assert.InDelta(t, float32(0.1), mem.Embedding[0], 0.001) + }) + + t.Run("empty fields", func(t *testing.T) { + t.Parallel() + mem := valkeyFieldsToMemory(map[string]string{}) + assert.Empty(t, mem.ID) + assert.Nil(t, mem.Embedding) + }) + + t.Run("invalid metadata JSON", func(t *testing.T) { + t.Parallel() + fields := map[string]string{ + "id": "mem_456", + "metadata": "not valid json", + } + mem := valkeyFieldsToMemory(fields) + assert.Equal(t, "mem_456", mem.ID) + // Should not panic; metadata fields remain zero-value + assert.Empty(t, mem.ProjectID) + }) + + t.Run("invalid timestamps", func(t *testing.T) { + t.Parallel() + fields := map[string]string{ + "id": "mem_789", + "created_at": "not_a_number", + "updated_at": "", + } + mem := valkeyFieldsToMemory(fields) + assert.True(t, mem.CreatedAt.IsZero()) + assert.True(t, mem.UpdatedAt.IsZero()) + }) +} + +// --------------------------------------------------------------------------- +// valkeyFieldsMapToMemory (FT.SEARCH result format) +// --------------------------------------------------------------------------- + +func TestValkeyFieldsMapToMemory(t *testing.T) { + t.Parallel() + + t.Run("full fields", func(t *testing.T) { + t.Parallel() + fields := map[string]interface{}{ + "id": "mem_100", + "content": "search result content", + "user_id": "user2", + "memory_type": "procedural", + "metadata": `{"project_id":"proj2","source":"extraction","importance":0.5,"access_count":1}`, + "created_at": "1700000000", + "updated_at": "1700000200", + } + + mem := valkeyFieldsMapToMemory(fields) + assert.Equal(t, "mem_100", mem.ID) + assert.Equal(t, "search result content", mem.Content) + assert.Equal(t, "user2", mem.UserID) + assert.Equal(t, MemoryType("procedural"), mem.Type) + assert.Equal(t, "proj2", mem.ProjectID) + assert.Equal(t, "extraction", mem.Source) + assert.InDelta(t, float32(0.5), mem.Importance, 0.01) + }) + + t.Run("empty map", func(t *testing.T) { + t.Parallel() + mem := valkeyFieldsMapToMemory(map[string]interface{}{}) + assert.Empty(t, mem.ID) + }) + + t.Run("non-string fields ignored", func(t *testing.T) { + t.Parallel() + fields := map[string]interface{}{ + "id": 123, // not a string + "content": true, + } + mem := valkeyFieldsMapToMemory(fields) + assert.Empty(t, mem.ID) + assert.Empty(t, mem.Content) + }) +} + +// --------------------------------------------------------------------------- +// parseSearchCandidates +// --------------------------------------------------------------------------- + +func TestValkeyStore_ParseSearchCandidates(t *testing.T) { + t.Parallel() + + store := &ValkeyStore{metricType: "COSINE"} + + t.Run("nil input", func(t *testing.T) { + t.Parallel() + assert.Nil(t, store.parseSearchCandidates(nil, "user1")) + }) + + t.Run("non-array input", func(t *testing.T) { + t.Parallel() + assert.Nil(t, store.parseSearchCandidates("not an array", "user1")) + }) + + t.Run("zero total count", func(t *testing.T) { + t.Parallel() + assert.Nil(t, store.parseSearchCandidates([]interface{}{int64(0)}, "user1")) + }) + + t.Run("valid single result", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(1), map[string]interface{}{ + "mem:1": map[string]interface{}{ + "id": "mem_1", "content": "hello world", "memory_type": "semantic", + "metadata": `{"user_id":"user1","project_id":"proj1","source":"conversation"}`, + "vector_distance": "0.2", + }, + }} + + candidates := store.parseSearchCandidates(result, "user1") + require.Len(t, candidates, 1) + assert.Equal(t, "mem_1", candidates[0].Memory.ID) + assert.Equal(t, "hello world", candidates[0].Memory.Content) + assert.Equal(t, "user1", candidates[0].Memory.UserID) + assert.Equal(t, "proj1", candidates[0].Memory.ProjectID) + assert.InDelta(t, 0.9, float64(candidates[0].Score), 0.01) + }) + + t.Run("multiple results sorted by score descending", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(2), map[string]interface{}{ + "mem:1": map[string]interface{}{ + "id": "mem_1", "content": "good", "memory_type": "semantic", + "metadata": `{"user_id":"u1"}`, "vector_distance": "0.4", + }, + "mem:2": map[string]interface{}{ + "id": "mem_2", "content": "better", "memory_type": "semantic", + "metadata": `{"user_id":"u1"}`, "vector_distance": "0.1", + }, + }} + + candidates := store.parseSearchCandidates(result, "u1") + require.Len(t, candidates, 2) + assert.GreaterOrEqual(t, candidates[0].Score, candidates[1].Score, "results should be sorted descending by score") + assert.Equal(t, "mem_2", candidates[0].Memory.ID) // lower distance = higher similarity + }) + + t.Run("skips entries with missing id", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(1), map[string]interface{}{ + "mem:1": map[string]interface{}{ + "content": "no id", "memory_type": "semantic", + "metadata": `{}`, "vector_distance": "0.1", + }, + }} + + candidates := store.parseSearchCandidates(result, "u1") + assert.Empty(t, candidates) + }) + + t.Run("default user ID from parameter", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(1), map[string]interface{}{ + "mem:1": map[string]interface{}{ + "id": "mem_1", "content": "test", "memory_type": "semantic", + "metadata": `{}`, "vector_distance": "0.2", + }, + }} + + candidates := store.parseSearchCandidates(result, "default_user") + require.Len(t, candidates, 1) + assert.Equal(t, "default_user", candidates[0].Memory.UserID) + }) +} + +// --------------------------------------------------------------------------- +// parseListSearchResults +// --------------------------------------------------------------------------- + +func TestValkeyStore_ParseListSearchResults(t *testing.T) { + t.Parallel() + + store := &ValkeyStore{metricType: "COSINE"} + + t.Run("nil input", func(t *testing.T) { + t.Parallel() + assert.Nil(t, store.parseListSearchResults(nil)) + }) + + t.Run("zero total count", func(t *testing.T) { + t.Parallel() + assert.Nil(t, store.parseListSearchResults([]interface{}{int64(0)})) + }) + + t.Run("valid results", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(2), map[string]interface{}{ + "mem:1": map[string]interface{}{ + "id": "mem_1", "content": "first", "user_id": "u1", + "memory_type": "semantic", "metadata": `{"project_id":"p1"}`, + "created_at": "1700000000", "updated_at": "1700000100", + }, + "mem:2": map[string]interface{}{ + "id": "mem_2", "content": "second", "user_id": "u1", + "memory_type": "procedural", "metadata": `{"project_id":"p2"}`, + "created_at": "1700000200", "updated_at": "1700000300", + }, + }} + + memories := store.parseListSearchResults(result) + require.Len(t, memories, 2) + }) +} + +// --------------------------------------------------------------------------- +// extractIDsFromSearchResult +// --------------------------------------------------------------------------- + +func TestValkeyStore_ExtractIDsFromSearchResult(t *testing.T) { + t.Parallel() + + store := &ValkeyStore{} + + t.Run("no project filter", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(2), map[string]interface{}{ + "mem:1": map[string]interface{}{"id": "mem_1", "metadata": `{"project_id":"proj1"}`}, + "mem:2": map[string]interface{}{"id": "mem_2", "metadata": `{"project_id":"proj2"}`}, + }} + + ids := store.extractIDsFromSearchResult(result, "") + assert.Len(t, ids, 2) + }) + + t.Run("with project filter", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(2), map[string]interface{}{ + "mem:1": map[string]interface{}{"id": "mem_1", "metadata": `{"project_id":"proj1"}`}, + "mem:2": map[string]interface{}{"id": "mem_2", "metadata": `{"project_id":"proj2"}`}, + }} + + ids := store.extractIDsFromSearchResult(result, "proj1") + assert.Len(t, ids, 1) + assert.Equal(t, "mem_1", ids[0]) + }) + + t.Run("nil input", func(t *testing.T) { + t.Parallel() + assert.Nil(t, store.extractIDsFromSearchResult(nil, "")) + }) + + t.Run("zero results", func(t *testing.T) { + t.Parallel() + assert.Nil(t, store.extractIDsFromSearchResult([]interface{}{int64(0)}, "")) + }) +} + +// --------------------------------------------------------------------------- +// hashKey +// --------------------------------------------------------------------------- + +func TestValkeyStore_HashKey(t *testing.T) { + t.Parallel() + + store := &ValkeyStore{collectionPrefix: "mem:"} + assert.Equal(t, "mem:abc123", store.hashKey("abc123")) + + store2 := &ValkeyStore{collectionPrefix: "custom_prefix:"} + assert.Equal(t, "custom_prefix:xyz", store2.hashKey("xyz")) +} + +// --------------------------------------------------------------------------- +// NewValkeyStore validation +// --------------------------------------------------------------------------- + +func TestNewValkeyStore_Disabled(t *testing.T) { + t.Parallel() + store, err := NewValkeyStore(ValkeyStoreOptions{Enabled: false}) + require.NoError(t, err) + assert.False(t, store.IsEnabled()) +} + +func TestNewValkeyStore_NilClient(t *testing.T) { + t.Parallel() + _, err := NewValkeyStore(ValkeyStoreOptions{ + Enabled: true, + ValkeyConfig: &config.MemoryValkeyConfig{}, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "valkey client is required") +} + +func TestNewValkeyStore_NilConfig(t *testing.T) { + t.Parallel() + _, err := NewValkeyStore(ValkeyStoreOptions{ + Enabled: true, + // Client would be non-nil in a real test, but we'll hit the config check first + // since we validate config before using client + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "valkey client is required") +} + +// --------------------------------------------------------------------------- +// Config defaults +// --------------------------------------------------------------------------- + +func TestValkeyStore_ConfigDefaults(t *testing.T) { + t.Parallel() + + // Verify default values are applied when config fields are zero-valued + store := &ValkeyStore{} + + // These are tested implicitly through NewValkeyStore, but we also verify + // the struct fields directly + assert.Empty(t, store.indexName) + assert.Empty(t, store.collectionPrefix) + assert.Empty(t, store.metricType) + assert.Equal(t, 0, store.dimension) +} + +// --------------------------------------------------------------------------- +// extractHashKeysFromSearchResult +// --------------------------------------------------------------------------- + +func TestValkeyStore_ExtractHashKeysFromSearchResult(t *testing.T) { + t.Parallel() + + store := &ValkeyStore{} + + t.Run("nil input", func(t *testing.T) { + t.Parallel() + assert.Nil(t, store.extractHashKeysFromSearchResult(nil)) + }) + + t.Run("empty array", func(t *testing.T) { + t.Parallel() + assert.Nil(t, store.extractHashKeysFromSearchResult([]interface{}{int64(0)})) + }) + + t.Run("map format", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(2), map[string]interface{}{ + "mem:key1": map[string]interface{}{"id": "mem_1"}, + "mem:key2": map[string]interface{}{"id": "mem_2"}, + }} + + keys := store.extractHashKeysFromSearchResult(result) + assert.Len(t, keys, 2) + assert.Contains(t, keys, "mem:key1") + assert.Contains(t, keys, "mem:key2") + }) + + t.Run("string format", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(1), "mem:key1"} + + keys := store.extractHashKeysFromSearchResult(result) + assert.Len(t, keys, 1) + assert.Equal(t, "mem:key1", keys[0]) + }) +} + +// --------------------------------------------------------------------------- +// extractTotalCount +// --------------------------------------------------------------------------- + +func TestValkeyStore_ExtractTotalCount(t *testing.T) { + t.Parallel() + + store := &ValkeyStore{} + + t.Run("nil input", func(t *testing.T) { + t.Parallel() + assert.Equal(t, 0, store.extractTotalCount(nil)) + }) + + t.Run("valid count", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(42), map[string]interface{}{}} + assert.Equal(t, 42, store.extractTotalCount(result)) + }) + + t.Run("zero count", func(t *testing.T) { + t.Parallel() + result := []interface{}{int64(0)} + assert.Equal(t, 0, store.extractTotalCount(result)) + }) + + t.Run("non-array input", func(t *testing.T) { + t.Parallel() + assert.Equal(t, 0, store.extractTotalCount("not an array")) + }) +} diff --git a/src/vllm-sr/cli/models_memory.py b/src/vllm-sr/cli/models_memory.py index 7ba9a89f8..ccef0e28b 100644 --- a/src/vllm-sr/cli/models_memory.py +++ b/src/vllm-sr/cli/models_memory.py @@ -11,6 +11,25 @@ class MemoryMilvusConfig(BaseModel): dimension: int = 384 +class MemoryValkeyConfig(BaseModel): + """Valkey configuration for memory storage using the Search module.""" + + host: str = "localhost" + port: int = 6379 + database: int = 0 + password: str | None = None + timeout: int = 10 + collection_prefix: str = "mem:" + index_name: str = "mem_idx" + dimension: int = 384 + metric_type: str = "COSINE" + index_m: int = 16 + index_ef_construction: int = 256 + tls_enabled: bool = False + tls_ca_path: str | None = None + tls_insecure_skip_verify: bool = False + + class MemoryRedisCacheConfig(BaseModel): """Redis hot-cache configuration for memory retrieval.""" @@ -34,8 +53,10 @@ class MemoryConfig(BaseModel): """ enabled: bool = True + backend: str = "" # "" or "milvus" → Milvus (default); "valkey" → Valkey auto_store: bool = False # Auto-store extracted facts after each response milvus: MemoryMilvusConfig | None = None + valkey: MemoryValkeyConfig | None = None redis_cache: MemoryRedisCacheConfig | None = None # Embedding model to use for memory vectors # Options: "bert", "mmbert", "multimodal", "qwen3", "gemma" diff --git a/website/docs/installation/valkey-memory.md b/website/docs/installation/valkey-memory.md new file mode 100644 index 000000000..3c583ac4a --- /dev/null +++ b/website/docs/installation/valkey-memory.md @@ -0,0 +1,284 @@ +--- +sidebar_position: 6 +--- + +# Valkey Agentic Memory + +This guide covers deploying Valkey as the agentic memory backend for the Semantic Router. Valkey provides a lightweight, Redis-compatible alternative to Milvus for vector similarity storage using the built-in Search module. + +:::note +Valkey is optional. The default memory backend is Milvus. Use Valkey when you want a single-binary deployment without external dependencies like etcd or MinIO, or when you already run Valkey for caching. +::: + +## When to Use Valkey vs Milvus + +| Concern | Valkey | Milvus | +|---------|--------|--------| +| Deployment complexity | Single binary with Search module | Requires etcd, MinIO/S3, optional Pulsar | +| Horizontal scaling | Cluster mode (manual sharding) | Native distributed architecture | +| Memory model | In-memory with optional persistence | Disk-based with memory-mapped indexes | +| Best for | Small-to-medium workloads, dev/test, existing Redis/Valkey infra | Large-scale production, billions of vectors | +| Vector index | HNSW via FT.CREATE | HNSW, IVF_FLAT, IVF_SQ8, and more | + +## Prerequisites + +- Valkey 8.0+ **with the Search module** enabled +- The `valkey/valkey-bundle` Docker image includes Search out of the box +- For Kubernetes: Helm 3.x and `kubectl` configured + +## Deploy with Docker + +### Quick Start + +```bash +docker run -d --name valkey-memory \ + -p 6379:6379 \ + valkey/valkey-bundle:latest +``` + +Verify the Search module is loaded: + +```bash +docker exec valkey-memory valkey-cli MODULE LIST | grep search +``` + +### With Persistence + +```bash +docker run -d --name valkey-memory \ + -p 6379:6379 \ + -v valkey-data:/data \ + valkey/valkey-bundle:latest \ + valkey-server --appendonly yes +``` + +## Deploy in Kubernetes + +### Using a StatefulSet + +```yaml +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: valkey-memory + namespace: vllm-semantic-router-system +spec: + serviceName: valkey-memory + replicas: 1 + selector: + matchLabels: + app: valkey-memory + template: + metadata: + labels: + app: valkey-memory + spec: + containers: + - name: valkey + image: valkey/valkey-bundle:latest + ports: + - containerPort: 6379 + args: ["valkey-server", "--appendonly", "yes"] + # For production, add --requirepass or mount a Secret: + # args: ["valkey-server", "--appendonly", "yes", "--requirepass", "$(VALKEY_PASSWORD)"] + volumeMounts: + - name: data + mountPath: /data + resources: + requests: + memory: "256Mi" + cpu: "250m" + limits: + memory: "1Gi" + cpu: "1000m" + volumeClaimTemplates: + - metadata: + name: data + spec: + accessModes: ["ReadWriteOnce"] + resources: + requests: + storage: 5Gi +--- +apiVersion: v1 +kind: Service +metadata: + name: valkey-memory + namespace: vllm-semantic-router-system +spec: + selector: + app: valkey-memory + ports: + - port: 6379 + targetPort: 6379 + clusterIP: None +``` + +## Configure the Router + +Add the Valkey memory backend to your `config.yaml`: + +```yaml +global: + stores: + memory: + enabled: true + backend: valkey + auto_store: true + valkey: + host: valkey-memory # Service name or hostname + port: 6379 + database: 0 + timeout: 10 + collection_prefix: "mem:" + index_name: mem_idx + dimension: 384 # Must match your embedding model + metric_type: COSINE # COSINE, L2, or IP + index_m: 16 + index_ef_construction: 256 + embedding_model: bert + default_retrieval_limit: 5 + default_similarity_threshold: 0.70 + hybrid_search: true + hybrid_mode: rerank + adaptive_threshold: true +``` + +### Configuration Reference + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `host` | `localhost` | Valkey server hostname | +| `port` | `6379` | Valkey server port | +| `database` | `0` | Database number (0-15) | +| `password` | _(empty)_ | Authentication password | +| `timeout` | `10` | Connection timeout in seconds | +| `collection_prefix` | `mem:` | Key prefix for HASH documents | +| `index_name` | `mem_idx` | FT.CREATE index name | +| `dimension` | `384` | Embedding vector dimension | +| `metric_type` | `COSINE` | Distance metric: `COSINE`, `L2`, or `IP` | +| `index_m` | `16` | HNSW M parameter (links per node) | +| `index_ef_construction` | `256` | HNSW build-time search width | + +### Optional Redis Hot Cache + +You can layer a Redis/Valkey hot cache in front of the Valkey memory store for frequently accessed memories: + +```yaml + redis_cache: + enabled: true + address: "valkey-memory:6379" + ttl_seconds: 900 + db: 1 # Use a different DB to avoid key collisions + key_prefix: "memory_cache:" +``` + +## Per-Decision Memory Plugin + +Routes can override global memory settings using the `memory` plugin: + +```yaml +routing: + decisions: + - name: personalized_route + plugins: + - type: memory + configuration: + enabled: true + retrieval_limit: 10 + similarity_threshold: 0.60 + auto_store: true +``` + +See the [Memory plugin tutorial](/docs/tutorials/plugin/memory) for details. + +## Performance Tuning + +### HNSW Index Parameters + +- **`index_m`** (default 16): Higher values improve recall at the cost of memory. Use 32-64 for production workloads requiring high accuracy. +- **`index_ef_construction`** (default 256): Higher values improve index quality at the cost of slower builds. Use 512+ for production. + +### Memory Sizing + +Each memory entry uses approximately: + +- HASH fields: ~500-2000 bytes (content, metadata, timestamps) +- Embedding vector: `dimension * 4` bytes (e.g., 384 * 4 = 1.5 KB for BERT) +- HNSW index overhead: ~`dimension * index_m * 4` bytes per entry + +For 100K memories with 384-dimensional embeddings and M=16: + +- Data: ~300 MB +- Index: ~240 MB +- **Total: ~540 MB** plus Valkey base overhead + +### Persistence + +Enable AOF (Append-Only File) for durability: + +```bash +valkey-server --appendonly yes --appendfsync everysec +``` + +For RDB snapshots (point-in-time backups): + +```bash +valkey-server --save 900 1 --save 300 10 +``` + +## Troubleshooting + +### Search Module Not Loaded + +``` +FT.CREATE failed: unknown command 'FT.CREATE' +``` + +Ensure you are using `valkey/valkey-bundle` (includes Search) rather than plain `valkey/valkey`: + +```bash +valkey-cli MODULE LIST +# Should show: name search ver ... +``` + +### Connection Timeout + +``` +valkey: connection timeout +``` + +- Verify the hostname resolves: `nslookup valkey-memory` +- Check port connectivity: `nc -zv valkey-memory 6379` +- Increase `timeout` in the config if the network is slow + +### Index Already Exists + +The router checks for existing indexes on startup and skips creation if one exists. If you need to recreate the index (e.g., after changing `dimension` or `metric_type`): + +```bash +valkey-cli FT.DROPINDEX mem_idx +``` + +The router will recreate it on the next request. + +### Out of Memory + +Valkey stores all data in memory. If you hit the memory limit: + +1. Set `maxmemory` and `maxmemory-policy` in Valkey config +2. Use `quality_scoring.max_memories_per_user` to cap per-user storage +3. Enable memory consolidation to merge similar memories + +## Migration from Milvus + +To switch an existing deployment from Milvus to Valkey: + +1. Update `config.yaml` to set `backend: valkey` and add the `valkey:` block +2. Remove or comment out the `milvus:` block +3. Restart the router — it will create the Valkey index automatically +4. Existing memories in Milvus are **not** automatically migrated + +:::warning +Switching backends does not migrate data. If you need to preserve existing memories, export them from Milvus and re-import via the memory API before switching. +::: diff --git a/website/docs/tutorials/global/stores-and-tools.md b/website/docs/tutorials/global/stores-and-tools.md index c82a173e8..b054cdf80 100644 --- a/website/docs/tutorials/global/stores-and-tools.md +++ b/website/docs/tutorials/global/stores-and-tools.md @@ -41,13 +41,43 @@ global: ### Memory +The memory store supports two backends: `milvus` (default) and `valkey`. + +**Milvus backend** (default): + +```yaml +global: + stores: + memory: + enabled: true + milvus: + address: milvus:19530 + collection: agentic_memory + dimension: 384 +``` + +**Valkey backend** (requires Valkey with Search module): + ```yaml global: stores: memory: enabled: true + backend: valkey + valkey: + host: valkey + port: 6379 + dimension: 384 + collection_prefix: "mem:" + index_name: mem_idx + metric_type: COSINE ``` +For full deployment instructions, see: + +- [Valkey Agentic Memory](../../installation/valkey-memory.md) — Docker, Kubernetes, config reference, tuning, and troubleshooting +- `deploy/examples/runtime/memory/` for backend-specific configuration references + ### Vector Store ```yaml diff --git a/website/docs/tutorials/plugin/memory.md b/website/docs/tutorials/plugin/memory.md index 4ca9587ed..f8b4ae7c4 100644 --- a/website/docs/tutorials/plugin/memory.md +++ b/website/docs/tutorials/plugin/memory.md @@ -24,6 +24,13 @@ Not every route should pay the complexity or privacy cost of memory. `memory` le ## Configuration +The memory plugin requires a backing store configured under `global.stores.memory`. The router supports two backends: + +- **Milvus** (default) — distributed vector database, best for large-scale production +- **Valkey** — lightweight single-binary option using the Search module, best for dev/test or existing Valkey infra + +See the [Stores and Tools](../global/stores-and-tools.md) tutorial for global memory configuration, or the [Valkey Memory deployment guide](../../installation/valkey-memory.md) for Valkey-specific setup. + Use this fragment under `routing.decisions[].plugins`: ```yaml diff --git a/website/sidebars.ts b/website/sidebars.ts index 03057599d..35f866729 100644 --- a/website/sidebars.ts +++ b/website/sidebars.ts @@ -52,6 +52,13 @@ const sidebars: SidebarsConfig = { 'installation/k8s/dynamo', ], }, + { + type: 'category', + label: 'Backend Stores', + items: [ + 'installation/valkey-memory', + ], + }, ], }, {