Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 56 additions & 56 deletions pkg/kvcache/kvblock/in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ type PodCache struct {
cache *lru.Cache[PodEntry, struct{}]
// mu protects the cache from concurrent access during check-and-set operations.
mu sync.Mutex
// removed indicates this PodCache has been evicted from the parent map.
// Checked by Add after acquiring mu to avoid writing into an orphaned cache.
removed bool
}

// Lookup receives a list of requestKeys and a set of pod identifiers,
Expand Down Expand Up @@ -165,47 +168,27 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block
m.engineToRequestKeys.Add(engineKeys[i], requestKey)
}

// 2. Store requestKey -> PodCache mapping
var podCache *PodCache
var found bool

// Try to get existing cache first
podCache, found = m.data.Get(requestKey)
//nolint:nestif // double-checked locking pattern
if !found {
// Create new cache
cache, err := lru.New[PodEntry, struct{}](m.podCacheSize)
if err != nil {
return fmt.Errorf("failed to create pod cache for key %s: %w", requestKey.String(), err)
}

newPodCache := &PodCache{
cache: cache,
// 2. Store requestKey -> PodCache mapping with retry on stale cache.
// A retry is needed only when a concurrent Evict marks the PodCache as
// removed between getOrCreatePodCache and Lock. The window is tiny, so
// this loop almost never iterates more than once.
for {
podCache := m.getOrCreatePodCache(requestKey)

podCache.mu.Lock()
if podCache.removed {
podCache.mu.Unlock()
continue // retry — this cache was evicted
}

// Try to add, but use existing if another thread added it first
// This is a bounded retry (1) - not perfectly safe but for practical use-cases and scenarios
// this should be sufficient
contains, _ := m.data.ContainsOrAdd(requestKey, newPodCache)
if contains {
podCache, found = m.data.Get(requestKey)
if !found { // Extremely irregular workload pattern - key evicted
m.data.Add(requestKey, newPodCache)
podCache = newPodCache
}
} else {
// We successfully added our cache
podCache = newPodCache
for _, entry := range entries {
podCache.cache.Add(entry, struct{}{})
}
}
podCache.mu.Unlock()

podCache.mu.Lock()
for _, entry := range entries {
podCache.cache.Add(entry, struct{}{})
traceLogger.Info("added pods to key", "requestKey", requestKey, "pods", entries)
break
}
podCache.mu.Unlock()

traceLogger.Info("added pods to key", "requestKey", requestKey, "pods", entries)
}

return nil
Expand Down Expand Up @@ -251,41 +234,36 @@ func (m *InMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType KeyTyp
}

podCache.mu.Lock()
prevLen := podCache.cache.Len()
for _, entry := range entries {
podCache.cache.Remove(entry)
}

isEmpty := podCache.cache.Len() == 0
podCache.mu.Unlock()

traceLogger.Info("evicted pods from key", "requestKey", requestKey, "key", key, "keyType", keyType, "pods", entries)

// Remove key from main cache if empty.
// Re-fetch and hold the lock through removal to prevent racing with Add.
if !isEmpty {
return nil
}

currentCache, stillExists := m.data.Get(requestKey)
if !stillExists || currentCache == nil {
return nil
}

currentCache.mu.Lock()
if currentCache.cache.Len() == 0 {
m.data.Remove(requestKey)
// Only mark as removed if this Evict actually emptied the cache.
// If the cache was already empty (prevLen == 0), a concurrent Add may have
// just created it — marking it removed would cause Add to spin.
if podCache.cache.Len() == 0 && prevLen > 0 {
podCache.removed = true
// Use Peek + pointer equality to avoid removing a replacement PodCache
// that a concurrent Add may have inserted.
if cur, ok := m.data.Peek(requestKey); ok && cur == podCache {
m.data.Remove(requestKey)
}
if hasEngineKeyMapping {
m.engineToRequestKeys.Remove(key)
}
traceLogger.Info("removed requestKey from index as no pods remain", "requestKey", requestKey, "key", key)
}
currentCache.mu.Unlock()
podCache.mu.Unlock()

traceLogger.Info("evicted pods from key", "requestKey", requestKey, "key", key, "keyType", keyType, "pods", entries)

return nil
}

// GetRequestKey returns the requestKey associated with the given engineKey.
// Returns an error if the engineKey mapping is missing (e.g., already evicted).
// No external lock needed — lru.Cache is internally thread-safe.
func (m *InMemoryIndex) GetRequestKey(ctx context.Context, engineKey BlockHash) (BlockHash, error) {
requestKey, found := m.engineToRequestKeys.Get(engineKey)
if !found {
Expand All @@ -294,6 +272,28 @@ func (m *InMemoryIndex) GetRequestKey(ctx context.Context, engineKey BlockHash)
return requestKey, nil
}

// getOrCreatePodCache returns the existing PodCache for requestKey,
// or creates and inserts a new one if none exists.
func (m *InMemoryIndex) getOrCreatePodCache(requestKey BlockHash) *PodCache {
if podCache, found := m.data.Get(requestKey); found {
return podCache
}

cache, _ := lru.New[PodEntry, struct{}](m.podCacheSize) //nolint:errcheck // size is always > 0
newPodCache := &PodCache{cache: cache}

// Try to add atomically; if another goroutine beat us, use theirs.
if contains, _ := m.data.ContainsOrAdd(requestKey, newPodCache); contains {
if existing, ok := m.data.Get(requestKey); ok {
return existing
}
// Key was evicted between ContainsOrAdd and Get — use ours.
m.data.Add(requestKey, newPodCache)
}

return newPodCache
}

// podsPerKeyPrintHelper formats a map of keys to pod names for printing.
func podsPerKeyPrintHelper(ks map[BlockHash][]PodEntry) string {
var b strings.Builder
Expand Down
147 changes: 146 additions & 1 deletion pkg/kvcache/kvblock/in_memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ limitations under the License.
package kvblock_test

import (
"fmt"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/sets"

. "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock"
"github.com/llm-d/llm-d-kv-cache/pkg/utils/logging"
Expand Down Expand Up @@ -223,7 +227,148 @@ func TestAddWithNilEngineKeys(t *testing.T) {
assert.Error(t, err, "GetRequestKey should fail since no engineKey mapping was created")
}

// TestPodEntryString tests the String() method with and without Annotation.
// TestConcurrentAddEvictToEmpty is a regression test for issue #421.
// It reproduces the race between a concurrent Add and an Evict that empties
// the PodCache on the same key. Without the fix, the Evict could remove the
// PodCache from the map after Add fetched it but before Add wrote into it,
// causing the newly added entries to be orphaned and lost.
func TestConcurrentAddEvictToEmpty(t *testing.T) {
ctx := logging.NewTestLoggerIntoContext(t.Context())

const iterations = 500

for iteration := 0; iteration < iterations; iteration++ {
cfg := DefaultInMemoryIndexConfig()
cfg.PodCacheSize = 10
index, err := NewInMemoryIndex(cfg)
require.NoError(t, err)

engineKey := BlockHash(11111111)
requestKey := BlockHash(22222222)
seedPod := PodEntry{PodIdentifier: "seed", DeviceTier: "gpu"}
survivorPod := PodEntry{PodIdentifier: fmt.Sprintf("survivor-%d", iteration), DeviceTier: "gpu"}

// Pre-populate with a single pod so Evict can empty the cache.
err = index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, []PodEntry{seedPod})
require.NoError(t, err)

var wg sync.WaitGroup
wg.Add(2)

// Goroutine 1: Evict the seed pod, making the cache empty.
// This triggers PodCache removal from the map.
go func() {
defer wg.Done()
//nolint:errcheck // best-effort eviction in concurrent test
index.Evict(ctx, engineKey, EngineKey, []PodEntry{seedPod})
}()

// Goroutine 2: Add a new pod to the same key concurrently.
go func() {
defer wg.Done()
//nolint:errcheck // best-effort add in concurrent test
index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, []PodEntry{survivorPod})
}()

wg.Wait()

// The survivor pod must be findable. Before the fix, if Evict ran
// between Add's Get and Add's write, the survivor would be written
// into an orphaned PodCache and lost.
podsPerKey, err := index.Lookup(ctx, []BlockHash{requestKey}, sets.Set[string]{})
require.NoError(t, err)

found := false
for _, pod := range podsPerKey[requestKey] {
if pod.PodIdentifier == survivorPod.PodIdentifier {
found = true
break
}
}
assert.True(t, found, "iteration %d: survivor pod %q was lost — Add wrote into an orphaned PodCache",
iteration, survivorPod.PodIdentifier)
}
}

// TestConcurrentAddWithStaleEvicts verifies that a flood of Evicts on the same
// key cannot cause Add to spin indefinitely. This guards the prevLen > 0 check
// in Evict: an Evict that finds an already-empty PodCache must NOT mark it as
// removed, otherwise Add's retry loop would keep creating new PodCaches that
// get immediately invalidated.
func TestConcurrentAddWithStaleEvicts(t *testing.T) {
ctx := logging.NewTestLoggerIntoContext(t.Context())

cfg := DefaultInMemoryIndexConfig()
cfg.PodCacheSize = 10
index, err := NewInMemoryIndex(cfg)
require.NoError(t, err)

engineKey := BlockHash(99999999)
requestKey := BlockHash(88888888)
targetPod := PodEntry{PodIdentifier: "target", DeviceTier: "gpu"}
stalePod := PodEntry{PodIdentifier: "stale", DeviceTier: "gpu"}

// We need the engineKey→requestKey mapping to exist so Evict can resolve
// the engineKey. Seed it with a throwaway Add, then evict the entry.
err = index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, []PodEntry{stalePod})
require.NoError(t, err)
err = index.Evict(ctx, engineKey, EngineKey, []PodEntry{stalePod})
require.NoError(t, err)
// Re-establish the engineKey mapping (Evict removed it when cache emptied).
err = index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, []PodEntry{stalePod})
require.NoError(t, err)

// Launch many goroutines that continuously Evict a pod from the same key.
// If Evict incorrectly marks freshly-created empty PodCaches as removed,
// the concurrent Add below would spin forever.
const evictors = 20
stop := make(chan struct{})
var wg sync.WaitGroup

for evictor := 0; evictor < evictors; evictor++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stop:
return
default:
//nolint:errcheck // best-effort eviction in concurrent test
index.Evict(ctx, engineKey, EngineKey, []PodEntry{stalePod})
}
}
}()
}

// Add must complete promptly despite the flood of concurrent Evicts.
done := make(chan struct{})
go func() {
//nolint:errcheck // best-effort add in concurrent test
index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, []PodEntry{targetPod})
close(done)
}()

// 2 seconds is more than enough — a non-spinning Add finishes in microseconds.
timeout := time.NewTimer(2 * time.Second)
defer timeout.Stop()

select {
case <-done:
close(stop)
wg.Wait()

podsPerKey, lookupErr := index.Lookup(ctx, []BlockHash{requestKey}, sets.Set[string]{})
require.NoError(t, lookupErr)
assert.Contains(t, podsPerKey[requestKey], targetPod,
"target pod must be present after Add completes")
case <-timeout.C:
close(stop)
wg.Wait()
t.Fatal("Add did not complete within 2s — likely spinning due to stale Evicts marking empty PodCaches as removed")
}
}

func TestPodEntryString(t *testing.T) {
confirmed := PodEntry{PodIdentifier: "10.0.0.1:8080", DeviceTier: "gpu"}
assert.Equal(t, "10.0.0.1:8080@gpu", confirmed.String())
Expand Down
Loading