diff --git a/docs/architecture.md b/docs/architecture.md index 3de075b7a..0caa91817 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -96,7 +96,7 @@ sequenceDiagram else BlockRemoved Worker->>Index: Evict(key, podEntry) else AllBlocksCleared - Note over Worker: No-op + Worker->>Index: Clear() end end ``` diff --git a/examples/helper/events.go b/examples/helper/events.go index faa6c4e9d..04161cc2b 100644 --- a/examples/helper/events.go +++ b/examples/helper/events.go @@ -110,6 +110,33 @@ func SimulateRemoveEvent(ctx context.Context, publisher *Publisher) error { return nil } +func SimulateClearAllBlocksEvent(ctx context.Context, publisher *Publisher) error { + logger := log.FromContext(ctx) + logger.Info("@@@ Simulating vLLM engine clear all blocks...") + + clearAllBlocksEvent := []any{ + "AllBlocksCleared", + } + + clearAllBlocksPayload, err := msgpack.Marshal(clearAllBlocksEvent) + if err != nil { + return fmt.Errorf("failed to marshal AllBlocksCleared event: %w", err) + } + clearAllBlockEventBatch := []any{ + float64(time.Now().UnixNano()) / 1e9, + []msgpack.RawMessage{clearAllBlocksPayload}, + nil, + } + + if err := publisher.PublishEvent(ctx, topic, clearAllBlockEventBatch); err != nil { + return fmt.Errorf("failed to publish AllBlocksCleared event: %w", err) + } + logger.Info("@@@ Published AllBlocksCleared event", "topic", topic) + + time.Sleep(3 * time.Second) + return nil +} + func SetupEventsPool(ctx context.Context, kvBlockIndex kvblock.Index) (*kvevents.Pool, error) { logger := log.FromContext(ctx) diff --git a/examples/kv_events/offline/main.go b/examples/kv_events/offline/main.go index 8e6555565..b0e165a05 100644 --- a/examples/kv_events/offline/main.go +++ b/examples/kv_events/offline/main.go @@ -141,6 +141,32 @@ func RunEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish } logger.Info("@@@ Final pod scores after BlockRemoved events", "pods", pods) + // Simulate vLLM engine publishing BlockStored events + err = helper.SimulateProduceEvent(ctx, publisher) + if err != nil { + return err + } + + // Query again to see the effect of the events + pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil) + if err != nil { + return err + } + logger.Info("@@@ Pod scores after BlockStored events", "pods", pods) + + // Simulate vLLM engine publishing AllBlocksCleared event + err = helper.SimulateClearAllBlocksEvent(ctx, publisher) + if err != nil { + return err + } + + // Query again to see the effect of the events + pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil) + if err != nil { + return err + } + logger.Info("@@@ Pod scores after AllBlocksCleared events", "pods", pods) + logger.Info("Events demo completed. Pool continues listening for more events...") logger.Info("Press Ctrl+C to shutdown") diff --git a/examples/valkey_example/main.go b/examples/valkey_example/main.go index d3c28cf53..b27c31c8a 100644 --- a/examples/valkey_example/main.go +++ b/examples/valkey_example/main.go @@ -195,6 +195,21 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer) logger.Info("Cache lookup after eviction", "keysFound", len(lookupAfterEvict)) + // Clear the cache + logger.Info("Clearing the cache") + err = indexer.KVBlockIndex().Clear(ctx, podEntries[0]) + if err != nil { + return fmt.Errorf("failed to clear cache: %w", err) + } + + // Lookup again after clear + lookupAfterClear, err := indexer.KVBlockIndex().Lookup(ctx, promptKeys, nil) + if err != nil { + return fmt.Errorf("failed to lookup after clear: %w", err) + } + + logger.Info("Cache lookup after clear", "keysFound", len(lookupAfterClear)) + // Final score check to see the difference finalScores, err := indexer.GetPodScores(ctx, nil, prompt, modelName, []string{"demo-pod-1", "demo-pod-2"}) if err != nil { diff --git a/pkg/kvcache/kvblock/cost_aware_memory.go b/pkg/kvcache/kvblock/cost_aware_memory.go index e3f2817a8..0175026e5 100644 --- a/pkg/kvcache/kvblock/cost_aware_memory.go +++ b/pkg/kvcache/kvblock/cost_aware_memory.go @@ -43,6 +43,7 @@ type CostAwareMemoryIndexConfig struct { Size string `json:"size,omitempty"` } +// DefaultCostAwareMemoryIndexConfig returns a default configuration for the CostAwareMemoryIndex. func DefaultCostAwareMemoryIndexConfig() *CostAwareMemoryIndexConfig { return &CostAwareMemoryIndexConfig{ Size: "2GiB", // 2GiB default size @@ -75,9 +76,15 @@ func NewCostAwareMemoryIndex(cfg *CostAwareMemoryIndexConfig) (*CostAwareMemoryI return nil, fmt.Errorf("failed to initialize in-memory engine key map: %w", err) } + podToRequestKeys, err := lru.New[string, map[BlockHash]BlockHash](defaultNumCounters) + if err != nil { + return nil, fmt.Errorf("failed to initialize pod-to-request-key map: %w", err) + } + return &CostAwareMemoryIndex{ - data: cache, - requestKeys: requestKeys, + data: cache, + requestKeys: requestKeys, + podToRequestKeys: podToRequestKeys, }, nil } @@ -95,8 +102,11 @@ type CostAwareMemoryIndex struct { requestKeys *lru.Cache[BlockHash, BlockHash] // mu protects concurrent access to the index operations mu sync.RWMutex + // podToRequestKeys is a reverse index: podIdentifier -> [requestKey]: engineKey. + podToRequestKeys *lru.Cache[string, map[BlockHash]BlockHash] } +// MaxCost returns the maximum cost of the cache. func (m *CostAwareMemoryIndex) MaxCost() int64 { return m.data.MaxCost() } @@ -193,8 +203,19 @@ func (m *CostAwareMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys podCache = &CostPodCache{} } + curEngineKey := EmptyBlockHash + if engineKeys != nil { + curEngineKey = engineKeys[i] + } + for _, entry := range entries { podCache.Add(entry) + mappings, found := m.podToRequestKeys.Peek(entry.PodIdentifier) + if !found { + mappings = make(map[BlockHash]BlockHash) + } + mappings[requestKey] = curEngineKey + m.podToRequestKeys.Add(entry.PodIdentifier, mappings) } // Calculate the actual cost for this cache entry @@ -206,6 +227,7 @@ func (m *CostAwareMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys return nil } +// Lookup looks up a list of request keys and returns the associated pod entries. func (m *CostAwareMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHash, podIdentifierSet sets.Set[string], ) (map[BlockHash][]PodEntry, error) { @@ -308,6 +330,13 @@ func (m *CostAwareMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType for _, entry := range entries { podCache.Delete(entry) + + if mappings, found := m.podToRequestKeys.Peek(entry.PodIdentifier); found { + delete(mappings, requestKey) + if len(mappings) == 0 { + m.podToRequestKeys.Remove(entry.PodIdentifier) + } + } } if podCache.Len() == 0 { @@ -333,3 +362,66 @@ func (m *CostAwareMemoryIndex) GetRequestKey(ctx context.Context, engineKey Bloc } return requestKey, nil } + +// Clear removes all entries from the index backend. +func (m *CostAwareMemoryIndex) Clear(ctx context.Context, podEntry PodEntry) error { + m.mu.Lock() + defer m.mu.Unlock() + traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvblock.CostAwareMemoryIndex.Clear") + + // Remove all entries for the given pod identifier + mappings, found := m.podToRequestKeys.Get(podEntry.PodIdentifier) + if !found { + traceLogger.Info("pod not found in reverse index, nothing to clear", "podEntry", podEntry) + return nil + } + + remaining := make(map[BlockHash]BlockHash) + for requestKey, engineKey := range mappings { + pod, found := m.data.Get(requestKey.String()) + if !found { + traceLogger.Info("request key not found in cache, skipping", "requestKey", requestKey) + continue + } + podCacheLenBefore := pod.Len() + var toDelete []PodEntry + pod.cache.Range(func(key, value any) bool { + if entry, ok := key.(PodEntry); ok { + if podEntry.DeviceTier != "" && entry.DeviceTier != podEntry.DeviceTier { + return true + } + if entry.PodIdentifier == podEntry.PodIdentifier { + toDelete = append(toDelete, entry) + } + } + return true + }) + + for _, entry := range toDelete { + pod.Delete(entry) + } + + switch { + case pod.Len() == 0: + m.data.Del(requestKey.String()) + if engineKey != EmptyBlockHash { + m.requestKeys.Remove(engineKey) + } + case podCacheLenBefore != pod.Len(): + m.data.Set(requestKey.String(), pod, pod.CalculateByteSize(requestKey.String())) + remaining[requestKey] = engineKey + default: + remaining[requestKey] = engineKey + } + } + + if podEntry.DeviceTier == "" || len(remaining) == 0 { + m.podToRequestKeys.Remove(podEntry.PodIdentifier) + } else { + m.podToRequestKeys.Add(podEntry.PodIdentifier, remaining) + } + + m.data.Wait() + traceLogger.Info("Cleared pod entries from CostAwareMemoryIndex", "podEntry", podEntry) + return nil +} diff --git a/pkg/kvcache/kvblock/in_memory.go b/pkg/kvcache/kvblock/in_memory.go index bc3a50c54..471585fd9 100644 --- a/pkg/kvcache/kvblock/in_memory.go +++ b/pkg/kvcache/kvblock/in_memory.go @@ -67,19 +67,32 @@ func NewInMemoryIndex(cfg *InMemoryIndexConfig) (*InMemoryIndex, error) { return nil, fmt.Errorf("failed to initialize in-memory engine key map: %w", err) } + podToRequestKeys, err := lru.New[string, *podMapping](cfg.Size) + if err != nil { + return nil, fmt.Errorf("failed to initialize pod-to-request-key map: %w", err) + } + return &InMemoryIndex{ data: cache, engineToRequestKeys: engineToRequestKeys, + podToRequestKeys: podToRequestKeys, podCacheSize: cfg.PodCacheSize, }, nil } +type podMapping struct { + mappings map[BlockHash]BlockHash + mu sync.Mutex +} + // InMemoryIndex is an in-memory implementation of the Index interface. type InMemoryIndex struct { // data holds the mapping of requestKeys to sets of pod identifiers. data *lru.Cache[BlockHash, *PodCache] // engineToRequestKeys holds the mapping of engineKeys to requestKeys. engineToRequestKeys *lru.Cache[BlockHash, BlockHash] + // podToRequestKeys is a reverse index: podIdentifier -> [requestKey]: engineKey. + podToRequestKeys *lru.Cache[string, *podMapping] // podCacheSize is the maximum number of pod entries per key. podCacheSize int } @@ -161,7 +174,9 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block for i, requestKey := range requestKeys { // 1. Store engineKey -> requestKey mapping (only if engineKeys provided) + curEngineKey := EmptyBlockHash if engineKeys != nil { + curEngineKey = engineKeys[i] m.engineToRequestKeys.Add(engineKeys[i], requestKey) } @@ -202,6 +217,20 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block podCache.mu.Lock() for _, entry := range entries { podCache.cache.Add(entry, struct{}{}) + + // 3. Maintain reverse index: podIdentifier -> [requestKey]: engineKey + pm, ok := m.podToRequestKeys.Peek(entry.PodIdentifier) + if !ok { + pm = &podMapping{ + mappings: make(map[BlockHash]BlockHash), + } + if contained, _ := m.podToRequestKeys.ContainsOrAdd(entry.PodIdentifier, pm); contained { + pm, _ = m.podToRequestKeys.Peek(entry.PodIdentifier) + } + } + pm.mu.Lock() + pm.mappings[requestKey] = curEngineKey + pm.mu.Unlock() } podCache.mu.Unlock() @@ -253,6 +282,16 @@ func (m *InMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType KeyTyp podCache.mu.Lock() for _, entry := range entries { podCache.cache.Remove(entry) + + if pm, ok := m.podToRequestKeys.Peek(entry.PodIdentifier); ok { + pm.mu.Lock() + delete(pm.mappings, requestKey) + + if len(pm.mappings) == 0 { + m.podToRequestKeys.Remove(entry.PodIdentifier) + } + pm.mu.Unlock() + } } isEmpty := podCache.cache.Len() == 0 @@ -304,3 +343,92 @@ func podsPerKeyPrintHelper(ks map[BlockHash][]PodEntry) string { } return b.String() } + +// Clear removes all entries for the given podEntry from the index. +func (m *InMemoryIndex) Clear(ctx context.Context, podEntry PodEntry) error { + traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvblock.InMemoryIndex.Clear") + // Remove all entries for the given pod identifier + pm, found := m.podToRequestKeys.Get(podEntry.PodIdentifier) + if !found { + traceLogger.Info("pod not found in reverse index, nothing to clear", "podEntry", podEntry) + return nil + } + pm.mu.Lock() + snapshot := make(map[BlockHash]BlockHash, len(pm.mappings)) + for requestKey, engineKey := range pm.mappings { + snapshot[requestKey] = engineKey + } + pm.mu.Unlock() + + for requestKey, engineKey := range snapshot { + cache, exists := m.data.Peek(requestKey) + if !exists || cache == nil { + continue + } + + cache.mu.Lock() + var toDelete []PodEntry + for _, entry := range cache.cache.Keys() { + if entry.PodIdentifier != podEntry.PodIdentifier { + continue + } + if podEntry.DeviceTier != "" && entry.DeviceTier != podEntry.DeviceTier { + continue + } + toDelete = append(toDelete, entry) + } + + for _, entry := range toDelete { + cache.cache.Remove(entry) + } + + isEmpty := cache.cache.Len() == 0 + cache.mu.Unlock() + + if !isEmpty { + continue + } + + currentCache, stillExists := m.data.Get(requestKey) + var engineKeyExists bool + if engineKey != EmptyBlockHash { + engineKeyExists = m.engineToRequestKeys.Contains(engineKey) + } + if !stillExists || currentCache == nil { + if engineKeyExists { + m.engineToRequestKeys.Remove(engineKey) + } + continue + } + currentCache.mu.Lock() + if currentCache.cache.Len() == 0 { + m.data.Remove(requestKey) + if engineKeyExists { + m.engineToRequestKeys.Remove(engineKey) + } + } + currentCache.mu.Unlock() + } + + if podEntry.DeviceTier == "" { + m.podToRequestKeys.Remove(podEntry.PodIdentifier) + } else { + remaining := make(map[BlockHash]BlockHash) + pm.mu.Lock() + for requestKey, engineKey := range pm.mappings { + if exists := m.data.Contains(requestKey); exists { + remaining[requestKey] = engineKey + } + } + + if len(remaining) == 0 { + m.podToRequestKeys.Remove(podEntry.PodIdentifier) + } else { + pm.mappings = remaining + } + pm.mu.Unlock() + } + + traceLogger.Info("Cleared pod entries from InMemoryIndex", "podEntry", podEntry) + return nil +} diff --git a/pkg/kvcache/kvblock/index.go b/pkg/kvcache/kvblock/index.go index 9e0a49457..c28edc566 100644 --- a/pkg/kvcache/kvblock/index.go +++ b/pkg/kvcache/kvblock/index.go @@ -138,6 +138,10 @@ type Index interface { Evict(ctx context.Context, key BlockHash, keyType KeyType, entries []PodEntry) error // GetRequestKey returns the requestKey associated with the given engineKey. GetRequestKey(ctx context.Context, engineKey BlockHash) (BlockHash, error) + // Clear removes all entries from the index backend. + // If podEntry.DeviceTier is empty, all tiers for that pod identifier are removed. + // If podEntry.DeviceTier is set, only entries matching that exact tier are removed. + Clear(ctx context.Context, podEntry PodEntry) error } // KeyType indicates whether a key passed to Evict is an engine key or a request key. diff --git a/pkg/kvcache/kvblock/index_test.go b/pkg/kvcache/kvblock/index_test.go index 510a7f743..4375013a0 100644 --- a/pkg/kvcache/kvblock/index_test.go +++ b/pkg/kvcache/kvblock/index_test.go @@ -68,6 +68,11 @@ func testCommonIndexBehavior(t *testing.T, indexFactory func(t *testing.T) Index index := indexFactory(t) testAddWithNilEngineKeys(t, ctx, index) }) + + t.Run("Clear", func(t *testing.T) { + index := indexFactory(t) + testClear(t, ctx, index) + }) } // testBasicAddAndLookup tests basic Add and Lookup functionality. @@ -218,6 +223,155 @@ func testEvictBasic(t *testing.T, ctx context.Context, index Index) { assert.ElementsMatch(t, expected, podsPerKey[requestKey]) } +func testClear(t *testing.T, ctx context.Context, index Index) { + t.Helper() + + t.Run("ClearOneOfManyPods", func(t *testing.T) { + engineKey := BlockHash(17434655) + requestKey := BlockHash(59244875) + entries := []PodEntry{ + {PodIdentifier: "pod1", DeviceTier: "gpu"}, + {PodIdentifier: "pod2", DeviceTier: "gpu"}, + {PodIdentifier: "pod3", DeviceTier: "cpu"}, + } + + err := index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, entries) + require.NoError(t, err) + + // Clear only pod1 — pod2 and pod3 must survive + err = index.Clear(ctx, PodEntry{PodIdentifier: "pod1", DeviceTier: "gpu"}) + require.NoError(t, err) + + podsPerKey, err := index.Lookup(ctx, []BlockHash{requestKey}, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, podsPerKey, 1, "request key should still exist with remaining pods") + assert.Contains(t, podsPerKey, requestKey) + assert.ElementsMatch(t, podsPerKey[requestKey], []PodEntry{ + {PodIdentifier: "pod2", DeviceTier: "gpu"}, + {PodIdentifier: "pod3", DeviceTier: "cpu"}, + }) + _, err = index.GetRequestKey(ctx, engineKey) + require.NoError(t, err) + }) + + t.Run("ClearWithDeviceTierFilter", func(t *testing.T) { + engineKey2 := BlockHash(11111111) + requestKey2 := BlockHash(22222222) + entries := []PodEntry{ + {PodIdentifier: "pod4", DeviceTier: "gpu"}, + {PodIdentifier: "pod4", DeviceTier: "cpu"}, // same pod, different tier + } + + err := index.Add(ctx, []BlockHash{engineKey2}, []BlockHash{requestKey2}, entries) + require.NoError(t, err) + + // Clear pod4 only on gpu tier — cpu tier entry must survive + err = index.Clear(ctx, PodEntry{PodIdentifier: "pod4", DeviceTier: "gpu"}) + require.NoError(t, err) + + podsPerKey, err := index.Lookup(ctx, []BlockHash{requestKey2}, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, podsPerKey, 1) + assert.ElementsMatch(t, podsPerKey[requestKey2], []PodEntry{ + {PodIdentifier: "pod4", DeviceTier: "cpu"}, + }) + assert.NotElementsMatch(t, podsPerKey[engineKey2], []PodEntry{ + {PodIdentifier: "pod4", DeviceTier: "gpu"}, + }) + + // engine key should survive + _, err = index.GetRequestKey(ctx, engineKey2) + require.NoError(t, err) + }) + + t.Run("ClearLastPodRemovesKey", func(t *testing.T) { + engineKey3 := BlockHash(33333333) + requestKey3 := BlockHash(44444444) + entries := []PodEntry{ + {PodIdentifier: "pod5", DeviceTier: "gpu"}, + } + + err := index.Add(ctx, []BlockHash{engineKey3}, []BlockHash{requestKey3}, entries) + require.NoError(t, err) + + // Clear the only pod — the request key should disappear + err = index.Clear(ctx, PodEntry{PodIdentifier: "pod5", DeviceTier: "gpu"}) + require.NoError(t, err) + + podsPerKey, err := index.Lookup(ctx, []BlockHash{requestKey3}, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, podsPerKey, 0, "request key should be removed when no pods remain") + + // engine key should survive + _, err = index.GetRequestKey(ctx, engineKey3) + require.Error(t, err) + }) + + t.Run("ClearNoDeviceTier", func(t *testing.T) { + engineKey4 := BlockHash(44444444) + requestKey4 := BlockHash(55555555) + entries := []PodEntry{ + {PodIdentifier: "pod6", DeviceTier: "gpu"}, + {PodIdentifier: "pod6", DeviceTier: "cpu"}, + } + + err := index.Add(ctx, []BlockHash{engineKey4}, []BlockHash{requestKey4}, entries) + require.NoError(t, err) + + // Clear pod6 with no device tier specified — both gpu and cpu entries should be removed + err = index.Clear(ctx, PodEntry{PodIdentifier: "pod6"}) + require.NoError(t, err) + + podsPerKey, err := index.Lookup(ctx, []BlockHash{requestKey4}, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, podsPerKey, 0, "request key should be removed when no pods remain") + + // Clear the dangling engine key + _, err = index.GetRequestKey(ctx, engineKey4) + require.Error(t, err) + }) + + t.Run("ClearWithNilEngineKeys", func(t *testing.T) { + // Speculative entries are added with nil engineKeys. + // Clear must still remove those entries via the podToRequestKey reverse index. + requestKey5 := BlockHash(66666666) + entries := []PodEntry{ + {PodIdentifier: "pod7", DeviceTier: "gpu", Speculative: true}, + {PodIdentifier: "pod8", DeviceTier: "gpu", Speculative: true}, + } + + err := index.Add(ctx, nil, []BlockHash{requestKey5}, entries) + require.NoError(t, err) + + // Verify entries are visible before clearing + podsPerKey, err := index.Lookup(ctx, []BlockHash{requestKey5}, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, podsPerKey[requestKey5], 2) + + // Clear pod7 — only pod8 should remain + err = index.Clear(ctx, PodEntry{PodIdentifier: "pod7", DeviceTier: "gpu"}) + require.NoError(t, err) + + podsPerKey, err = index.Lookup(ctx, []BlockHash{requestKey5}, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, podsPerKey, 1, "request key should still exist with pod8") + assert.ElementsMatch(t, podsPerKey[requestKey5], []PodEntry{ + {PodIdentifier: "pod8", DeviceTier: "gpu", Speculative: true}, + }) + + // Clear pod8 — request key should be removed entirely + err = index.Clear(ctx, PodEntry{PodIdentifier: "pod8", DeviceTier: "gpu"}) + require.NoError(t, err) + + podsPerKey, err = index.Lookup(ctx, []BlockHash{requestKey5}, sets.Set[string]{}) + require.NoError(t, err) + assert.Len(t, podsPerKey, 0, "request key should be removed when no pods remain") + + _, err = index.GetRequestKey(ctx, EmptyBlockHash) + require.Error(t, err) + }) +} + // testConcurrentOperations tests thread safety with concurrent operations. func testConcurrentOperations(t *testing.T, ctx context.Context, index Index) { t.Helper() @@ -234,7 +388,7 @@ func testConcurrentOperations(t *testing.T, ctx context.Context, index Index) { time.Sleep(time.Millisecond * time.Duration(id%10)) // Stagger start times defer wg.Done() for operationIndex := 0; operationIndex < 10; operationIndex++ { - switch operationIndex % 3 { + switch operationIndex % 4 { case 0: // Add entries := []PodEntry{{PodIdentifier: fmt.Sprintf("pod-%d-%d", id, operationIndex), DeviceTier: "gpu"}} if err := index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, entries); err != nil { @@ -250,6 +404,13 @@ func testConcurrentOperations(t *testing.T, ctx context.Context, index Index) { if err := index.Evict(ctx, engineKey, EngineKey, entries); err != nil { errChan <- err } + case 3: // Clear + if err := index.Clear(ctx, PodEntry{ + PodIdentifier: fmt.Sprintf("pod-%d-%d", id, operationIndex-2), + DeviceTier: "gpu", + }); err != nil { + errChan <- err + } } } }(goroutineID) diff --git a/pkg/kvcache/kvblock/instrumented_index.go b/pkg/kvcache/kvblock/instrumented_index.go index 91d7050a7..ce556d7a1 100644 --- a/pkg/kvcache/kvblock/instrumented_index.go +++ b/pkg/kvcache/kvblock/instrumented_index.go @@ -27,7 +27,7 @@ type instrumentedIndex struct { } // NewInstrumentedIndex wraps an Index and emits metrics for Add, Evict, and -// Lookup. +// Lookup, Clear. func NewInstrumentedIndex(next Index) Index { return &instrumentedIndex{next: next} } @@ -90,3 +90,9 @@ func recordHitMetrics(keyToPods map[BlockHash][]PodEntry) { metrics.MaxPodHitCount.Add(float64(maxHit)) metrics.LookupHits.Add(float64(maxHit)) } + +func (m *instrumentedIndex) Clear(ctx context.Context, podEntry PodEntry) error { + err := m.next.Clear(ctx, podEntry) + metrics.Clear.Add(1) + return err +} diff --git a/pkg/kvcache/kvblock/redis.go b/pkg/kvcache/kvblock/redis.go index 921b47767..71948602d 100644 --- a/pkg/kvcache/kvblock/redis.go +++ b/pkg/kvcache/kvblock/redis.go @@ -111,6 +111,11 @@ func NewRedisIndex(config *RedisIndexConfig) (Index, error) { return nil, fmt.Errorf("failed to connect to %s: %w", config.BackendType, err) } + // Pre-load Lua scripts so EvalSha calls in pipelines never get NOSCRIPT errors. + if err := clearPodEntryScript.Load(context.Background(), redisClient).Err(); err != nil { + return nil, fmt.Errorf("failed to load Lua scripts on %s: %w", config.BackendType, err) + } + return &RedisIndex{ RedisClient: redisClient, BackendType: config.BackendType, @@ -243,12 +248,18 @@ func (r *RedisIndex) Add(ctx context.Context, engineKeys, requestKeys []BlockHas redisKey := requestKey.String() // Store engineKey -> requestKey mapping (only if engineKeys provided) + engineKeyStr := "" if engineKeys != nil { pipe.Set(ctx, redisEngineKey(engineKeys[i]), redisKey, 0) + engineKeyStr = engineKeys[i].String() } for _, entry := range entries { // Use HSet to add the pod identifier as a field in the hash pipe.HSet(ctx, redisKey, entry.String(), "") + // Store reverse-index: pod: hash + // field = entry.String() (e.g. "10.0.0.1:8080@gpu") + // value = ":" (engineKey may be empty) + pipe.HSet(ctx, podIdentifierKey(entry.PodIdentifier), entry.String(), redisKey+":"+engineKeyStr) } } @@ -291,6 +302,8 @@ func (r *RedisIndex) Evict(ctx context.Context, key BlockHash, keyType KeyType, for _, entry := range entries { // Use HDel to remove the pod identifier field from the hash pipe.HDel(ctx, redisKey, entry.String()) + // Remove the corresponding field from the pod reverse-index hash. + pipe.HDel(ctx, podIdentifierKey(entry.PodIdentifier), entry.String()) } if _, err := pipe.Exec(ctx); err != nil { @@ -322,5 +335,104 @@ func (r *RedisIndex) GetRequestKey(ctx context.Context, engineKey BlockHash) (Bl } func redisEngineKey(engineKey BlockHash) string { + if engineKey == EmptyBlockHash { + return "" + } return "engine:" + engineKey.String() } + +func podIdentifierKey(podIdentifier string) string { + return "pod:" + podIdentifier +} + +// clearPodEntryScript atomically removes a single pod-entry field from the +// request-key hash AND from the pod reverse-index hash, then prunes the +// engine-key string and request-key hash when they become empty. +// +// KEYS[1] = request-key hash (e.g. "10633516") +// KEYS[2] = engine-key string (e.g. "engine:55269488") — may be "" to skip +// KEYS[3] = pod reverse-index hash (e.g. "pod:10.0.0.1:8080") +// ARGV[1] = pod entry field (e.g. "10.0.0.1:8080@gpu"). +var clearPodEntryScript = redis.NewScript(` + redis.call('HDEL', KEYS[1], ARGV[1]) + redis.call('HDEL', KEYS[3], ARGV[1]) + if redis.call('HLEN', KEYS[3]) == 0 then + redis.call('DEL', KEYS[3]) + end + if KEYS[2] ~= '' and redis.call('HLEN', KEYS[1]) == 0 then + redis.call('DEL', KEYS[2]) + redis.call('DEL', KEYS[1]) + end + return 1 +`) + +// Clear removes all index entries for the given podEntry. +// +// The pod reverse-index hash (pod:) stores: +// +// field = entry.String() e.g. "10.0.0.1:8080@gpu" +// value = ":" (engineKey may be empty for speculative entries) +func (r *RedisIndex) Clear(ctx context.Context, podEntry PodEntry) error { + logger := log.FromContext(ctx).WithName("kvblock.RedisIndex.Clear") + + podKey := podIdentifierKey(podEntry.PodIdentifier) + + // HGETALL returns all {entryString -> "requestKey:engineKey"} pairs in one RTT. + fields, err := r.RedisClient.HGetAll(ctx, podKey).Result() + if err != nil { + return fmt.Errorf("failed to get pod reverse-index for %s: %w", podEntry.PodIdentifier, err) + } + if len(fields) == 0 { + logger.Info("pod not found in reverse index, nothing to clear", "podEntry", podEntry) + return nil + } + + pipe := r.RedisClient.Pipeline() + for entryStr, meta := range fields { + // Filter by DeviceTier when specified. + // entryStr format: "@[speculative]" + if podEntry.DeviceTier != "" { + parts := strings.SplitN(entryStr, "@", 2) + if len(parts) < 2 { + continue + } + tier := parts[1] + // Strip optional [speculative] suffix before comparing + if idx := strings.Index(tier, "["); idx != -1 { + tier = tier[:idx] + } + if tier != podEntry.DeviceTier { + continue + } + } + + // meta = ":" (engineKey part may be empty) + sep := strings.LastIndex(meta, ":") + requestKeyStr := meta + engineKeyRedisKey := "" + if sep >= 0 { + requestKeyStr = meta[:sep] + if ek := meta[sep+1:]; ek != "" { + engineKeyRedisKey = redisEngineKey(BlockHash(mustParseUint64(ek))) + } + } + + pipe.EvalSha(ctx, clearPodEntryScript.Hash(), []string{requestKeyStr, engineKeyRedisKey, podKey}, entryStr) + } + + if _, err := pipe.Exec(ctx); err != nil { + return fmt.Errorf("failed to clear pod entries from Redis index: %w", err) + } + + logger.Info("Cleared pod entries from Redis index", "podEntry", podEntry) + return nil +} + +// mustParseUint64 parses a uint64 string, returning 0 on failure. +func mustParseUint64(s string) uint64 { + v, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return 0 + } + return v +} diff --git a/pkg/kvcache/kvblock/traced_index.go b/pkg/kvcache/kvblock/traced_index.go index 66fd0c993..6c96fd932 100644 --- a/pkg/kvcache/kvblock/traced_index.go +++ b/pkg/kvcache/kvblock/traced_index.go @@ -86,3 +86,7 @@ func (t *tracedIndex) Lookup( func (t *tracedIndex) GetRequestKey(ctx context.Context, engineKey BlockHash) (BlockHash, error) { return t.next.GetRequestKey(ctx, engineKey) } + +func (t *tracedIndex) Clear(ctx context.Context, podEntry PodEntry) error { + return t.next.Clear(ctx, podEntry) +} diff --git a/pkg/kvcache/metrics/collector.go b/pkg/kvcache/metrics/collector.go index fe68500ac..539282c2a 100644 --- a/pkg/kvcache/metrics/collector.go +++ b/pkg/kvcache/metrics/collector.go @@ -34,6 +34,10 @@ var ( Namespace: "kvcache", Subsystem: "index", Name: "evictions_total", Help: "Total number of KV-block evictions", }) + Clear = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "kvcache", Subsystem: "index", Name: "clear_total", + Help: "Total number of KV-block clears", + }) // LookupRequests counts how many Lookup() calls have been made. LookupRequests = prometheus.NewCounter(prometheus.CounterOpts{ @@ -77,7 +81,7 @@ var ( // Collectors returns a slice of all registered Prometheus collectors. func Collectors() []prometheus.Collector { return []prometheus.Collector{ - Admissions, Evictions, + Admissions, Evictions, Clear, LookupRequests, LookupHits, LookupLatency, RenderChatTemplateLatency, TokenizationLatency, TokenizedTokensCount, } @@ -125,6 +129,12 @@ func logMetrics(ctx context.Context) { } lookups := m.GetCounter().GetValue() + err = Clear.Write(&m) + if err != nil { + return + } + clears := m.GetCounter().GetValue() + var hitsMetric dto.Metric err = LookupHits.Write(&hitsMetric) if err != nil { @@ -149,6 +159,7 @@ func logMetrics(ctx context.Context) { "admissions", admissions, "evictions", evictions, "lookups", lookups, + "clears", clears, "hits", hits, "latency_count", latencyCount, "latency_sum", latencySum, diff --git a/pkg/kvevents/engineadapter/vllm_adapter.go b/pkg/kvevents/engineadapter/vllm_adapter.go index 8e7085d4b..50f4d1e0e 100644 --- a/pkg/kvevents/engineadapter/vllm_adapter.go +++ b/pkg/kvevents/engineadapter/vllm_adapter.go @@ -268,8 +268,19 @@ func (v *VLLMAdapter) convertBlockRemovedEvent(fields []any) (kvevents.GenericEv } // convertAllBlocksClearedEvent converts a decoded []any into an AllBlocksClearedEvent. -func (v *VLLMAdapter) convertAllBlocksClearedEvent(_ []any) (kvevents.GenericEvent, error) { - return &kvevents.AllBlocksClearedEvent{}, nil +func (v *VLLMAdapter) convertAllBlocksClearedEvent(fields []any) (kvevents.GenericEvent, error) { + deviceTier := "" + if raw := fieldAt(fields, 1); raw != nil { + s, ok := raw.(string) + if !ok { + return nil, fmt.Errorf("AllBlocksCleared: medium is not a string: %T", raw) + } + deviceTier = s + } + + return &kvevents.AllBlocksClearedEvent{ + DeviceTier: deviceTier, + }, nil } // toUint32Slice converts a msgpack-decoded []any of integers to []uint32. diff --git a/pkg/kvevents/engineadapter/vllm_adapter_test.go b/pkg/kvevents/engineadapter/vllm_adapter_test.go index e0f168f5f..36d748b0f 100644 --- a/pkg/kvevents/engineadapter/vllm_adapter_test.go +++ b/pkg/kvevents/engineadapter/vllm_adapter_test.go @@ -370,7 +370,7 @@ func TestVLLMBlockRemoved(t *testing.T) { func TestVLLMAllBlocksCleared(t *testing.T) { adapter := NewVLLMAdapter() - vllmEvent := []any{"AllBlocksCleared"} + vllmEvent := []any{"AllBlocksCleared", nil} rawBytes, err := msgpack.Marshal(vllmEvent) require.NoError(t, err) diff --git a/pkg/kvevents/pool.go b/pkg/kvevents/pool.go index 149a8c63a..3d92de59a 100644 --- a/pkg/kvevents/pool.go +++ b/pkg/kvevents/pool.go @@ -322,11 +322,16 @@ func (p *Pool) processEventBatch(ctx context.Context, batch *EventBatch, podIden } case *AllBlocksClearedEvent: - debugLogger.Info("All blocks cleared event received", - "podIdentifier", podIdentifier, - "deviceTier", ev.DeviceTier, - "modelName", modelName) - + if err := p.index.Clear(ctx, kvblock.PodEntry{ + PodIdentifier: podIdentifier, + DeviceTier: strings.ToLower(ev.DeviceTier), + }); err != nil { + debugLogger.Error(err, "Failed to clear all blocks", + "deviceTier", strings.ToLower(ev.DeviceTier), + "podIdentifier", podIdentifier, + "modelName", modelName) + continue + } default: debugLogger.Info("Unknown event", "podIdentifier", podIdentifier, "event", genericEvent) }