Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ sequenceDiagram
else BlockRemoved
Worker->>Index: Evict(key, podEntry)
else AllBlocksCleared
Note over Worker: No-op
Worker->>Index: Clear()
end
end
```
Expand Down
27 changes: 27 additions & 0 deletions examples/helper/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,33 @@ func SimulateRemoveEvent(ctx context.Context, publisher *Publisher) error {
return nil
}

func SimulateClearAllBlocksEvent(ctx context.Context, publisher *Publisher) error {
logger := log.FromContext(ctx)
logger.Info("@@@ Simulating vLLM engine clear all blocks...")

clearAllBlocksEvent := []any{
"AllBlocksCleared",
}

clearAllBlocksPayload, err := msgpack.Marshal(clearAllBlocksEvent)
if err != nil {
return fmt.Errorf("failed to marshal AllBlocksCleared event: %w", err)
}
clearAllBlockEventBatch := []any{
float64(time.Now().UnixNano()) / 1e9,
[]msgpack.RawMessage{clearAllBlocksPayload},
nil,
}

if err := publisher.PublishEvent(ctx, topic, clearAllBlockEventBatch); err != nil {
return fmt.Errorf("failed to publish AllBlocksCleared event: %w", err)
}
logger.Info("@@@ Published AllBlocksCleared event", "topic", topic)

time.Sleep(3 * time.Second)
return nil
}

func SetupEventsPool(ctx context.Context, kvBlockIndex kvblock.Index) (*kvevents.Pool, error) {
logger := log.FromContext(ctx)

Expand Down
26 changes: 26 additions & 0 deletions examples/kv_events/offline/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,32 @@ func RunEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish
}
logger.Info("@@@ Final pod scores after BlockRemoved events", "pods", pods)

// Simulate vLLM engine publishing BlockStored events
err = helper.SimulateProduceEvent(ctx, publisher)
if err != nil {
return err
}

// Query again to see the effect of the events
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil)
if err != nil {
return err
}
logger.Info("@@@ Pod scores after BlockStored events", "pods", pods)

// Simulate vLLM engine publishing AllBlocksCleared event
err = helper.SimulateClearAllBlocksEvent(ctx, publisher)
if err != nil {
return err
}

// Query again to see the effect of the events
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil)
if err != nil {
return err
}
logger.Info("@@@ Pod scores after AllBlocksCleared events", "pods", pods)

logger.Info("Events demo completed. Pool continues listening for more events...")
logger.Info("Press Ctrl+C to shutdown")

Expand Down
15 changes: 15 additions & 0 deletions examples/valkey_example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,21 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer)

logger.Info("Cache lookup after eviction", "keysFound", len(lookupAfterEvict))

// Clear the cache
logger.Info("Clearing the cache")
err = indexer.KVBlockIndex().Clear(ctx)
if err != nil {
return fmt.Errorf("failed to clear cache: %w", err)
}

// Lookup again after clear
lookupAfterClear, err := indexer.KVBlockIndex().Lookup(ctx, promptKeys, nil)
if err != nil {
return fmt.Errorf("failed to lookup after clear: %w", err)
}

logger.Info("Cache lookup after clear", "keysFound", len(lookupAfterClear))

// Final score check to see the difference
finalScores, err := indexer.GetPodScores(ctx, nil, prompt, modelName, []string{"demo-pod-1", "demo-pod-2"})
if err != nil {
Expand Down
85 changes: 83 additions & 2 deletions pkg/kvcache/kvblock/cost_aware_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type CostAwareMemoryIndexConfig struct {
Size string `json:"size,omitempty"`
}

// DefaultCostAwareMemoryIndexConfig returns a default configuration for the CostAwareMemoryIndex.
func DefaultCostAwareMemoryIndexConfig() *CostAwareMemoryIndexConfig {
return &CostAwareMemoryIndexConfig{
Size: "2GiB", // 2GiB default size
Expand Down Expand Up @@ -75,9 +76,15 @@ func NewCostAwareMemoryIndex(cfg *CostAwareMemoryIndexConfig) (*CostAwareMemoryI
return nil, fmt.Errorf("failed to initialize in-memory engine key map: %w", err)
}

podToRequestKey, err := lru.New[string, map[BlockHash]BlockHash](defaultNumCounters)
if err != nil {
return nil, fmt.Errorf("failed to initialize pod-to-request-key map: %w", err)
}

return &CostAwareMemoryIndex{
data: cache,
requestKeys: requestKeys,
data: cache,
requestKeys: requestKeys,
podToRequestKey: podToRequestKey,
}, nil
}

Expand All @@ -95,8 +102,11 @@ type CostAwareMemoryIndex struct {
requestKeys *lru.Cache[BlockHash, BlockHash]
// mu protects concurrent access to the index operations
mu sync.RWMutex
// podToRequestKey is a reverse index: podIdentifier -> [requestKey]: engineKey.
podToRequestKey *lru.Cache[string, map[BlockHash]BlockHash]
}

// MaxCost returns the maximum cost of the cache.
func (m *CostAwareMemoryIndex) MaxCost() int64 {
return m.data.MaxCost()
}
Expand Down Expand Up @@ -193,8 +203,19 @@ func (m *CostAwareMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys
podCache = &CostPodCache{}
}

curEngineKey := EmptyBlockHash
if engineKeys != nil {
curEngineKey = engineKeys[i]
}

for _, entry := range entries {
podCache.Add(entry)
mappings, found := m.podToRequestKey.Peek(entry.PodIdentifier)
if !found {
mappings = make(map[BlockHash]BlockHash)
}
mappings[requestKey] = curEngineKey
m.podToRequestKey.Add(entry.PodIdentifier, mappings)
}

// Calculate the actual cost for this cache entry
Expand All @@ -206,6 +227,7 @@ func (m *CostAwareMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys
return nil
}

// Lookup looks up a list of request keys and returns the associated pod entries.
func (m *CostAwareMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHash,
podIdentifierSet sets.Set[string],
) (map[BlockHash][]PodEntry, error) {
Expand Down Expand Up @@ -308,6 +330,13 @@ func (m *CostAwareMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType

for _, entry := range entries {
podCache.Delete(entry)

if mappings, found := m.podToRequestKey.Peek(entry.PodIdentifier); found {
delete(mappings, requestKey)
if len(mappings) == 0 {
m.podToRequestKey.Remove(entry.PodIdentifier)
}
}
}

if podCache.Len() == 0 {
Expand All @@ -333,3 +362,55 @@ func (m *CostAwareMemoryIndex) GetRequestKey(ctx context.Context, engineKey Bloc
}
return requestKey, nil
}

// Clear removes all entries from the index backend.
func (m *CostAwareMemoryIndex) Clear(ctx context.Context, podEntry PodEntry) error {
m.mu.Lock()
defer m.mu.Unlock()
traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvblock.CostAwareMemoryIndex.Clear")

// Remove all entries for the given pod identifier
mappings, found := m.podToRequestKey.Get(podEntry.PodIdentifier)
if !found {
traceLogger.Info("pod not found in reverse index, nothing to clear", "podEntry", podEntry)
return nil
}
for requestKey, engineKey := range mappings {
pod, found := m.data.Get(requestKey.String())
if !found {
traceLogger.Info("request key not found in cache, skipping", "requestKey", requestKey)
continue
}
podCacheLenBefore := pod.Len()
var toDelete []PodEntry
pod.cache.Range(func(key, value any) bool {
if entry, ok := key.(PodEntry); ok {
if podEntry.DeviceTier != "" && entry.DeviceTier != podEntry.DeviceTier {
return true
}
if entry.PodIdentifier == podEntry.PodIdentifier {
toDelete = append(toDelete, entry)
}
}
return true
})

for _, entry := range toDelete {
pod.Delete(entry)
}

if pod.Len() == 0 {
m.data.Del(requestKey.String())
if _, hasEngineKey := m.requestKeys.Get(engineKey); hasEngineKey {
m.requestKeys.Remove(engineKey)
}
} else if podCacheLenBefore != pod.Len() {
m.data.Set(requestKey.String(), pod, pod.CalculateByteSize(requestKey.String()))
}
m.podToRequestKey.Remove(podEntry.PodIdentifier)
}

m.data.Wait()
traceLogger.Info("Cleared pod entries from InMemoryIndex", "podEntry", podEntry)
return nil
}
106 changes: 106 additions & 0 deletions pkg/kvcache/kvblock/in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,15 @@ func NewInMemoryIndex(cfg *InMemoryIndexConfig) (*InMemoryIndex, error) {
return nil, fmt.Errorf("failed to initialize in-memory engine key map: %w", err)
}

podToRequestKeys, err := lru.New[string, map[BlockHash]BlockHash](cfg.Size)
if err != nil {
return nil, fmt.Errorf("failed to initialize pod-to-request-key map: %w", err)
}

return &InMemoryIndex{
data: cache,
engineToRequestKeys: engineToRequestKeys,
podToRequestKeys: podToRequestKeys,
podCacheSize: cfg.PodCacheSize,
}, nil
}
Expand All @@ -80,6 +86,8 @@ type InMemoryIndex struct {
data *lru.Cache[BlockHash, *PodCache]
// engineToRequestKeys holds the mapping of engineKeys to requestKeys.
engineToRequestKeys *lru.Cache[BlockHash, BlockHash]
// podToRequestKeys is a reverse index: podIdentifier -> [requestKey]: engineKey.
podToRequestKeys *lru.Cache[string, map[BlockHash]BlockHash]
// podCacheSize is the maximum number of pod entries per key.
podCacheSize int
}
Expand Down Expand Up @@ -161,7 +169,9 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block

for i, requestKey := range requestKeys {
// 1. Store engineKey -> requestKey mapping (only if engineKeys provided)
curEngineKey := EmptyBlockHash
if engineKeys != nil {
curEngineKey = engineKeys[i]
m.engineToRequestKeys.Add(engineKeys[i], requestKey)
}

Expand Down Expand Up @@ -202,6 +212,14 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block
podCache.mu.Lock()
for _, entry := range entries {
podCache.cache.Add(entry, struct{}{})

// 3. Maintain reverse index: podIdentifier -> [requestKey]: engineKey
mappings, ok := m.podToRequestKeys.Peek(entry.PodIdentifier)
if !ok {
mappings = make(map[BlockHash]BlockHash)
}
mappings[requestKey] = curEngineKey
m.podToRequestKeys.Add(entry.PodIdentifier, mappings)
}
podCache.mu.Unlock()

Expand Down Expand Up @@ -253,6 +271,14 @@ func (m *InMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType KeyTyp
podCache.mu.Lock()
for _, entry := range entries {
podCache.cache.Remove(entry)

if mappings, ok := m.podToRequestKeys.Peek(entry.PodIdentifier); ok {
delete(mappings, requestKey)

if len(mappings) == 0 {
m.podToRequestKeys.Remove(entry.PodIdentifier)
}
}
}

isEmpty := podCache.cache.Len() == 0
Expand Down Expand Up @@ -304,3 +330,83 @@ func podsPerKeyPrintHelper(ks map[BlockHash][]PodEntry) string {
}
return b.String()
}

// Clear removes all entries for the given podEntry from the index.
func (m *InMemoryIndex) Clear(ctx context.Context, podEntry PodEntry) error {
traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvblock.InMemoryIndex.Clear")
// Remove all entries for the given pod identifier
mappings, found := m.podToRequestKeys.Get(podEntry.PodIdentifier)
if !found {
traceLogger.Info("pod not found in reverse index, nothing to clear", "podEntry", podEntry)
return nil
}

for requestKey, engineKey := range mappings {
cache, exists := m.data.Peek(requestKey)
if !exists || cache == nil {
continue
}

cache.mu.Lock()
var toDelete []PodEntry
for _, entry := range cache.cache.Keys() {
if entry.PodIdentifier != podEntry.PodIdentifier {
continue
}
if podEntry.DeviceTier != "" && entry.DeviceTier != podEntry.DeviceTier {
continue
}
toDelete = append(toDelete, entry)
}

for _, entry := range toDelete {
cache.cache.Remove(entry)
}

isEmpty := cache.cache.Len() == 0
cache.mu.Unlock()

if !isEmpty {
continue
}

currentCache, stillExists := m.data.Get(requestKey)
var engineKeyExists bool
if engineKey != EmptyBlockHash {
engineKeyExists = m.engineToRequestKeys.Contains(engineKey)
}
if !stillExists || currentCache == nil {
if engineKeyExists {
m.engineToRequestKeys.Remove(engineKey)
}
continue
}
currentCache.mu.Lock()
if currentCache.cache.Len() == 0 {
m.data.Remove(requestKey)
if engineKeyExists {
m.engineToRequestKeys.Remove(engineKey)
}
}
currentCache.mu.Unlock()
}

if podEntry.DeviceTier == "" {
m.podToRequestKeys.Remove(podEntry.PodIdentifier)
} else {
remaining := make(map[BlockHash]BlockHash)
for requestKey, engineKey := range mappings {
if exists := m.data.Contains(requestKey); exists {
remaining[requestKey] = engineKey
}
}
if len(remaining) == 0 {
m.podToRequestKeys.Remove(podEntry.PodIdentifier)
} else {
m.podToRequestKeys.Add(podEntry.PodIdentifier, remaining)
}
}

traceLogger.Info("Cleared pod entries from InMemoryIndex", "podEntry", podEntry)
return nil
}
4 changes: 4 additions & 0 deletions pkg/kvcache/kvblock/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ type Index interface {
Evict(ctx context.Context, key BlockHash, keyType KeyType, entries []PodEntry) error
// GetRequestKey returns the requestKey associated with the given engineKey.
GetRequestKey(ctx context.Context, engineKey BlockHash) (BlockHash, error)
// Clear removes all entries from the index backend.
// If podEntry.DeviceTier is empty, all tiers for that pod identifier are removed.
// If podEntry.DeviceTier is set, only entries matching that exact tier are removed.
Clear(ctx context.Context, podEntry PodEntry) error
}

// KeyType indicates whether a key passed to Evict is an engine key or a request key.
Expand Down
Loading