Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 73 additions & 16 deletions pkg/kvcache/kvblock/cost_aware_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,73 @@ func (m *CostAwareMemoryIndex) MaxCost() int64 {

// CostPodCache wraps a sync.Map of PodEntry and provides cost calculation for memory usage estimation.
type CostPodCache struct {
cache sync.Map // map[PodEntry]struct{}
cache sync.Map // map[string]*PodEntry (key: "podID@tier")
// size tracks the number of entries in cache for O(1) Len().
size atomic.Int64
}

// Add adds a PodEntry to the cache.
// Add adds or updates a PodEntry in the cache, merging StoredGroups if the entry exists.
func (c *CostPodCache) Add(entry PodEntry) {
if _, loaded := c.cache.LoadOrStore(entry, struct{}{}); !loaded {
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)

// Try to load existing entry
if existingVal, loaded := c.cache.Load(cacheKey); loaded {
if existingEntry, ok := existingVal.(*PodEntry); ok {
// Merge StoredGroups
existingEntry.StoredGroups = mergeGroupsUnique(existingEntry.StoredGroups, entry.StoredGroups)
// Store updated entry
c.cache.Store(cacheKey, existingEntry)
}
} else {
// Create new entry
newEntry := &PodEntry{
PodIdentifier: entry.PodIdentifier,
DeviceTier: entry.DeviceTier,
Speculative: entry.Speculative,
StoredGroups: mergeGroupsUnique(nil, entry.StoredGroups),
}
c.cache.Store(cacheKey, newEntry)
c.size.Add(1)
}
}

// Delete removes a PodEntry from the cache.
// Delete removes a PodEntry from the cache entirely.
func (c *CostPodCache) Delete(entry PodEntry) {
if _, loaded := c.cache.LoadAndDelete(entry); loaded {
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)
if _, loaded := c.cache.LoadAndDelete(cacheKey); loaded {
c.size.Add(-1)
}
}

// RemoveGroups removes specified groups from a PodEntry's StoredGroups.
// If no groups remain, the entry is deleted.
func (c *CostPodCache) RemoveGroups(entry PodEntry) bool {
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)

existingVal, loaded := c.cache.Load(cacheKey)
if !loaded {
return false
}

existingEntry, ok := existingVal.(*PodEntry)
if !ok {
return false
}

// Remove specified groups
updatedGroups := removeGroups(existingEntry.StoredGroups, entry.StoredGroups)

if len(updatedGroups) == 0 {
// No groups left, delete the entry
c.cache.Delete(cacheKey)
c.size.Add(-1)
return true
}

// Update with remaining groups
existingEntry.StoredGroups = updatedGroups
c.cache.Store(cacheKey, existingEntry)
return false
}

