diff --git a/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go b/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go index a68ddb2c1..f2be0d302 100644 --- a/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go +++ b/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go @@ -126,6 +126,7 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr return nil, fmt.Errorf("failed to create engine adapter: %w", err) } + config.KVEventsConfig.ModelRegistry = kvCacheIndexer.ModelRegistry() pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex(), tokenProcessor, adapter) pool.Start(ctx) diff --git a/examples/kv_events/online/main.go b/examples/kv_events/online/main.go index 3bebbed49..ecd7b5ed3 100644 --- a/examples/kv_events/online/main.go +++ b/examples/kv_events/online/main.go @@ -99,7 +99,7 @@ func run(ctx context.Context) error { } // Setup events pool - eventsPool, err := setupEventsPool(ctx, kvCacheIndexer.KVBlockIndex()) + eventsPool, err := setupEventsPool(ctx, kvCacheIndexer) if err != nil { return err } @@ -212,7 +212,7 @@ func setupKVCacheIndexer(ctx context.Context) (*kvcache.Indexer, error) { return kvCacheIndexer, nil } -func setupEventsPool(ctx context.Context, kvBlockIndex kvblock.Index) (*kvevents.Pool, error) { +func setupEventsPool(ctx context.Context, indexer *kvcache.Indexer) (*kvevents.Pool, error) { logger := log.FromContext(ctx) cfg := getEventsPoolConfig() @@ -227,7 +227,9 @@ func setupEventsPool(ctx context.Context, kvBlockIndex kvblock.Index) (*kvevents return nil, err } - pool := kvevents.NewPool(cfg, kvBlockIndex, tokenProcessor, adapter) + // Use model registry from indexer (configured via modelConfigs in config) + cfg.ModelRegistry = indexer.ModelRegistry() + pool := kvevents.NewPool(cfg, indexer.KVBlockIndex(), tokenProcessor, adapter) return pool, nil } diff --git a/pkg/kvcache/backend.go b/pkg/kvcache/backend.go index 488b9c6e1..075aac519 100644 --- a/pkg/kvcache/backend.go +++ b/pkg/kvcache/backend.go @@ -16,6 +16,8 @@ limitations under the License. package kvcache +import "sync" + type KVCacheBackendConfig struct { // Name is the identifier for this medium (e.g., "gpu", "cpu", "disk") Name string `json:"name"` @@ -29,3 +31,133 @@ func DefaultKVCacheBackendConfig() []*KVCacheBackendConfig { {Name: "cpu", Weight: 0.8}, } } + +// AttentionType defines the type of attention mechanism used by an attention group. +type AttentionType string + +const ( + // AttentionTypeFull represents full/global attention (attends to all previous tokens). + AttentionTypeFull AttentionType = "full" + // AttentionTypeSlidingWindow represents sliding window attention (attends to last N tokens). + AttentionTypeSlidingWindow AttentionType = "sliding_window" + // AttentionTypeLocal represents local attention (attends to nearby tokens only). + AttentionTypeLocal AttentionType = "local" +) + +// AttentionGroupConfig holds configuration for a single attention group in HMA models. +type AttentionGroupConfig struct { + // GroupID is the attention group identifier (e.g., 0 for full attention, 1 for sliding window) + GroupID int `json:"groupId"` + // AttentionType specifies the type of attention mechanism + AttentionType AttentionType `json:"attentionType"` + // BlockSize is the number of tokens per KV-cache block for this group + BlockSize int `json:"blockSize"` + // SlidingWindowSize is the window size for sliding window attention (0 or omitted for full attention) + SlidingWindowSize int `json:"slidingWindowSize,omitempty"` +} + +// ModelConfig holds the configuration for a specific model. +type ModelConfig struct { + // Name is the model identifier (e.g., "Qwen/Qwen3-8B", "DeepSeek-V3") + Name string `json:"name"` + // IsHMA indicates whether this model uses Hybrid Multi-head Attention. + // When true, StoredGroups tracking is enabled for cache entries. + // When false, StoredGroups is left nil to save memory. + IsHMA bool `json:"isHMA"` + // AttentionGroups defines the attention group configuration for HMA models. + // Only used when IsHMA is true. + // Example for DeepSeek-V3: + // [{GroupID: 0, BlockSize: 64, SlidingWindowSize: 0}, // Full attention + // {GroupID: 1, BlockSize: 64, SlidingWindowSize: 4096}] // Sliding window + AttentionGroups []AttentionGroupConfig `json:"attentionGroups,omitempty"` +} + +// ModelRegistry manages model configurations. +// It provides thread-safe access to model metadata needed for event processing. +type ModelRegistry struct { + mu sync.RWMutex + configs map[string]*ModelConfig +} + +// NewModelRegistry creates a new ModelRegistry with optional initial configs. +func NewModelRegistry(initialConfigs []*ModelConfig) *ModelRegistry { + registry := &ModelRegistry{ + configs: make(map[string]*ModelConfig), + } + + for _, config := range initialConfigs { + registry.configs[config.Name] = config + } + + return registry +} + +// GetModelConfig retrieves the configuration for a given model name. +// If the model is not registered, it returns a default non-HMA config. +func (r *ModelRegistry) GetModelConfig(modelName string) *ModelConfig { + r.mu.RLock() + defer r.mu.RUnlock() + + if config, exists := r.configs[modelName]; exists { + return config + } + + // Default: treat unknown models as non-HMA for memory efficiency + return &ModelConfig{ + Name: modelName, + IsHMA: false, + } +} + +// RegisterModel adds or updates a model configuration. +func (r *ModelRegistry) RegisterModel(config *ModelConfig) { + r.mu.Lock() + defer r.mu.Unlock() + r.configs[config.Name] = config +} + +// IsHMA checks if a model uses Hybrid Multi-head Attention. +// Returns false for unknown models. +func (r *ModelRegistry) IsHMA(modelName string) bool { + return r.GetModelConfig(modelName).IsHMA +} + +// GetAttentionGroups returns the attention group configuration for a model. +// Returns nil for simple (non-HMA) models or unknown models. +func (r *ModelRegistry) GetAttentionGroups(modelName string) []AttentionGroupConfig { + config := r.GetModelConfig(modelName) + if !config.IsHMA { + return nil + } + return config.AttentionGroups +} + +// GetGroupBlockSize returns the block size for a specific attention group. +// Returns 0 if the model or group is not found. +func (r *ModelRegistry) GetGroupBlockSize(modelName string, groupID int) int { + groups := r.GetAttentionGroups(modelName) + for _, group := range groups { + if group.GroupID == groupID { + return group.BlockSize + } + } + return 0 +} + +// GetGroupSlidingWindow returns the sliding window size for a specific attention group. +// Returns 0 for full attention groups or if not found. +func (r *ModelRegistry) GetGroupSlidingWindow(modelName string, groupID int) int { + groups := r.GetAttentionGroups(modelName) + for _, group := range groups { + if group.GroupID == groupID { + return group.SlidingWindowSize + } + } + return 0 +} + +// NewDefaultModelRegistry creates a ModelRegistry with common defaults. +// By default, all models are treated as non-HMA for memory efficiency. +func NewDefaultModelRegistry() *ModelRegistry { + return NewModelRegistry(nil) +} diff --git a/pkg/kvcache/indexer.go b/pkg/kvcache/indexer.go index 20d447179..cdd1201ea 100644 --- a/pkg/kvcache/indexer.go +++ b/pkg/kvcache/indexer.go @@ -42,6 +42,7 @@ type Config struct { KVBlockScorerConfig *KVBlockScorerConfig // not exported TokenizersPoolConfig *tokenization.Config `json:"tokenizersPoolConfig"` BackendConfigs []*KVCacheBackendConfig `json:"kvCacheBackendConfigs"` + ModelConfigs []*ModelConfig `json:"modelConfigs,omitempty"` } // NewDefaultConfig returns a default configuration for the Indexer module. @@ -66,6 +67,7 @@ type Indexer struct { tokenProcessor kvblock.TokenProcessor // turns tokens to kv block keys kvBlockIndex kvblock.Index // looks up pods for block keys kvBlockScorer KVBlockScorer // scores pods based on block hits + modelRegistry *ModelRegistry // manages model-specific configurations tokenizersPool TokenizersPool } @@ -104,11 +106,21 @@ func NewKVCacheIndexer(ctx context.Context, config *Config, tokenProcessor kvblo return nil, fmt.Errorf("failed to create tokenizers pool: %w", err) } + // Create model registry from config + var modelRegistry *ModelRegistry + if len(config.ModelConfigs) > 0 { + modelRegistry = NewModelRegistry(config.ModelConfigs) + } else { + // Use default registry (all models treated as non-HMA) + modelRegistry = NewDefaultModelRegistry() + } + return &Indexer{ config: config, tokenProcessor: tokenProcessor, kvBlockIndex: kvBlockIndex, kvBlockScorer: scorer, + modelRegistry: modelRegistry, tokenizersPool: tokenizersPool, }, nil } @@ -123,6 +135,11 @@ func (k *Indexer) KVBlockIndex() kvblock.Index { return k.kvBlockIndex } +// ModelRegistry returns the ModelRegistry used by the Indexer. +func (k *Indexer) ModelRegistry() *ModelRegistry { + return k.modelRegistry +} + // ComputeBlockKeys computes the KV-block keys for a given prompt and model name. // This method extracts the tokenization and block key computation logic so that // callers (e.g., IGW::EPP::PrepareDataPlugin) can compute block keys once and reuse them diff --git a/pkg/kvcache/kvblock/cost_aware_memory.go b/pkg/kvcache/kvblock/cost_aware_memory.go index e3f2817a8..ad04fa79e 100644 --- a/pkg/kvcache/kvblock/cost_aware_memory.go +++ b/pkg/kvcache/kvblock/cost_aware_memory.go @@ -103,23 +103,85 @@ func (m *CostAwareMemoryIndex) MaxCost() int64 { // CostPodCache wraps a sync.Map of PodEntry and provides cost calculation for memory usage estimation. type CostPodCache struct { - cache sync.Map // map[PodEntry]struct{} + cache sync.Map // map[string]*PodEntry (key: "podID@tier") // size tracks the number of entries in cache for O(1) Len(). size atomic.Int64 } -// Add adds a PodEntry to the cache. +// Add adds or updates a PodEntry in the cache, merging StoredGroups if the entry exists. func (c *CostPodCache) Add(entry PodEntry) { - if _, loaded := c.cache.LoadOrStore(entry, struct{}{}); !loaded { + cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative) + + // Try to load existing entry + if existingVal, loaded := c.cache.Load(cacheKey); loaded { //nolint:nestif // Existing complexity + if existingEntry, ok := existingVal.(*PodEntry); ok { + // Check StoredGroups to determine simple vs HMA model (same pattern as Evict) + if entry.StoredGroups == nil { + // Simple model - no group tracking needed + } else { + // HMA model - merge groups + existingEntry.StoredGroups = mergeGroupsUnique(existingEntry.StoredGroups, entry.StoredGroups) + } + // Store updated entry + c.cache.Store(cacheKey, existingEntry) + } + } else { + // Create new entry + newEntry := &PodEntry{ + PodIdentifier: entry.PodIdentifier, + DeviceTier: entry.DeviceTier, + Speculative: entry.Speculative, + StoredGroups: entry.StoredGroups, // nil for simple models, []int for HMA + } + c.cache.Store(cacheKey, newEntry) c.size.Add(1) } } -// Delete removes a PodEntry from the cache. +// Delete removes a PodEntry from the cache entirely. func (c *CostPodCache) Delete(entry PodEntry) { - if _, loaded := c.cache.LoadAndDelete(entry); loaded { + cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative) + if _, loaded := c.cache.LoadAndDelete(cacheKey); loaded { + c.size.Add(-1) + } +} + +// RemoveGroups removes specified groups from a PodEntry's StoredGroups. +// If no groups remain, the entry is deleted. +func (c *CostPodCache) RemoveGroups(entry PodEntry) bool { + cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative) + + existingVal, loaded := c.cache.Load(cacheKey) + if !loaded { + return false + } + + existingEntry, ok := existingVal.(*PodEntry) + if !ok { + return false + } + + // For simple (non-HMA) models: StoredGroups is nil, remove entire entry + if entry.StoredGroups == nil { + c.cache.Delete(cacheKey) + c.size.Add(-1) + return true + } + + // For HMA models: remove specific groups + updatedGroups := removeGroups(existingEntry.StoredGroups, entry.StoredGroups) + + if len(updatedGroups) == 0 { + // No groups left, delete the entry + c.cache.Delete(cacheKey) c.size.Add(-1) + return true } + + // Update with remaining groups + existingEntry.StoredGroups = updatedGroups + c.cache.Store(cacheKey, existingEntry) + return false } // Len returns the number of entries in the cache. @@ -141,16 +203,22 @@ func (c *CostPodCache) CalculateByteSize(keyStr string) int64 { // Count entries and calculate their size c.cache.Range(func(key, value interface{}) bool { - entry, ok := key.(PodEntry) - if !ok { + // key is now a string, value is *PodEntry + keyStr, okKey := key.(string) + entry, okEntry := value.(*PodEntry) + if !okKey || !okEntry { return true } entryCount++ - totalBytes += int64(len(entry.PodIdentifier)) // PodIdentifier string content - totalBytes += int64(len(entry.DeviceTier)) // DeviceTier string content - totalBytes += 32 // string headers (16 bytes each for 2 strings) - totalBytes += 8 // struct padding/alignment + totalBytes += int64(len(keyStr)) // cache key string + totalBytes += int64(len(entry.PodIdentifier)) // PodIdentifier string content + totalBytes += int64(len(entry.DeviceTier)) // DeviceTier string content + totalBytes += int64(len(entry.StoredGroups) * 8) // StoredGroups slice (8 bytes per int) + totalBytes += 32 // string headers (16 bytes each for 2 strings) + totalBytes += 24 // slice header for StoredGroups + totalBytes += 8 // pointer to PodEntry + totalBytes += 8 // struct padding/alignment return true }) @@ -234,17 +302,17 @@ func (m *CostAwareMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHa if podIdentifierSet.Len() == 0 { // If no pod identifiers are provided, return all pods pods.cache.Range(func(k, value interface{}) bool { - if pod, ok := k.(PodEntry); ok { - podsPerKey[key] = append(podsPerKey[key], pod) + if pod, ok := value.(*PodEntry); ok { + podsPerKey[key] = append(podsPerKey[key], *pod) } return true }) } else { // Filter pods based on the provided pod identifiers pods.cache.Range(func(k, value interface{}) bool { - if pod, ok := k.(PodEntry); ok { + if pod, ok := value.(*PodEntry); ok { if podIdentifierSet.Has(pod.PodIdentifier) { - podsPerKey[key] = append(podsPerKey[key], pod) + podsPerKey[key] = append(podsPerKey[key], *pod) } } return true @@ -307,7 +375,8 @@ func (m *CostAwareMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType podCacheLenBefore := podCache.Len() for _, entry := range entries { - podCache.Delete(entry) + // Remove groups from the entry; if no groups remain, the entry is deleted + podCache.RemoveGroups(entry) } if podCache.Len() == 0 { diff --git a/pkg/kvcache/kvblock/in_memory.go b/pkg/kvcache/kvblock/in_memory.go index bc3a50c54..5a4448bd5 100644 --- a/pkg/kvcache/kvblock/in_memory.go +++ b/pkg/kvcache/kvblock/in_memory.go @@ -88,9 +88,9 @@ var _ Index = &InMemoryIndex{} // PodCache represents a cache for pod entries. type PodCache struct { - // cache is an LRU cache that maps PodEntry to their last access time. - // thread-safe. - cache *lru.Cache[PodEntry, struct{}] + // cache is an LRU cache that maps "podID@tier" keys to PodEntry pointers. + // This allows in-place updates of StoredGroups without recreating entries. + cache *lru.Cache[string, *PodEntry] // mu protects the cache from concurrent access during check-and-set operations. mu sync.Mutex } @@ -126,12 +126,14 @@ func (m *InMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHash, if podIdentifierSet.Len() == 0 { // If no pod identifiers are provided, return all pods - podsPerKey[requestKey] = pods.cache.Keys() + for _, podEntry := range pods.cache.Values() { + podsPerKey[requestKey] = append(podsPerKey[requestKey], *podEntry) + } } else { // Filter pods based on the provided pod identifiers - for _, pod := range pods.cache.Keys() { - if podIdentifierSet.Has(pod.PodIdentifier) { - podsPerKey[requestKey] = append(podsPerKey[requestKey], pod) + for _, podEntry := range pods.cache.Values() { + if podIdentifierSet.Has(podEntry.PodIdentifier) { + podsPerKey[requestKey] = append(podsPerKey[requestKey], *podEntry) } } } @@ -174,7 +176,7 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block //nolint:nestif // double-checked locking pattern if !found { // Create new cache - cache, err := lru.New[PodEntry, struct{}](m.podCacheSize) + cache, err := lru.New[string, *PodEntry](m.podCacheSize) if err != nil { return fmt.Errorf("failed to create pod cache for key %s: %w", requestKey.String(), err) } @@ -201,11 +203,37 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block podCache.mu.Lock() for _, entry := range entries { - podCache.cache.Add(entry, struct{}{}) + cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative) + + // Check if entry already exists + existingEntry, found := podCache.cache.Get(cacheKey) + if found { + // Check StoredGroups to determine simple vs HMA model (same pattern as Evict) + if entry.StoredGroups == nil { + // Simple model - no group tracking needed, just update LRU + traceLogger.Info("updated existing pod entry (simple model)", + "requestKey", requestKey, "pod", existingEntry) + } else { + // HMA model - merge groups + existingEntry.StoredGroups = mergeGroupsUnique(existingEntry.StoredGroups, entry.StoredGroups) + traceLogger.Info("updated existing pod entry with merged groups", + "requestKey", requestKey, "pod", existingEntry) + } + // Re-add to update LRU position + podCache.cache.Add(cacheKey, existingEntry) + } else { + // Create new entry (copy to avoid mutation) + newEntry := &PodEntry{ + PodIdentifier: entry.PodIdentifier, + DeviceTier: entry.DeviceTier, + Speculative: entry.Speculative, + StoredGroups: entry.StoredGroups, // nil for simple models, []int for HMA + } + podCache.cache.Add(cacheKey, newEntry) + traceLogger.Info("added new pod entry", "requestKey", requestKey, "pod", newEntry) + } } podCache.mu.Unlock() - - traceLogger.Info("added pods to key", "requestKey", requestKey, "pods", entries) } return nil @@ -252,13 +280,43 @@ func (m *InMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType KeyTyp podCache.mu.Lock() for _, entry := range entries { - podCache.cache.Remove(entry) + cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative) + + existingEntry, found := podCache.cache.Get(cacheKey) + if !found { + traceLogger.Info("pod entry not found for eviction, skipping", + "requestKey", requestKey, "podID", entry.PodIdentifier, "tier", entry.DeviceTier) + continue + } + + // For simple (non-HMA) models: StoredGroups is nil, remove entire entry + if entry.StoredGroups == nil { + podCache.cache.Remove(cacheKey) + traceLogger.Info("removed pod entry (simple model)", + "requestKey", requestKey, "pod", existingEntry) + } else { + // For HMA models: remove specific groups + updatedGroups := removeGroups(existingEntry.StoredGroups, entry.StoredGroups) + + if len(updatedGroups) == 0 { + // No groups left, remove the entire pod entry + podCache.cache.Remove(cacheKey) + traceLogger.Info("removed pod entry (no groups remaining)", + "requestKey", requestKey, "pod", existingEntry) + } else { + // Update with remaining groups + existingEntry.StoredGroups = updatedGroups + podCache.cache.Add(cacheKey, existingEntry) + traceLogger.Info("updated pod entry after group removal", + "requestKey", requestKey, "pod", existingEntry, "remainingGroups", updatedGroups) + } + } } isEmpty := podCache.cache.Len() == 0 podCache.mu.Unlock() - traceLogger.Info("evicted pods from key", "requestKey", requestKey, "key", key, "keyType", keyType, "pods", entries) + traceLogger.Info("processed eviction", "requestKey", requestKey, "key", key, "keyType", keyType, "entries", entries) // Remove key from main cache if empty. // Re-fetch and hold the lock through removal to prevent racing with Add. @@ -294,6 +352,59 @@ func (m *InMemoryIndex) GetRequestKey(ctx context.Context, engineKey BlockHash) return requestKey, nil } +// podCacheKey generates a cache key for a pod entry. +// Format: "podIdentifier@deviceTier" or "podIdentifier@deviceTier[speculative]". +func podCacheKey(podIdentifier, deviceTier string, speculative bool) string { + key := podIdentifier + "@" + deviceTier + if speculative { + key += "[speculative]" + } + return key +} + +// mergeGroupsUnique merges two group lists, removing duplicates and preserving order. +// Elements from 'existing' come first, followed by new elements from 'incoming'. +// This function should only be called for HMA models where group tracking is enabled. +func mergeGroupsUnique(existing, incoming []int) []int { + // If incoming is empty, return existing as-is + if len(incoming) == 0 { + return existing + } + firstIncoming := incoming[0] + + for _, v := range existing { + if v == firstIncoming { + return existing // Already there, nothing to do + } + } + result := make([]int, 0, len(existing)+1) + result = append(result, existing...) + result = append(result, firstIncoming) + return result +} + +// removeGroups removes specified groups from the list, +// maintaining order of remaining elements. +// This function should only be called for HMA models where group tracking is enabled. +func removeGroups(existing, toRemove []int) []int { + if len(toRemove) == 0 || len(existing) == 0 { + return existing + } + target := toRemove[0] + targetIdx := -1 + for i, v := range existing { + if v == target { + targetIdx = i + break + } + } + if targetIdx == -1 { + return existing + } + copy(existing[targetIdx:], existing[targetIdx+1:]) + return existing[:len(existing)-1] +} + // podsPerKeyPrintHelper formats a map of keys to pod names for printing. func podsPerKeyPrintHelper(ks map[BlockHash][]PodEntry) string { var b strings.Builder diff --git a/pkg/kvcache/kvblock/index.go b/pkg/kvcache/kvblock/index.go index 9e0a49457..360b84cbf 100644 --- a/pkg/kvcache/kvblock/index.go +++ b/pkg/kvcache/kvblock/index.go @@ -172,6 +172,10 @@ type PodEntry struct { DeviceTier string // Speculative indicates the entry was added predictively before a KV event confirmed it. Speculative bool + // StoredGroups tracks the group IDs that have stored this block (for HMA models). + // - nil: Simple (non-HMA) model - no group tracking, saves memory + // - []int{0, 1, ...}: HMA model - tracks which attention groups cached this block + StoredGroups []int } // String returns a string representation of the PodEntry. diff --git a/pkg/kvcache/kvblock/redis.go b/pkg/kvcache/kvblock/redis.go index 921b47767..f7f965d65 100644 --- a/pkg/kvcache/kvblock/redis.go +++ b/pkg/kvcache/kvblock/redis.go @@ -18,6 +18,7 @@ package kvblock import ( "context" + "encoding/json" "errors" "fmt" "strconv" @@ -173,12 +174,11 @@ func (r *RedisIndex) Lookup(ctx context.Context, requestKeys []BlockHash, // pipeline for single RTT pipe := r.RedisClient.Pipeline() - results := make([]*redis.StringSliceCmd, len(requestKeys)) + results := make([]*redis.MapStringStringCmd, len(requestKeys)) - // queue an HKeys command for each key in the pipeline + // queue an HGetAll command for each key in the pipeline to get all field:value pairs for i, key := range requestKeys { - // HKeys gets all field names - results[i] = pipe.HKeys(ctx, key.String()) + results[i] = pipe.HGetAll(ctx, key.String()) } _, execErr := pipe.Exec(ctx) @@ -191,8 +191,8 @@ func (r *RedisIndex) Lookup(ctx context.Context, requestKeys []BlockHash, for idx, cmd := range results { key := requestKeys[idx] - // cmd.Result() returns the slice of strings (pod IDs) which is the first layer in the mapping - pods, cmdErr := cmd.Result() + // cmd.Result() returns a map[string]string of entryKey -> JSON data + entryMap, cmdErr := cmd.Result() if cmdErr != nil { if !errors.Is(cmdErr, redis.Nil) { logger.Error(cmdErr, "failed to get pods for key", "key", key) @@ -201,23 +201,26 @@ func (r *RedisIndex) Lookup(ctx context.Context, requestKeys []BlockHash, return podsPerKey, nil // early stop since prefix-chain breaks here } + if len(entryMap) == 0 { + logger.Info("no pods found for key, cutting search", "key", key) + return podsPerKey, nil // early stop since prefix-chain breaks here + } + var filteredPods []PodEntry - for _, p := range pods { - ip := strings.SplitN(p, "@", 2)[0] - if !filterPods || podIdentifierSet.Has(ip) { - tier := strings.SplitN(p, "@", 2)[1] - speculative := false - // Strip annotation suffix e.g. "gpu[speculative]" -> "gpu" - if idx := strings.Index(tier, "["); idx != -1 { - speculative = strings.Contains(tier[idx:], "speculative") - tier = tier[:idx] - } - filteredPods = append(filteredPods, PodEntry{PodIdentifier: ip, DeviceTier: tier, Speculative: speculative}) + for _, jsonData := range entryMap { + var entry PodEntry + if err := json.Unmarshal([]byte(jsonData), &entry); err != nil { + logger.Error(err, "failed to unmarshal pod entry", "key", key, "data", jsonData) + continue + } + + if !filterPods || podIdentifierSet.Has(entry.PodIdentifier) { + filteredPods = append(filteredPods, entry) } } if len(filteredPods) == 0 { - logger.Info("no pods found for key, cutting search", "key", key) + logger.Info("no pods found for key after filtering, cutting search", "key", key) return podsPerKey, nil // early stop since prefix-chain breaks here } @@ -238,22 +241,73 @@ func (r *RedisIndex) Add(ctx context.Context, engineKeys, requestKeys []BlockHas return fmt.Errorf("mismatch between engine keys and request keys length") } - pipe := r.RedisClient.Pipeline() for i, requestKey := range requestKeys { redisKey := requestKey.String() // Store engineKey -> requestKey mapping (only if engineKeys provided) if engineKeys != nil { - pipe.Set(ctx, redisEngineKey(engineKeys[i]), redisKey, 0) + if err := r.RedisClient.Set(ctx, redisEngineKey(engineKeys[i]), redisKey, 0).Err(); err != nil { + return fmt.Errorf("failed to set engine key mapping: %w", err) + } } + for _, entry := range entries { - // Use HSet to add the pod identifier as a field in the hash - pipe.HSet(ctx, redisKey, entry.String(), "") - } - } + entryKey := podEntryKey(entry.PodIdentifier, entry.DeviceTier) + + // Get existing entry if it exists + existingData, err := r.RedisClient.HGet(ctx, redisKey, entryKey).Result() + + // Handle errors (excluding key-not-found) + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("failed to check existing entry: %w", err) + } + + // Entry doesn't exist, create new + if errors.Is(err, redis.Nil) { + newEntry := PodEntry{ + PodIdentifier: entry.PodIdentifier, + DeviceTier: entry.DeviceTier, + Speculative: entry.Speculative, + StoredGroups: entry.StoredGroups, // nil for simple models, []int for HMA + } + + data, err := json.Marshal(newEntry) + if err != nil { + return fmt.Errorf("failed to marshal new entry: %w", err) + } + if err := r.RedisClient.HSet(ctx, redisKey, entryKey, data).Err(); err != nil { + return fmt.Errorf("failed to add entry to Redis: %w", err) + } + continue + } - if _, err := pipe.Exec(ctx); err != nil { - return fmt.Errorf("failed to add entries to Redis: %w", err) + // Entry exists, merge groups (if HMA model) + var existingEntry PodEntry + if err := json.Unmarshal([]byte(existingData), &existingEntry); err != nil { + return fmt.Errorf("failed to unmarshal existing entry: %w", err) + } + + // Check StoredGroups to determine simple vs HMA model (same pattern as Evict) + if entry.StoredGroups == nil { + // Simple model - no group tracking needed + } else { + // HMA model - merge groups + existingEntry.StoredGroups = mergeGroupsUnique(existingEntry.StoredGroups, entry.StoredGroups) + } + // Update speculative flag if new entry is confirmed + if !entry.Speculative { + existingEntry.Speculative = false + } + + // Serialize and store updated entry + data, err := json.Marshal(existingEntry) + if err != nil { + return fmt.Errorf("failed to marshal updated entry: %w", err) + } + if err := r.RedisClient.HSet(ctx, redisKey, entryKey, data).Err(); err != nil { + return fmt.Errorf("failed to update entry in Redis: %w", err) + } + } } return nil @@ -286,15 +340,51 @@ func (r *RedisIndex) Evict(ctx context.Context, key BlockHash, keyType KeyType, } redisKey := requestKey.String() - pipe := r.RedisClient.Pipeline() for _, entry := range entries { - // Use HDel to remove the pod identifier field from the hash - pipe.HDel(ctx, redisKey, entry.String()) - } + entryKey := podEntryKey(entry.PodIdentifier, entry.DeviceTier) + + // Get existing entry + existingData, err := r.RedisClient.HGet(ctx, redisKey, entryKey).Result() + if errors.Is(err, redis.Nil) { + // Entry doesn't exist, nothing to evict + continue + } else if err != nil { + return fmt.Errorf("failed to get existing entry: %w", err) + } - if _, err := pipe.Exec(ctx); err != nil { - return fmt.Errorf("failed to evict entries from Redis: %w", err) + // Deserialize existing entry + var existingEntry PodEntry + if err := json.Unmarshal([]byte(existingData), &existingEntry); err != nil { + return fmt.Errorf("failed to unmarshal existing entry: %w", err) + } + + // For simple (non-HMA) models: StoredGroups is nil, remove entire entry + if entry.StoredGroups == nil { //nolint:nestif // Existing complexity + if err := r.RedisClient.HDel(ctx, redisKey, entryKey).Err(); err != nil { + return fmt.Errorf("failed to delete entry from Redis: %w", err) + } + } else { + // For HMA models: remove specific groups + updatedGroups := removeGroups(existingEntry.StoredGroups, entry.StoredGroups) + + if len(updatedGroups) == 0 { + // No groups left, remove the entire entry + if err := r.RedisClient.HDel(ctx, redisKey, entryKey).Err(); err != nil { + return fmt.Errorf("failed to delete entry from Redis: %w", err) + } + } else { + // Update with remaining groups + existingEntry.StoredGroups = updatedGroups + data, err := json.Marshal(existingEntry) + if err != nil { + return fmt.Errorf("failed to marshal updated entry: %w", err) + } + if err := r.RedisClient.HSet(ctx, redisKey, entryKey, data).Err(); err != nil { + return fmt.Errorf("failed to update entry in Redis: %w", err) + } + } + } } // Atomically check hash length and delete engine key if empty (only if engine key mapping exists) @@ -324,3 +414,9 @@ func (r *RedisIndex) GetRequestKey(ctx context.Context, engineKey BlockHash) (Bl func redisEngineKey(engineKey BlockHash) string { return "engine:" + engineKey.String() } + +// podEntryKey generates a hash field key for a pod entry. +// Format: "podIdentifier@deviceTier". +func podEntryKey(podIdentifier, deviceTier string) string { + return podIdentifier + "@" + deviceTier +} diff --git a/pkg/kvcache/model_registry_test.go b/pkg/kvcache/model_registry_test.go new file mode 100644 index 000000000..4c88c41c8 --- /dev/null +++ b/pkg/kvcache/model_registry_test.go @@ -0,0 +1,319 @@ +// Copyright 2025 The llm-d Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvcache //nolint:testpackage // Tests internal model registry implementation + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestModelRegistryDefaultBehavior verifies model registry default behavior. +// 1. If modelConfigs is not present (nil/empty) → all models are non-HMA +// 2. If modelConfigs is present but model not in list → non-HMA (default) +// TestModelRegistryDefaultBehavior verifies model registry default behavior. +func TestModelRegistryDefaultBehavior(t *testing.T) { + t.Run("NoConfig_AllModelsNonHMA", func(t *testing.T) { + // Create registry with no configs (nil) + registry := NewModelRegistry(nil) + + // All models should default to non-HMA + assert.False(t, registry.IsHMA("DeepSeek-V3"), "Unknown model should be non-HMA") + assert.False(t, registry.IsHMA("Qwen/Qwen3-8B"), "Unknown model should be non-HMA") + assert.False(t, registry.IsHMA("any-random-model"), "Unknown model should be non-HMA") + + // GetModelConfig should return default config + config := registry.GetModelConfig("unknown-model") + require.NotNil(t, config) + assert.Equal(t, "unknown-model", config.Name) + assert.False(t, config.IsHMA) + assert.Nil(t, config.AttentionGroups) + }) + + t.Run("EmptyConfig_AllModelsNonHMA", func(t *testing.T) { + // Create registry with empty slice + registry := NewModelRegistry([]*ModelConfig{}) + + // All models should default to non-HMA + assert.False(t, registry.IsHMA("DeepSeek-V3"), "Unknown model should be non-HMA") + assert.False(t, registry.IsHMA("Qwen/Qwen3-8B"), "Unknown model should be non-HMA") + }) + + t.Run("ModelNotInList_NonHMA", func(t *testing.T) { + // Create registry with one HMA model + registry := NewModelRegistry([]*ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []AttentionGroupConfig{ + {GroupID: 0, AttentionType: AttentionTypeFull, BlockSize: 64}, + {GroupID: 1, AttentionType: AttentionTypeSlidingWindow, BlockSize: 64, SlidingWindowSize: 4096}, + }, + }, + }) + + // Configured model should be HMA + assert.True(t, registry.IsHMA("DeepSeek-V3"), "Configured HMA model should be HMA") + + // Unknown models should default to non-HMA + assert.False(t, registry.IsHMA("Qwen/Qwen3-8B"), "Model not in config should be non-HMA") + assert.False(t, registry.IsHMA("unknown-model"), "Model not in config should be non-HMA") + + // GetModelConfig for unknown model returns default + config := registry.GetModelConfig("Qwen/Qwen3-8B") + require.NotNil(t, config) + assert.Equal(t, "Qwen/Qwen3-8B", config.Name) + assert.False(t, config.IsHMA) + }) + + t.Run("ModelInList_UseIsHMAFlag_True", func(t *testing.T) { + registry := NewModelRegistry([]*ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []AttentionGroupConfig{ + {GroupID: 0, AttentionType: AttentionTypeFull, BlockSize: 64}, + }, + }, + }) + + // Should use IsHMA flag from config + assert.True(t, registry.IsHMA("DeepSeek-V3"), "Should use IsHMA=true from config") + + // GetModelConfig should return configured values + config := registry.GetModelConfig("DeepSeek-V3") + require.NotNil(t, config) + assert.Equal(t, "DeepSeek-V3", config.Name) + assert.True(t, config.IsHMA) + assert.Len(t, config.AttentionGroups, 1) + }) + + t.Run("ModelInList_UseIsHMAFlag_False", func(t *testing.T) { + registry := NewModelRegistry([]*ModelConfig{ + { + Name: "Qwen/Qwen3-8B", + IsHMA: false, + }, + }) + + // Should use IsHMA flag from config + assert.False(t, registry.IsHMA("Qwen/Qwen3-8B"), "Should use IsHMA=false from config") + + // GetModelConfig should return configured values + config := registry.GetModelConfig("Qwen/Qwen3-8B") + require.NotNil(t, config) + assert.Equal(t, "Qwen/Qwen3-8B", config.Name) + assert.False(t, config.IsHMA) + assert.Nil(t, config.AttentionGroups) + }) + + t.Run("MixedConfig_CorrectBehavior", func(t *testing.T) { + registry := NewModelRegistry([]*ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []AttentionGroupConfig{ + {GroupID: 0, AttentionType: AttentionTypeFull, BlockSize: 64}, + {GroupID: 1, AttentionType: AttentionTypeSlidingWindow, BlockSize: 64, SlidingWindowSize: 4096}, + }, + }, + { + Name: "Qwen/Qwen3-8B", + IsHMA: false, + }, + { + Name: "meta-llama/Llama-3.1-8B", + IsHMA: false, + }, + }) + + // Configured HMA model + assert.True(t, registry.IsHMA("DeepSeek-V3"), "DeepSeek-V3 should be HMA") + + // Configured non-HMA models + assert.False(t, registry.IsHMA("Qwen/Qwen3-8B"), "Qwen should be non-HMA") + assert.False(t, registry.IsHMA("meta-llama/Llama-3.1-8B"), "Llama should be non-HMA") + + // Unknown model (not in config) + assert.False(t, registry.IsHMA("mistralai/Mistral-7B-v0.1"), "Unknown model should be non-HMA") + }) +} + +func TestModelRegistryAttentionGroups(t *testing.T) { + t.Run("NonHMAModel_NoAttentionGroups", func(t *testing.T) { + registry := NewModelRegistry([]*ModelConfig{ + {Name: "Qwen/Qwen3-8B", IsHMA: false}, + }) + + groups := registry.GetAttentionGroups("Qwen/Qwen3-8B") + assert.Nil(t, groups, "Non-HMA model should have no attention groups") + + blockSize := registry.GetGroupBlockSize("Qwen/Qwen3-8B", 0) + assert.Equal(t, 0, blockSize, "Non-HMA model should return 0 for block size") + + windowSize := registry.GetGroupSlidingWindow("Qwen/Qwen3-8B", 0) + assert.Equal(t, 0, windowSize, "Non-HMA model should return 0 for window size") + }) + + t.Run("HMAModel_HasAttentionGroups", func(t *testing.T) { + registry := NewModelRegistry([]*ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []AttentionGroupConfig{ + { + GroupID: 0, + AttentionType: AttentionTypeFull, + BlockSize: 64, + }, + { + GroupID: 1, + AttentionType: AttentionTypeSlidingWindow, + BlockSize: 64, + SlidingWindowSize: 4096, + }, + }, + }, + }) + + groups := registry.GetAttentionGroups("DeepSeek-V3") + require.NotNil(t, groups) + assert.Len(t, groups, 2) + + // Group 0: Full attention + assert.Equal(t, 64, registry.GetGroupBlockSize("DeepSeek-V3", 0)) + assert.Equal(t, 0, registry.GetGroupSlidingWindow("DeepSeek-V3", 0)) + + // Group 1: Sliding window + assert.Equal(t, 64, registry.GetGroupBlockSize("DeepSeek-V3", 1)) + assert.Equal(t, 4096, registry.GetGroupSlidingWindow("DeepSeek-V3", 1)) + + // Non-existent group + assert.Equal(t, 0, registry.GetGroupBlockSize("DeepSeek-V3", 99)) + }) + + t.Run("UnknownModel_NoAttentionGroups", func(t *testing.T) { + registry := NewModelRegistry([]*ModelConfig{ + {Name: "DeepSeek-V3", IsHMA: true}, + }) + + groups := registry.GetAttentionGroups("unknown-model") + assert.Nil(t, groups, "Unknown model should have no attention groups") + }) +} + +func TestModelRegistryRegisterModel(t *testing.T) { + t.Run("RegisterNewModel", func(t *testing.T) { + registry := NewModelRegistry(nil) + + // Initially unknown + assert.False(t, registry.IsHMA("new-model")) + + // Register as HMA + registry.RegisterModel(&ModelConfig{ + Name: "new-model", + IsHMA: true, + AttentionGroups: []AttentionGroupConfig{ + {GroupID: 0, AttentionType: AttentionTypeFull, BlockSize: 32}, + }, + }) + + // Now should be HMA + assert.True(t, registry.IsHMA("new-model")) + assert.Equal(t, 32, registry.GetGroupBlockSize("new-model", 0)) + }) + + t.Run("UpdateExistingModel", func(t *testing.T) { + registry := NewModelRegistry([]*ModelConfig{ + {Name: "test-model", IsHMA: false}, + }) + + // Initially non-HMA + assert.False(t, registry.IsHMA("test-model")) + + // Update to HMA + registry.RegisterModel(&ModelConfig{ + Name: "test-model", + IsHMA: true, + AttentionGroups: []AttentionGroupConfig{ + {GroupID: 0, AttentionType: AttentionTypeFull, BlockSize: 64}, + }, + }) + + // Now should be HMA + assert.True(t, registry.IsHMA("test-model")) + }) +} + +func TestNewDefaultModelRegistry(t *testing.T) { + registry := NewDefaultModelRegistry() + + // All models should be non-HMA by default + assert.False(t, registry.IsHMA("DeepSeek-V3")) + assert.False(t, registry.IsHMA("Qwen/Qwen3-8B")) + assert.False(t, registry.IsHMA("any-model")) +} + +func TestPreConfiguredHMARegistry(t *testing.T) { + // Create registry with DeepSeek-V3 HMA configuration + registry := NewModelRegistry([]*ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []AttentionGroupConfig{ + { + GroupID: 0, + AttentionType: AttentionTypeFull, + BlockSize: 64, + }, + { + GroupID: 1, + AttentionType: AttentionTypeSlidingWindow, + BlockSize: 64, + SlidingWindowSize: 4096, + }, + }, + }, + }) + + // DeepSeek-V3 should be HMA + assert.True(t, registry.IsHMA("DeepSeek-V3")) + + // Should have 2 attention groups + groups := registry.GetAttentionGroups("DeepSeek-V3") + require.NotNil(t, groups) + assert.Len(t, groups, 2) + + // Group 0: Full attention + assert.Equal(t, AttentionTypeFull, groups[0].AttentionType) + assert.Equal(t, 64, groups[0].BlockSize) + assert.Equal(t, 0, groups[0].SlidingWindowSize) + + // Group 1: Sliding window + assert.Equal(t, AttentionTypeSlidingWindow, groups[1].AttentionType) + assert.Equal(t, 64, groups[1].BlockSize) + assert.Equal(t, 4096, groups[1].SlidingWindowSize) + + // Unknown models should still be non-HMA + assert.False(t, registry.IsHMA("unknown-model")) +} + +func TestAttentionTypeConstants(t *testing.T) { + // Verify attention type constants are defined correctly + assert.Equal(t, AttentionType("full"), AttentionTypeFull) + assert.Equal(t, AttentionType("sliding_window"), AttentionTypeSlidingWindow) + assert.Equal(t, AttentionType("local"), AttentionTypeLocal) +} diff --git a/pkg/kvevents/engineadapter/vllm_adapter.go b/pkg/kvevents/engineadapter/vllm_adapter.go index 8e7085d4b..368c0034b 100644 --- a/pkg/kvevents/engineadapter/vllm_adapter.go +++ b/pkg/kvevents/engineadapter/vllm_adapter.go @@ -141,6 +141,7 @@ func fieldAt(fields []any, i int) any { // [6] medium string|nil (optional, omit_defaults) // [7] lora_name string|nil (optional, omit_defaults) // [8] extra_keys [][]any|nil (optional, omit_defaults) +// [9] group_idx int|nil (optional, omit_defaults) // // Trailing fields may be absent in older vLLM versions. Extra trailing fields // from newer vLLM versions are silently ignored. @@ -221,6 +222,16 @@ func (v *VLLMAdapter) convertBlockStoredEvent(fields []any) (kvevents.GenericEve } } + // [9] group_idx (optional) + var groupIdx int + if raw := fieldAt(fields, 9); raw != nil { + idx, err := toInt(raw) + if err == nil { + groupIdx = idx + } + // Silently ignore field if it's not an int (forward compatibility) + } + return &kvevents.BlockStoredEvent{ BlockHashes: blockHashes, Tokens: tokens, @@ -229,6 +240,7 @@ func (v *VLLMAdapter) convertBlockStoredEvent(fields []any) (kvevents.GenericEve LoraID: loraID, LoraName: loraName, ExtraKeys: extraKeys, + GroupIdx: groupIdx, }, nil } @@ -238,6 +250,7 @@ func (v *VLLMAdapter) convertBlockStoredEvent(fields []any) (kvevents.GenericEve // [0] tag string // [1] block_hashes []hash // [2] medium string|nil (optional, omit_defaults) +// [3] group_idx int|nil (optional, omit_defaults) func (v *VLLMAdapter) convertBlockRemovedEvent(fields []any) (kvevents.GenericEvent, error) { if len(fields) < 2 { return nil, fmt.Errorf("BlockRemoved: need at least 2 fields, got %d", len(fields)) @@ -261,9 +274,20 @@ func (v *VLLMAdapter) convertBlockRemovedEvent(fields []any) (kvevents.GenericEv deviceTier = s } + // [3] group_idx (optional) + var groupIdx int + if raw := fieldAt(fields, 3); raw != nil { + idx, err := toInt(raw) + if err == nil { + groupIdx = idx + } + // Silently ignore field if it's not an int (forward compatibility) + } + return &kvevents.BlockRemovedEvent{ BlockHashes: blockHashes, DeviceTier: deviceTier, + GroupIdx: groupIdx, }, nil } diff --git a/pkg/kvevents/events.go b/pkg/kvevents/events.go index 1a30725d9..185f7789c 100644 --- a/pkg/kvevents/events.go +++ b/pkg/kvevents/events.go @@ -72,6 +72,7 @@ type BlockStoredEvent struct { LoraID *int LoraName *string ExtraKeys [][]any + GroupIdx int // Attention group ID } // Type returns the event type. @@ -83,6 +84,7 @@ func (e *BlockStoredEvent) Type() EventType { type BlockRemovedEvent struct { BlockHashes []uint64 DeviceTier string + GroupIdx int // Attention group ID being evicted } // Type returns the event type. diff --git a/pkg/kvevents/pool.go b/pkg/kvevents/pool.go index 149a8c63a..7df0d70a3 100644 --- a/pkg/kvevents/pool.go +++ b/pkg/kvevents/pool.go @@ -51,6 +51,9 @@ type Config struct { // PodDiscoveryConfig holds the configuration for pod discovery. // Only used when DiscoverPods is true. PodDiscoveryConfig *PodDiscoveryConfig `json:"podDiscoveryConfig,omitempty"` + // ModelRegistry provides model configuration for HMA vs simple model handling. + // If nil, NewDefaultModelRegistry() is used. + ModelRegistry ModelRegistry `json:"-"` } // PodDiscoveryConfig holds configuration for the Kubernetes pod reconciler. @@ -93,9 +96,26 @@ type Pool struct { index kvblock.Index tokenProcessor kvblock.TokenProcessor adapter EngineAdapter + modelRegistry ModelRegistry wg sync.WaitGroup } +// ModelRegistry interface defines methods for retrieving model configurations. +type ModelRegistry interface { + IsHMA(modelName string) bool +} + +// defaultModelRegistry is a simple implementation that treats all models as non-HMA. +type defaultModelRegistry struct{} + +func newDefaultModelRegistry() ModelRegistry { + return &defaultModelRegistry{} +} + +func (r *defaultModelRegistry) IsHMA(modelName string) bool { + return false +} + // NewPool creates a Pool with a sharded worker setup. // Subscribers are managed by SubscriberManager which is controlled by the pod // reconciler. @@ -106,12 +126,20 @@ func NewPool(cfg *Config, index kvblock.Index, tokenProcessor kvblock.TokenProce cfg = DefaultConfig() } + // Use provided model registry or default + modelRegistry := cfg.ModelRegistry + if modelRegistry == nil { + // Import required - will add at top of file + modelRegistry = newDefaultModelRegistry() + } + p := &Pool{ queues: make([]workqueue.TypedRateLimitingInterface[*RawMessage], cfg.Concurrency), concurrency: cfg.Concurrency, index: index, tokenProcessor: tokenProcessor, adapter: adapter, + modelRegistry: modelRegistry, } for i := 0; i < p.concurrency; i++ { @@ -230,7 +258,21 @@ func (p *Pool) processEventBatch(ctx context.Context, batch *EventBatch, podIden } // Create PodEntry for this specific event's device tier - podEntries := []kvblock.PodEntry{{PodIdentifier: podIdentifier, DeviceTier: deviceTier}} + // Check once if model uses HMA to avoid repeated lookups + isHMA := p.modelRegistry.IsHMA(effectiveModelName) + + // Only populate StoredGroups for HMA models to save CPU and memory + // For simple models: skip group processing entirely (nil StoredGroups) + var storedGroups []int + if isHMA { + storedGroups = []int{ev.GroupIdx} + } + + podEntries := []kvblock.PodEntry{{ + PodIdentifier: podIdentifier, + DeviceTier: deviceTier, + StoredGroups: storedGroups, + }} engineKeys := make([]kvblock.BlockHash, len(ev.BlockHashes)) for i, hash := range ev.BlockHashes { @@ -309,7 +351,21 @@ func (p *Pool) processEventBatch(ctx context.Context, batch *EventBatch, podIden } // Create PodEntry for this specific event's device tier - podEntries := []kvblock.PodEntry{{PodIdentifier: podIdentifier, DeviceTier: deviceTier}} + // Check once if model uses HMA to avoid repeated lookups + isHMA := p.modelRegistry.IsHMA(modelName) + + // Only populate StoredGroups for HMA models to save CPU and memory + // For simple models: nil StoredGroups → evict entire entry immediately + var storedGroups []int + if isHMA { + storedGroups = []int{ev.GroupIdx} + } + + podEntries := []kvblock.PodEntry{{ + PodIdentifier: podIdentifier, + DeviceTier: deviceTier, + StoredGroups: storedGroups, + }} // Iterate over the hashes and evict each key. for _, hash := range ev.BlockHashes { diff --git a/pkg/kvevents/pool_test.go b/pkg/kvevents/pool_test.go new file mode 100644 index 000000000..a9d605a4d --- /dev/null +++ b/pkg/kvevents/pool_test.go @@ -0,0 +1,819 @@ +// Copyright 2025 The llm-d Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvevents //nolint:testpackage // Tests internal pool event processing + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache" + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" +) + +// TestPoolEventProcessing_GroupIDHandling verifies that: +// 1. Event with group_id AND model is HMA → StoredGroups = []int{group_id}. +// 2. Event with group_id BUT model is non-HMA → StoredGroups = nil (ignore group_id). +// 3. Event without group_id (group_id=0 default) → depends on model config. +func TestPoolEventProcessing_GroupIDHandling(t *testing.T) { + ctx := context.Background() + + t.Run("HMAModel_WithGroupID_TracksGroup", func(t *testing.T) { + // Setup: HMA model registry. + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []kvcache.AttentionGroupConfig{ + {GroupID: 0, AttentionType: kvcache.AttentionTypeFull, BlockSize: 64}, + {GroupID: 1, AttentionType: kvcache.AttentionTypeSlidingWindow, BlockSize: 64, SlidingWindowSize: 4096}, + }, + }, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Event with GroupIdx = 1 for HMA model + // Need at least 16 tokens to create one complete block (defaultBlockSize = 16) + // 16 tokens = 1 block, so we need 1 engineKey (BlockHash) + tokens := []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{100}, // 1 block = 1 hash + Tokens: tokens, + ParentHash: 0, + DeviceTier: "gpu", + GroupIdx: 1, // Group 1 + }, + }, + } + + pool.processEventBatch(ctx, batch, "test-pod", "DeepSeek-V3") + + // Verify StoredGroups was set to [1] + requestKeys, err := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, "DeepSeek-V3", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + + // Should have entries with StoredGroups = [1] + found := false + for _, entries := range result { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.NotNil(t, entry.StoredGroups, "HMA model MUST have non-nil StoredGroups") + assert.Equal(t, []int{1}, entry.StoredGroups, "Should track group ID 1") + found = true + } + } + } + assert.True(t, found, "Should have found entry with group tracking") + }) + + t.Run("NonHMAModel_WithGroupID_IgnoresGroup", func(t *testing.T) { + // Setup: Simple (non-HMA) model registry + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + { + Name: "Qwen/Qwen3-8B", + IsHMA: false, + }, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Event with GroupIdx = 1, but model is non-HMA + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{200}, + Tokens: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + ParentHash: 0, + DeviceTier: "gpu", + GroupIdx: 1, // Should be ignored for non-HMA + }, + }, + } + + pool.processEventBatch(ctx, batch, "test-pod", "Qwen/Qwen3-8B") + + // Verify StoredGroups is nil (group_id ignored) + requestKeys, err := tokenProcessor.TokensToKVBlockKeys( + kvblock.EmptyBlockHash, + []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + "Qwen/Qwen3-8B", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + + // Should have entries with StoredGroups = nil + found := false + for _, entries := range result { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.Nil(t, entry.StoredGroups, "Non-HMA model MUST have nil StoredGroups (ignore group_id)") + found = true + } + } + } + assert.True(t, found, "Should have found entry without group tracking") + }) + + t.Run("HMAModel_WithGroupIDZero_TracksGroup", func(t *testing.T) { + // Setup: HMA model + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []kvcache.AttentionGroupConfig{ + {GroupID: 0, AttentionType: kvcache.AttentionTypeFull, BlockSize: 64}, + }, + }, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Event with GroupIdx = 0 (default) for HMA model + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{300}, + Tokens: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + ParentHash: 0, + DeviceTier: "gpu", + GroupIdx: 0, // Group 0 + }, + }, + } + + pool.processEventBatch(ctx, batch, "test-pod", "DeepSeek-V3") + + // Verify StoredGroups = [0] + requestKeys, err := tokenProcessor.TokensToKVBlockKeys( + kvblock.EmptyBlockHash, + []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + "DeepSeek-V3", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + + found := false + for _, entries := range result { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.NotNil(t, entry.StoredGroups, "HMA model should track groups") + assert.Equal(t, []int{0}, entry.StoredGroups, "Should track group ID 0") + found = true + } + } + } + assert.True(t, found, "Should have found entry with group 0") + }) + + t.Run("NonHMAModel_WithGroupIDZero_IgnoresGroup", func(t *testing.T) { + // Setup: Non-HMA model + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + { + Name: "Qwen/Qwen3-8B", + IsHMA: false, + }, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Event with GroupIdx = 0 (default) + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{400}, + Tokens: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + ParentHash: 0, + DeviceTier: "gpu", + GroupIdx: 0, + }, + }, + } + + pool.processEventBatch(ctx, batch, "test-pod", "Qwen/Qwen3-8B") + + // Verify StoredGroups is nil + requestKeys, err := tokenProcessor.TokensToKVBlockKeys( + kvblock.EmptyBlockHash, + []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + "Qwen/Qwen3-8B", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + + found := false + for _, entries := range result { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.Nil(t, entry.StoredGroups, "Non-HMA model should have nil StoredGroups") + found = true + } + } + } + assert.True(t, found, "Should have found entry") + }) + + t.Run("UnknownModel_WithGroupID_DefaultsToNonHMA", func(t *testing.T) { + // Setup: Empty registry (all models default to non-HMA) + modelRegistry := kvcache.NewDefaultModelRegistry() + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Event with GroupIdx = 1 for unknown model + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{500}, + Tokens: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + ParentHash: 0, + DeviceTier: "gpu", + GroupIdx: 1, // Should be ignored (unknown model = non-HMA) + }, + }, + } + + pool.processEventBatch(ctx, batch, "test-pod", "unknown-model") + + // Verify StoredGroups is nil (unknown model defaults to non-HMA) + requestKeys, err := tokenProcessor.TokensToKVBlockKeys( + kvblock.EmptyBlockHash, + []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + "unknown-model", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + + found := false + for _, entries := range result { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.Nil(t, entry.StoredGroups, "Unknown model should default to non-HMA (nil StoredGroups)") + found = true + } + } + } + assert.True(t, found, "Should have found entry") + }) +} + +func TestPoolEventProcessing_MultipleBlockHashes(t *testing.T) { + ctx := context.Background() + + t.Run("HMAModel_WithMultipleBlocks_TracksAllBlocks", func(t *testing.T) { + // Setup: HMA model registry + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []kvcache.AttentionGroupConfig{ + {GroupID: 0, AttentionType: kvcache.AttentionTypeFull, BlockSize: 64}, + {GroupID: 1, AttentionType: kvcache.AttentionTypeSlidingWindow, BlockSize: 64, SlidingWindowSize: 4096}, + }, + }, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Event with multiple blocks (32 tokens = 2 blocks, each 16 tokens) + tokens := []uint32{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // Block 1 + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, // Block 2 + } + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{100, 101}, // 2 blocks + Tokens: tokens, + ParentHash: 0, + DeviceTier: "gpu", + GroupIdx: 1, + }, + }, + } + + pool.processEventBatch(ctx, batch, "test-pod", "DeepSeek-V3") + + // Verify both blocks are tracked + requestKeys, err := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, "DeepSeek-V3", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + + // Should have 2 entries, one for each block + assert.Len(t, result, 2, "Should have entries for both blocks") + + // Verify each entry has StoredGroups = [1] + for _, entries := range result { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.NotNil(t, entry.StoredGroups, "HMA model MUST have non-nil StoredGroups") + assert.Equal(t, []int{1}, entry.StoredGroups, "Should track group ID 1") + } + } + } + }) + + t.Run("NonHMAModel_WithMultipleBlocks_TracksAllBlocks", func(t *testing.T) { + // Setup: Simple model registry + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + { + Name: "Qwen/Qwen3-8B", + IsHMA: false, + }, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Event with multiple blocks (32 tokens = 2 blocks) + tokens := []uint32{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + } + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{200, 201}, // 2 blocks + Tokens: tokens, + ParentHash: 0, + DeviceTier: "gpu", + GroupIdx: 0, + }, + }, + } + + pool.processEventBatch(ctx, batch, "test-pod", "Qwen/Qwen3-8B") + + // Verify both blocks are tracked with nil StoredGroups + requestKeys, err := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, "Qwen/Qwen3-8B", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + + assert.Len(t, result, 2, "Should have entries for both blocks") + + for _, entries := range result { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.Nil(t, entry.StoredGroups, "Non-HMA model MUST have nil StoredGroups") + } + } + } + }) +} + +func TestPoolEventProcessing_WithParentHash(t *testing.T) { + ctx := context.Background() + + t.Run("HMAModel_WithParentHash_CreatesChain", func(t *testing.T) { + // Setup: HMA model registry + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []kvcache.AttentionGroupConfig{ + {GroupID: 0, AttentionType: kvcache.AttentionTypeFull, BlockSize: 64}, + }, + }, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // First event: parent block (16 tokens) + parentTokens := []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + parentBatch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{1000}, // Parent block hash + Tokens: parentTokens, + ParentHash: 0, // Root block + DeviceTier: "gpu", + GroupIdx: 0, + }, + }, + } + + pool.processEventBatch(ctx, parentBatch, "test-pod", "DeepSeek-V3") + + // Second event: child block with parent hash + childTokens := []uint32{17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + childBatch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{1001}, // Child block hash + Tokens: childTokens, + ParentHash: 1000, // References parent block + DeviceTier: "gpu", + GroupIdx: 0, + }, + }, + } + + pool.processEventBatch(ctx, childBatch, "test-pod", "DeepSeek-V3") + + // Verify both parent and child blocks are stored + parentRequestKeys, err := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, parentTokens, "DeepSeek-V3", nil) + require.NoError(t, err) + parentResult, err := index.Lookup(ctx, parentRequestKeys, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, parentResult, 1, "Parent block should be stored") + + // Verify child block is stored with parent reference + childRequestKeys, err := tokenProcessor.TokensToKVBlockKeys(parentRequestKeys[0], childTokens, "DeepSeek-V3", nil) + require.NoError(t, err) + childResult, err := index.Lookup(ctx, childRequestKeys, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, childResult, 1, "Child block should be stored") + + // Verify group tracking + for _, entries := range childResult { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.NotNil(t, entry.StoredGroups) + assert.Equal(t, []int{0}, entry.StoredGroups) + } + } + } + }) + + t.Run("NonHMAModel_WithParentHash_CreatesChain", func(t *testing.T) { + // Setup: Simple model registry + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + { + Name: "Qwen/Qwen3-8B", + IsHMA: false, + }, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // First event: parent block + parentTokens := []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + parentBatch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{2000}, + Tokens: parentTokens, + ParentHash: 0, + DeviceTier: "gpu", + GroupIdx: 0, + }, + }, + } + + pool.processEventBatch(ctx, parentBatch, "test-pod", "Qwen/Qwen3-8B") + + // Second event: child block with parent hash + childTokens := []uint32{17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + childBatch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{2001}, + Tokens: childTokens, + ParentHash: 2000, + DeviceTier: "gpu", + GroupIdx: 0, + }, + }, + } + + pool.processEventBatch(ctx, childBatch, "test-pod", "Qwen/Qwen3-8B") + + // Verify both blocks stored with nil StoredGroups + parentRequestKeys, err := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, parentTokens, "Qwen/Qwen3-8B", nil) + require.NoError(t, err) + parentResult, err := index.Lookup(ctx, parentRequestKeys, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, parentResult, 1) + + childRequestKeys, err := tokenProcessor.TokensToKVBlockKeys(parentRequestKeys[0], childTokens, "Qwen/Qwen3-8B", nil) + require.NoError(t, err) + childResult, err := index.Lookup(ctx, childRequestKeys, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, childResult, 1) + + // Verify nil StoredGroups for simple model + for _, entries := range childResult { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.Nil(t, entry.StoredGroups, "Non-HMA model should have nil StoredGroups") + } + } + } + }) +} + +func TestPoolEventProcessing_BlockRemovedEvent(t *testing.T) { + ctx := context.Background() + + t.Run("HMAModel_RemovesSpecificGroup", func(t *testing.T) { + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []kvcache.AttentionGroupConfig{ + {GroupID: 0, AttentionType: kvcache.AttentionTypeFull, BlockSize: 64}, + {GroupID: 1, AttentionType: kvcache.AttentionTypeSlidingWindow, BlockSize: 64, SlidingWindowSize: 4096}, + }, + }, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Add with group 0 + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{600}, + Tokens: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + GroupIdx: 0, + }, + }, + } + pool.processEventBatch(ctx, batch, "test-pod", "DeepSeek-V3") + + // Add group 1 + batchEvent := batch.Events[0].(*BlockStoredEvent) //nolint:errcheck // Test - type assertion is safe + batchEvent.GroupIdx = 1 + pool.processEventBatch(ctx, batch, "test-pod", "DeepSeek-V3") + + // Verify both groups present + requestKeys, err := tokenProcessor.TokensToKVBlockKeys( + kvblock.EmptyBlockHash, + []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + "DeepSeek-V3", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + for _, entries := range result { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.ElementsMatch(t, []int{0, 1}, entry.StoredGroups, "Should have both groups") + } + } + } + + // Remove group 0 only + removeEvent := &EventBatch{ + Events: []GenericEvent{ + &BlockRemovedEvent{ + BlockHashes: []uint64{600}, + GroupIdx: 0, + }, + }, + } + pool.processEventBatch(ctx, removeEvent, "test-pod", "DeepSeek-V3") + + // Verify only group 1 remains + result, _ = index.Lookup(ctx, requestKeys, sets.Set[string]{}) //nolint:errcheck // Test cleanup - errors not critical + for _, entries := range result { + for _, entry := range entries { + if entry.PodIdentifier == "test-pod" { + assert.Equal(t, []int{1}, entry.StoredGroups, "Only group 1 should remain") + } + } + } + }) + + t.Run("NonHMAModel_RemovesEntireEntry", func(t *testing.T) { + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + {Name: "Qwen/Qwen3-8B", IsHMA: false}, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Add entry + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{700}, + Tokens: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + GroupIdx: 0, + }, + }, + } + pool.processEventBatch(ctx, batch, "test-pod", "Qwen/Qwen3-8B") + + // Verify entry exists + requestKeys, err := tokenProcessor.TokensToKVBlockKeys( + kvblock.EmptyBlockHash, + []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + "Qwen/Qwen3-8B", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + assert.NotEmpty(t, result, "Entry should exist") + + // Remove entry (group_id ignored for non-HMA) + removeEvent := &EventBatch{ + Events: []GenericEvent{ + &BlockRemovedEvent{ + BlockHashes: []uint64{700}, + GroupIdx: 0, // Ignored + }, + }, + } + pool.processEventBatch(ctx, removeEvent, "test-pod", "Qwen/Qwen3-8B") + + // Verify entry is completely removed + result, _ = index.Lookup(ctx, requestKeys, sets.Set[string]{}) //nolint:errcheck // Test cleanup - errors not critical + assert.Empty(t, result, "Entry should be completely removed for non-HMA model") + }) + + t.Run("HMAModel_RemovesMultipleBlocks", func(t *testing.T) { + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + { + Name: "DeepSeek-V3", + IsHMA: true, + AttentionGroups: []kvcache.AttentionGroupConfig{ + {GroupID: 0, AttentionType: kvcache.AttentionTypeFull, BlockSize: 64}, + {GroupID: 1, AttentionType: kvcache.AttentionTypeSlidingWindow, BlockSize: 64, SlidingWindowSize: 4096}, + }, + }, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Add multiple blocks (32 tokens = 2 blocks) + tokens := []uint32{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + } + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{800, 801}, + Tokens: tokens, + GroupIdx: 0, + }, + }, + } + pool.processEventBatch(ctx, batch, "test-pod", "DeepSeek-V3") + + // Verify both blocks are stored + requestKeys, err := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, "DeepSeek-V3", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, result, 2, "Both blocks should be stored") + + // Remove both blocks at once + removeEvent := &EventBatch{ + Events: []GenericEvent{ + &BlockRemovedEvent{ + BlockHashes: []uint64{800, 801}, // Remove multiple blocks + GroupIdx: 0, + }, + }, + } + pool.processEventBatch(ctx, removeEvent, "test-pod", "DeepSeek-V3") + + // Verify both blocks are removed + result, _ = index.Lookup(ctx, requestKeys, sets.Set[string]{}) //nolint:errcheck // Test cleanup - errors not critical + assert.Empty(t, result, "All blocks should be removed") + }) + + t.Run("NonHMAModel_RemovesMultipleBlocks", func(t *testing.T) { + modelRegistry := kvcache.NewModelRegistry([]*kvcache.ModelConfig{ + {Name: "Qwen/Qwen3-8B", IsHMA: false}, + }) + + index, err := kvblock.NewInMemoryIndex(kvblock.DefaultInMemoryIndexConfig()) + require.NoError(t, err) + tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + require.NoError(t, err) + cfg := DefaultConfig() + cfg.ModelRegistry = modelRegistry + pool := NewPool(cfg, index, tokenProcessor, nil) + + // Add multiple blocks + tokens := []uint32{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + } + batch := &EventBatch{ + Events: []GenericEvent{ + &BlockStoredEvent{ + BlockHashes: []uint64{900, 901}, + Tokens: tokens, + GroupIdx: 0, + }, + }, + } + pool.processEventBatch(ctx, batch, "test-pod", "Qwen/Qwen3-8B") + + // Verify blocks are stored + requestKeys, err := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, "Qwen/Qwen3-8B", nil) + require.NoError(t, err) + result, err := index.Lookup(ctx, requestKeys, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, result, 2) + + // Remove multiple blocks + removeEvent := &EventBatch{ + Events: []GenericEvent{ + &BlockRemovedEvent{ + BlockHashes: []uint64{900, 901}, + GroupIdx: 0, + }, + }, + } + pool.processEventBatch(ctx, removeEvent, "test-pod", "Qwen/Qwen3-8B") + + // Verify all removed + result, _ = index.Lookup(ctx, requestKeys, sets.Set[string]{}) //nolint:errcheck // Test cleanup - errors not critical + assert.Empty(t, result, "All blocks should be removed for non-HMA model") + }) +} diff --git a/pkg/kvevents/zmq_subscriber_test.go b/pkg/kvevents/zmq_subscriber_test.go index adbba7010..544634be0 100644 --- a/pkg/kvevents/zmq_subscriber_test.go +++ b/pkg/kvevents/zmq_subscriber_test.go @@ -120,7 +120,9 @@ func TestZMQSubscriber_ReceivesMessages(t *testing.T) { require.NoError(t, err) tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) require.NoError(t, err) - pool := kvevents.NewPool(kvevents.DefaultConfig(), index, tokenProcessor, engineadapter.NewVLLMAdapter()) + pool := kvevents.NewPool( + kvevents.DefaultConfig(), index, tokenProcessor, + engineadapter.NewVLLMAdapter()) pool.Start(ctx) // Start subscriber — remote=false means it binds (Listen). @@ -165,7 +167,9 @@ func TestZMQSubscriber_ShortSequenceFrameSkipped(t *testing.T) { require.NoError(t, err) tokenProcessor, err := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) require.NoError(t, err) - pool := kvevents.NewPool(kvevents.DefaultConfig(), index, tokenProcessor, engineadapter.NewVLLMAdapter()) + pool := kvevents.NewPool( + kvevents.DefaultConfig(), index, tokenProcessor, + engineadapter.NewVLLMAdapter()) pool.Start(ctx) // Pick an available ephemeral port to avoid conflicts with parallel tests or CI.