Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ sequenceDiagram
else BlockRemoved
Worker->>Index: Evict(key, podEntry)
else AllBlocksCleared
Note over Worker: No-op
Worker->>Index: Clear()
end
end
```
Expand Down
27 changes: 27 additions & 0 deletions examples/helper/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,33 @@ func SimulateRemoveEvent(ctx context.Context, publisher *Publisher) error {
return nil
}

func SimulateClearAllBlocksEvent(ctx context.Context, publisher *Publisher) error {
logger := log.FromContext(ctx)
logger.Info("@@@ Simulating vLLM engine clear all blocks...")

clearAllBlocksEvent := []any{
"AllBlocksCleared",
}

clearAllBlocksPayload, err := msgpack.Marshal(clearAllBlocksEvent)
if err != nil {
return fmt.Errorf("failed to marshal AllBlocksCleared event: %w", err)
}
clearAllBlockEventBatch := []any{
float64(time.Now().UnixNano()) / 1e9,
[]msgpack.RawMessage{clearAllBlocksPayload},
nil,
}

if err := publisher.PublishEvent(ctx, topic, clearAllBlockEventBatch); err != nil {
return fmt.Errorf("failed to publish AllBlocksCleared event: %w", err)
}
logger.Info("@@@ Published AllBlocksCleared event", "topic", topic)

time.Sleep(3 * time.Second)
return nil
}

func SetupEventsPool(ctx context.Context, kvBlockIndex kvblock.Index) (*kvevents.Pool, error) {
logger := log.FromContext(ctx)

Expand Down
26 changes: 26 additions & 0 deletions examples/kv_events/offline/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,32 @@ func RunEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish
}
logger.Info("@@@ Final pod scores after BlockRemoved events", "pods", pods)

// Simulate vLLM engine publishing BlockStored events
err = helper.SimulateProduceEvent(ctx, publisher)
if err != nil {
return err
}

// Query again to see the effect of the events
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil)
if err != nil {
return err
}
logger.Info("@@@ Pod scores after BlockStored events", "pods", pods)

// Simulate vLLM engine publishing AllBlocksCleared event
err = helper.SimulateClearAllBlocksEvent(ctx, publisher)
if err != nil {
return err
}

// Query again to see the effect of the events
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil)
if err != nil {
return err
}
logger.Info("@@@ Pod scores after AllBlocksCleared events", "pods", pods)

logger.Info("Events demo completed. Pool continues listening for more events...")
logger.Info("Press Ctrl+C to shutdown")

Expand Down
15 changes: 15 additions & 0 deletions examples/valkey_example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,21 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer)

logger.Info("Cache lookup after eviction", "keysFound", len(lookupAfterEvict))

// Clear the cache
logger.Info("Clearing the cache")
err = indexer.KVBlockIndex().Clear(ctx)
if err != nil {
return fmt.Errorf("failed to clear cache: %w", err)
}

// Lookup again after clear
lookupAfterClear, err := indexer.KVBlockIndex().Lookup(ctx, promptKeys, nil)
if err != nil {
return fmt.Errorf("failed to lookup after clear: %w", err)
}

logger.Info("Cache lookup after clear", "keysFound", len(lookupAfterClear))

// Final score check to see the difference
finalScores, err := indexer.GetPodScores(ctx, nil, prompt, modelName, []string{"demo-pod-1", "demo-pod-2"})
if err != nil {
Expand Down
101 changes: 99 additions & 2 deletions pkg/kvcache/kvblock/cost_aware_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type CostAwareMemoryIndexConfig struct {
Size string `json:"size,omitempty"`
}

// DefaultCostAwareMemoryIndexConfig returns a default configuration for the CostAwareMemoryIndex.
func DefaultCostAwareMemoryIndexConfig() *CostAwareMemoryIndexConfig {
return &CostAwareMemoryIndexConfig{
Size: "2GiB", // 2GiB default size
Expand Down Expand Up @@ -75,9 +76,15 @@ func NewCostAwareMemoryIndex(cfg *CostAwareMemoryIndexConfig) (*CostAwareMemoryI
return nil, fmt.Errorf("failed to initialize in-memory engine key map: %w", err)
}

podToRequestKeys, err := lru.New[string, map[BlockHash]BlockHash](defaultNumCounters)
if err != nil {
return nil, fmt.Errorf("failed to initialize pod-to-request-key map: %w", err)
}

return &CostAwareMemoryIndex{
data: cache,
requestKeys: requestKeys,
data: cache,
requestKeys: requestKeys,
podToRequestKeys: podToRequestKeys,
}, nil
}

Expand All @@ -95,8 +102,11 @@ type CostAwareMemoryIndex struct {
requestKeys *lru.Cache[BlockHash, BlockHash]
// mu protects concurrent access to the index operations
mu sync.RWMutex
// podToRequestKeys is a reverse index: podIdentifier -> [requestKey]: engineKey.
podToRequestKeys *lru.Cache[string, map[BlockHash]BlockHash]
}

// MaxCost returns the maximum cost of the cache.
func (m *CostAwareMemoryIndex) MaxCost() int64 {
return m.data.MaxCost()
}
Expand Down Expand Up @@ -193,8 +203,19 @@ func (m *CostAwareMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys
podCache = &CostPodCache{}
}

curEngineKey := EmptyBlockHash
if engineKeys != nil {
curEngineKey = engineKeys[i]
}

for _, entry := range entries {
podCache.Add(entry)
mappings, found := m.podToRequestKeys.Peek(entry.PodIdentifier)
if !found {
mappings = make(map[BlockHash]BlockHash)
}
mappings[requestKey] = curEngineKey
m.podToRequestKeys.Add(entry.PodIdentifier, mappings)
}

// Calculate the actual cost for this cache entry
Expand All @@ -206,6 +227,7 @@ func (m *CostAwareMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys
return nil
}

// Lookup looks up a list of request keys and returns the associated pod entries.
func (m *CostAwareMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHash,
podIdentifierSet sets.Set[string],
) (map[BlockHash][]PodEntry, error) {
Expand Down Expand Up @@ -308,6 +330,13 @@ func (m *CostAwareMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType

for _, entry := range entries {
podCache.Delete(entry)

if mappings, found := m.podToRequestKeys.Peek(entry.PodIdentifier); found {
delete(mappings, requestKey)
if len(mappings) == 0 {
m.podToRequestKeys.Remove(entry.PodIdentifier)
}
}
}
Comment thread
yash9263 marked this conversation as resolved.

if podCache.Len() == 0 {
Expand All @@ -333,3 +362,71 @@ func (m *CostAwareMemoryIndex) GetRequestKey(ctx context.Context, engineKey Bloc
}
return requestKey, nil
}

// Clear removes all entries from the index backend.
func (m *CostAwareMemoryIndex) Clear(ctx context.Context, podEntry PodEntry) error {
m.mu.Lock()
defer m.mu.Unlock()
traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvblock.CostAwareMemoryIndex.Clear")

// Remove all entries for the given pod identifier
mappings, found := m.podToRequestKeys.Get(podEntry.PodIdentifier)
if !found {
traceLogger.Info("pod not found in reverse index, nothing to clear", "podEntry", podEntry)
return nil
}

for requestKey, engineKey := range mappings {
Comment thread
yash9263 marked this conversation as resolved.
pod, found := m.data.Get(requestKey.String())
if !found {
traceLogger.Info("request key not found in cache, skipping", "requestKey", requestKey)
continue
}
podCacheLenBefore := pod.Len()
var toDelete []PodEntry
pod.cache.Range(func(key, value any) bool {
if entry, ok := key.(PodEntry); ok {
if podEntry.DeviceTier != "" && entry.DeviceTier != podEntry.DeviceTier {
return true
}
if entry.PodIdentifier == podEntry.PodIdentifier {
toDelete = append(toDelete, entry)
}
}
return true
})
Comment thread
yash9263 marked this conversation as resolved.

for _, entry := range toDelete {
pod.Delete(entry)
}

if pod.Len() == 0 {
m.data.Del(requestKey.String())
if _, hasEngineKey := m.requestKeys.Get(engineKey); hasEngineKey {
m.requestKeys.Remove(engineKey)
}
Comment thread
yash9263 marked this conversation as resolved.
Outdated
} else if podCacheLenBefore != pod.Len() {
m.data.Set(requestKey.String(), pod, pod.CalculateByteSize(requestKey.String()))
}
}

if podEntry.DeviceTier == "" {
m.podToRequestKeys.Remove(podEntry.PodIdentifier)
} else {
remaining := make(map[BlockHash]BlockHash)
for requestKey, engineKey := range mappings {
if _, found := m.data.Get(requestKey.String()); found {
remaining[requestKey] = engineKey
}
}
if len(remaining) == 0 {
m.podToRequestKeys.Remove(podEntry.PodIdentifier)
} else {
m.podToRequestKeys.Add(podEntry.PodIdentifier, remaining)
}
}

m.data.Wait()
traceLogger.Info("Cleared pod entries from InMemoryIndex", "podEntry", podEntry)
Comment thread
yash9263 marked this conversation as resolved.
Outdated
return nil
}
Loading