// Len returns the number of entries in the cache.
Expand All @@ -141,16 +191,22 @@ func (c *CostPodCache) CalculateByteSize(keyStr string) int64 {

// Count entries and calculate their size
c.cache.Range(func(key, value interface{}) bool {
entry, ok := key.(PodEntry)
if !ok {
// key is now a string, value is *PodEntry
keyStr, okKey := key.(string)
entry, okEntry := value.(*PodEntry)
if !okKey || !okEntry {
return true
}

entryCount++
totalBytes += int64(len(entry.PodIdentifier)) // PodIdentifier string content
totalBytes += int64(len(entry.DeviceTier)) // DeviceTier string content
totalBytes += 32 // string headers (16 bytes each for 2 strings)
totalBytes += 8 // struct padding/alignment
totalBytes += int64(len(keyStr)) // cache key string
totalBytes += int64(len(entry.PodIdentifier)) // PodIdentifier string content
totalBytes += int64(len(entry.DeviceTier)) // DeviceTier string content
totalBytes += int64(len(entry.StoredGroups) * 8) // StoredGroups slice (8 bytes per int)
totalBytes += 32 // string headers (16 bytes each for 2 strings)
totalBytes += 24 // slice header for StoredGroups
totalBytes += 8 // pointer to PodEntry
totalBytes += 8 // struct padding/alignment
return true
})

Expand Down Expand Up @@ -234,17 +290,17 @@ func (m *CostAwareMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHa
if podIdentifierSet.Len() == 0 {
// If no pod identifiers are provided, return all pods
pods.cache.Range(func(k, value interface{}) bool {
if pod, ok := k.(PodEntry); ok {
podsPerKey[key] = append(podsPerKey[key], pod)
if pod, ok := value.(*PodEntry); ok {
podsPerKey[key] = append(podsPerKey[key], *pod)
}
return true
})
} else {
// Filter pods based on the provided pod identifiers
pods.cache.Range(func(k, value interface{}) bool {
if pod, ok := k.(PodEntry); ok {
if pod, ok := value.(*PodEntry); ok {
if podIdentifierSet.Has(pod.PodIdentifier) {
podsPerKey[key] = append(podsPerKey[key], pod)
podsPerKey[key] = append(podsPerKey[key], *pod)
}
}
return true
Expand Down Expand Up @@ -307,7 +363,8 @@ func (m *CostAwareMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType
podCacheLenBefore := podCache.Len()

for _, entry := range entries {
podCache.Delete(entry)
// Remove groups from the entry; if no groups remain, the entry is deleted
podCache.RemoveGroups(entry)
}

if podCache.Len() == 0 {
Expand Down
121 changes: 108 additions & 13 deletions pkg/kvcache/kvblock/in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ var _ Index = &InMemoryIndex{}

// PodCache represents a cache for pod entries.
type PodCache struct {
// cache is an LRU cache that maps PodEntry to their last access time.
// thread-safe.
cache *lru.Cache[PodEntry, struct{}]
// cache is an LRU cache that maps "podID@tier" keys to PodEntry pointers.
// This allows in-place updates of StoredGroups without recreating entries.
cache *lru.Cache[string, *PodEntry]
// mu protects the cache from concurrent access during check-and-set operations.
mu sync.Mutex
}
Expand Down Expand Up @@ -126,12 +126,14 @@ func (m *InMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHash,

if podIdentifierSet.Len() == 0 {
// If no pod identifiers are provided, return all pods
podsPerKey[requestKey] = pods.cache.Keys()
for _, podEntry := range pods.cache.Values() {
podsPerKey[requestKey] = append(podsPerKey[requestKey], *podEntry)
}
} else {
// Filter pods based on the provided pod identifiers
for _, pod := range pods.cache.Keys() {
if podIdentifierSet.Has(pod.PodIdentifier) {
podsPerKey[requestKey] = append(podsPerKey[requestKey], pod)
for _, podEntry := range pods.cache.Values() {
if podIdentifierSet.Has(podEntry.PodIdentifier) {
podsPerKey[requestKey] = append(podsPerKey[requestKey], *podEntry)
}
}
}
Expand Down Expand Up @@ -174,7 +176,7 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block
//nolint:nestif // double-checked locking pattern
if !found {
// Create new cache
cache, err := lru.New[PodEntry, struct{}](m.podCacheSize)
cache, err := lru.New[string, *PodEntry](m.podCacheSize)
if err != nil {
return fmt.Errorf("failed to create pod cache for key %s: %w", requestKey.String(), err)
}
Expand All @@ -201,11 +203,30 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block

podCache.mu.Lock()
for _, entry := range entries {
podCache.cache.Add(entry, struct{}{})
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)

// Check if entry already exists
existingEntry, found := podCache.cache.Get(cacheKey)
if found {
// Merge StoredGroups, deduplicating and preserving order
existingEntry.StoredGroups = mergeGroupsUnique(existingEntry.StoredGroups, entry.StoredGroups)
// Re-add to update LRU position
podCache.cache.Add(cacheKey, existingEntry)
traceLogger.Info("updated existing pod entry with merged groups",
"requestKey", requestKey, "pod", existingEntry)
} else {
// Create new entry (copy to avoid mutation)
newEntry := &PodEntry{
PodIdentifier: entry.PodIdentifier,
DeviceTier: entry.DeviceTier,
Speculative: entry.Speculative,
StoredGroups: mergeGroupsUnique(nil, entry.StoredGroups),
}
podCache.cache.Add(cacheKey, newEntry)
traceLogger.Info("added new pod entry", "requestKey", requestKey, "pod", newEntry)
}
}
podCache.mu.Unlock()

traceLogger.Info("added pods to key", "requestKey", requestKey, "pods", entries)
}

return nil
Expand Down Expand Up @@ -252,13 +273,36 @@ func (m *InMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType KeyTyp

podCache.mu.Lock()
for _, entry := range entries {
podCache.cache.Remove(entry)
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)

existingEntry, found := podCache.cache.Get(cacheKey)
if !found {
traceLogger.Info("pod entry not found for eviction, skipping",
"requestKey", requestKey, "podID", entry.PodIdentifier, "tier", entry.DeviceTier)
continue
}

// Remove the specified groups from StoredGroups
updatedGroups := removeGroups(existingEntry.StoredGroups, entry.StoredGroups)

if len(updatedGroups) == 0 {
// No groups left, remove the entire pod entry
podCache.cache.Remove(cacheKey)
traceLogger.Info("removed pod entry (no groups remaining)",
"requestKey", requestKey, "pod", existingEntry)
} else {
// Update with remaining groups
existingEntry.StoredGroups = updatedGroups
podCache.cache.Add(cacheKey, existingEntry)
traceLogger.Info("updated pod entry after group removal",
"requestKey", requestKey, "pod", existingEntry, "remainingGroups", updatedGroups)
}
}

isEmpty := podCache.cache.Len() == 0
podCache.mu.Unlock()

traceLogger.Info("evicted pods from key", "requestKey", requestKey, "key", key, "keyType", keyType, "pods", entries)
traceLogger.Info("processed eviction", "requestKey", requestKey, "key", key, "keyType", keyType, "entries", entries)

// Remove key from main cache if empty.
// Re-fetch and hold the lock through removal to prevent racing with Add.
Expand Down Expand Up @@ -294,6 +338,57 @@ func (m *InMemoryIndex) GetRequestKey(ctx context.Context, engineKey BlockHash)
return requestKey, nil
}

// podCacheKey generates a cache key for a pod entry.
// Format: "podIdentifier@deviceTier" or "podIdentifier@deviceTier[speculative]".
func podCacheKey(podIdentifier, deviceTier string, speculative bool) string {
key := podIdentifier + "@" + deviceTier
if speculative {
key += "[speculative]"
}
return key
}

// mergeGroupsUnique merges two group lists, removing duplicates and preserving order.
// Elements from 'existing' come first, followed by new elements from 'incoming'.
func mergeGroupsUnique(existing, incoming []int) []int {
// 1. If incoming is empty, return existing as-is
if len(incoming) == 0 {
return existing
}
firstIncoming := incoming[0]

for _, v := range existing {
if v == firstIncoming {
return existing // Already there, nothing to do
}
}
result := make([]int, 0, len(existing)+1)
result = append(result, existing...)
result = append(result, firstIncoming)
return result
}

// removeGroups removes specified groups from the list,
// maintaining order of remaining elements.
func removeGroups(existing, toRemove []int) []int {
if len(toRemove) == 0 || len(existing) == 0 {
return existing
}
target := toRemove[0]
targetIdx := -1
for i, v := range existing {
if v == target {
targetIdx = i
break
}
}
if targetIdx == -1 {
return existing
}
copy(existing[targetIdx:], existing[targetIdx+1:])
return existing[:len(existing)-1]
}

// podsPerKeyPrintHelper formats a map of keys to pod names for printing.
func podsPerKeyPrintHelper(ks map[BlockHash][]PodEntry) string {
var b strings.Builder
Expand Down
2 changes: 2 additions & 0 deletions pkg/kvcache/kvblock/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ type PodEntry struct {
DeviceTier string
// Speculative indicates the entry was added predictively before a KV event confirmed it.
Speculative bool
// StoredGroups tracks the group IDs that have stored this block.
StoredGroups []int
}

// String returns a string representation of the PodEntry.
Expand Down
Loading