Skip to content

Commit cece48c

Browse files
committed
fix: Close data race in InMemoryIndex Add/Evict with RWMutex
Signed-off-by: Guangya Liu <gyliu513@gmail.com>
1 parent 8b7855c commit cece48c

File tree

2 files changed

+197
-58
lines changed

2 files changed

+197
-58
lines changed

pkg/kvcache/kvblock/in_memory.go

Lines changed: 56 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,12 @@ var _ Index = &InMemoryIndex{}
8989
// PodCache represents a cache for pod entries.
9090
type PodCache struct {
9191
// cache is an LRU cache that maps PodEntry to their last access time.
92-
// thread-safe.
9392
cache *lru.Cache[PodEntry, struct{}]
9493
// mu protects the cache from concurrent access during check-and-set operations.
9594
mu sync.Mutex
95+
// removed indicates this PodCache has been evicted from the parent map.
96+
// Checked by Add after acquiring mu to avoid writing into an orphaned cache.
97+
removed bool
9698
}
9799

98100
// Lookup receives a list of requestKeys and a set of pod identifiers,
@@ -165,47 +167,27 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block
165167
m.engineToRequestKeys.Add(engineKeys[i], requestKey)
166168
}
167169

168-
// 2. Store requestKey -> PodCache mapping
169-
var podCache *PodCache
170-
var found bool
171-
172-
// Try to get existing cache first
173-
podCache, found = m.data.Get(requestKey)
174-
//nolint:nestif // double-checked locking pattern
175-
if !found {
176-
// Create new cache
177-
cache, err := lru.New[PodEntry, struct{}](m.podCacheSize)
178-
if err != nil {
179-
return fmt.Errorf("failed to create pod cache for key %s: %w", requestKey.String(), err)
180-
}
181-
182-
newPodCache := &PodCache{
183-
cache: cache,
170+
// 2. Store requestKey -> PodCache mapping with retry on stale cache.
171+
// A retry is needed only when a concurrent Evict marks the PodCache as
172+
// removed between getOrCreatePodCache and Lock. The window is tiny, so
173+
// this loop almost never iterates more than once.
174+
for {
175+
podCache := m.getOrCreatePodCache(requestKey)
176+
177+
podCache.mu.Lock()
178+
if podCache.removed {
179+
podCache.mu.Unlock()
180+
continue // retry — this cache was evicted
184181
}
185182

186-
// Try to add, but use existing if another thread added it first
187-
// This is a bounded retry (1) - not perfectly safe but for practical use-cases and scenarios
188-
// this should be sufficient
189-
contains, _ := m.data.ContainsOrAdd(requestKey, newPodCache)
190-
if contains {
191-
podCache, found = m.data.Get(requestKey)
192-
if !found { // Extremely irregular workload pattern - key evicted
193-
m.data.Add(requestKey, newPodCache)
194-
podCache = newPodCache
195-
}
196-
} else {
197-
// We successfully added our cache
198-
podCache = newPodCache
183+
for _, entry := range entries {
184+
podCache.cache.Add(entry, struct{}{})
199185
}
200-
}
186+
podCache.mu.Unlock()
201187

202-
podCache.mu.Lock()
203-
for _, entry := range entries {
204-
podCache.cache.Add(entry, struct{}{})
188+
traceLogger.Info("added pods to key", "requestKey", requestKey, "pods", entries)
189+
break
205190
}
206-
podCache.mu.Unlock()
207-
208-
traceLogger.Info("added pods to key", "requestKey", requestKey, "pods", entries)
209191
}
210192

211193
return nil
@@ -251,41 +233,36 @@ func (m *InMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType KeyTyp
251233
}
252234

253235
podCache.mu.Lock()
236+
prevLen := podCache.cache.Len()
254237
for _, entry := range entries {
255238
podCache.cache.Remove(entry)
256239
}
257240

258-
isEmpty := podCache.cache.Len() == 0
259-
podCache.mu.Unlock()
260-
261-
traceLogger.Info("evicted pods from key", "requestKey", requestKey, "key", key, "keyType", keyType, "pods", entries)
262-
263-
// Remove key from main cache if empty.
264-
// Re-fetch and hold the lock through removal to prevent racing with Add.
265-
if !isEmpty {
266-
return nil
267-
}
268-
269-
currentCache, stillExists := m.data.Get(requestKey)
270-
if !stillExists || currentCache == nil {
271-
return nil
272-
}
273-
274-
currentCache.mu.Lock()
275-
if currentCache.cache.Len() == 0 {
276-
m.data.Remove(requestKey)
241+
// Only mark as removed if this Evict actually emptied the cache.
242+
// If the cache was already empty (prevLen == 0), a concurrent Add may have
243+
// just created it — marking it removed would cause Add to spin.
244+
if podCache.cache.Len() == 0 && prevLen > 0 {
245+
podCache.removed = true
246+
// Use Peek + pointer equality to avoid removing a replacement PodCache
247+
// that a concurrent Add may have inserted.
248+
if cur, ok := m.data.Peek(requestKey); ok && cur == podCache {
249+
m.data.Remove(requestKey)
250+
}
277251
if hasEngineKeyMapping {
278252
m.engineToRequestKeys.Remove(key)
279253
}
280254
traceLogger.Info("removed requestKey from index as no pods remain", "requestKey", requestKey, "key", key)
281255
}
282-
currentCache.mu.Unlock()
256+
podCache.mu.Unlock()
257+
258+
traceLogger.Info("evicted pods from key", "requestKey", requestKey, "key", key, "keyType", keyType, "pods", entries)
283259

284260
return nil
285261
}
286262

287263
// GetRequestKey returns the requestKey associated with the given engineKey.
288264
// Returns an error if the engineKey mapping is missing (e.g., already evicted).
265+
// No external lock needed — lru.Cache is internally thread-safe.
289266
func (m *InMemoryIndex) GetRequestKey(ctx context.Context, engineKey BlockHash) (BlockHash, error) {
290267
requestKey, found := m.engineToRequestKeys.Get(engineKey)
291268
if !found {
@@ -294,6 +271,28 @@ func (m *InMemoryIndex) GetRequestKey(ctx context.Context, engineKey BlockHash)
294271
return requestKey, nil
295272
}
296273

274+
// getOrCreatePodCache returns the existing PodCache for requestKey,
275+
// or creates and inserts a new one if none exists.
276+
func (m *InMemoryIndex) getOrCreatePodCache(requestKey BlockHash) *PodCache {
277+
if podCache, found := m.data.Get(requestKey); found {
278+
return podCache
279+
}
280+
281+
cache, _ := lru.New[PodEntry, struct{}](m.podCacheSize)
282+
newPodCache := &PodCache{cache: cache}
283+
284+
// Try to add atomically; if another goroutine beat us, use theirs.
285+
if contains, _ := m.data.ContainsOrAdd(requestKey, newPodCache); contains {
286+
if existing, ok := m.data.Get(requestKey); ok {
287+
return existing
288+
}
289+
// Key was evicted between ContainsOrAdd and Get — use ours.
290+
m.data.Add(requestKey, newPodCache)
291+
}
292+
293+
return newPodCache
294+
}
295+
297296
// podsPerKeyPrintHelper formats a map of keys to pod names for printing.
298297
func podsPerKeyPrintHelper(ks map[BlockHash][]PodEntry) string {
299298
var b strings.Builder

pkg/kvcache/kvblock/in_memory_test.go

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ limitations under the License.
1717
package kvblock_test
1818

1919
import (
20+
"fmt"
21+
"sync"
2022
"testing"
23+
"time"
2124

2225
"github.com/stretchr/testify/assert"
2326
"github.com/stretchr/testify/require"
27+
"k8s.io/apimachinery/pkg/util/sets"
2428

2529
. "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock"
2630
"github.com/llm-d/llm-d-kv-cache/pkg/utils/logging"
@@ -223,7 +227,143 @@ func TestAddWithNilEngineKeys(t *testing.T) {
223227
assert.Error(t, err, "GetRequestKey should fail since no engineKey mapping was created")
224228
}
225229

226-
// TestPodEntryString tests the String() method with and without Annotation.
230+
// TestConcurrentAddEvictToEmpty is a regression test for issue #421.
231+
// It reproduces the race between a concurrent Add and an Evict that empties
232+
// the PodCache on the same key. Without the fix, the Evict could remove the
233+
// PodCache from the map after Add fetched it but before Add wrote into it,
234+
// causing the newly added entries to be orphaned and lost.
235+
func TestConcurrentAddEvictToEmpty(t *testing.T) {
236+
ctx := logging.NewTestLoggerIntoContext(t.Context())
237+
238+
const iterations = 500
239+
240+
for i := 0; i < iterations; i++ {
241+
cfg := DefaultInMemoryIndexConfig()
242+
cfg.PodCacheSize = 10
243+
index, err := NewInMemoryIndex(cfg)
244+
require.NoError(t, err)
245+
246+
engineKey := BlockHash(11111111)
247+
requestKey := BlockHash(22222222)
248+
seedPod := PodEntry{PodIdentifier: "seed", DeviceTier: "gpu"}
249+
survivorPod := PodEntry{PodIdentifier: fmt.Sprintf("survivor-%d", i), DeviceTier: "gpu"}
250+
251+
// Pre-populate with a single pod so Evict can empty the cache.
252+
err = index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, []PodEntry{seedPod})
253+
require.NoError(t, err)
254+
255+
var wg sync.WaitGroup
256+
wg.Add(2)
257+
258+
// Goroutine 1: Evict the seed pod, making the cache empty.
259+
// This triggers PodCache removal from the map.
260+
go func() {
261+
defer wg.Done()
262+
_ = index.Evict(ctx, engineKey, EngineKey, []PodEntry{seedPod})
263+
}()
264+
265+
// Goroutine 2: Add a new pod to the same key concurrently.
266+
go func() {
267+
defer wg.Done()
268+
_ = index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, []PodEntry{survivorPod})
269+
}()
270+
271+
wg.Wait()
272+
273+
// The survivor pod must be findable. Before the fix, if Evict ran
274+
// between Add's Get and Add's write, the survivor would be written
275+
// into an orphaned PodCache and lost.
276+
podsPerKey, err := index.Lookup(ctx, []BlockHash{requestKey}, sets.Set[string]{})
277+
require.NoError(t, err)
278+
279+
found := false
280+
for _, pod := range podsPerKey[requestKey] {
281+
if pod.PodIdentifier == survivorPod.PodIdentifier {
282+
found = true
283+
break
284+
}
285+
}
286+
assert.True(t, found, "iteration %d: survivor pod %q was lost — Add wrote into an orphaned PodCache",
287+
i, survivorPod.PodIdentifier)
288+
}
289+
}
290+
291+
// TestConcurrentAddWithStaleEvicts verifies that a flood of Evicts on the same
292+
// key cannot cause Add to spin indefinitely. This guards the prevLen > 0 check
293+
// in Evict: an Evict that finds an already-empty PodCache must NOT mark it as
294+
// removed, otherwise Add's retry loop would keep creating new PodCaches that
295+
// get immediately invalidated.
296+
func TestConcurrentAddWithStaleEvicts(t *testing.T) {
297+
ctx := logging.NewTestLoggerIntoContext(t.Context())
298+
299+
cfg := DefaultInMemoryIndexConfig()
300+
cfg.PodCacheSize = 10
301+
index, err := NewInMemoryIndex(cfg)
302+
require.NoError(t, err)
303+
304+
engineKey := BlockHash(99999999)
305+
requestKey := BlockHash(88888888)
306+
targetPod := PodEntry{PodIdentifier: "target", DeviceTier: "gpu"}
307+
stalePod := PodEntry{PodIdentifier: "stale", DeviceTier: "gpu"}
308+
309+
// We need the engineKey→requestKey mapping to exist so Evict can resolve
310+
// the engineKey. Seed it with a throwaway Add, then evict the entry.
311+
err = index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, []PodEntry{stalePod})
312+
require.NoError(t, err)
313+
err = index.Evict(ctx, engineKey, EngineKey, []PodEntry{stalePod})
314+
require.NoError(t, err)
315+
// Re-establish the engineKey mapping (Evict removed it when cache emptied).
316+
err = index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, []PodEntry{stalePod})
317+
require.NoError(t, err)
318+
319+
// Launch many goroutines that continuously Evict a pod from the same key.
320+
// If Evict incorrectly marks freshly-created empty PodCaches as removed,
321+
// the concurrent Add below would spin forever.
322+
const evictors = 20
323+
stop := make(chan struct{})
324+
var wg sync.WaitGroup
325+
326+
for i := 0; i < evictors; i++ {
327+
wg.Add(1)
328+
go func() {
329+
defer wg.Done()
330+
for {
331+
select {
332+
case <-stop:
333+
return
334+
default:
335+
_ = index.Evict(ctx, engineKey, EngineKey, []PodEntry{stalePod})
336+
}
337+
}
338+
}()
339+
}
340+
341+
// Add must complete promptly despite the flood of concurrent Evicts.
342+
done := make(chan struct{})
343+
go func() {
344+
_ = index.Add(ctx, []BlockHash{engineKey}, []BlockHash{requestKey}, []PodEntry{targetPod})
345+
close(done)
346+
}()
347+
348+
// 2 seconds is more than enough — a non-spinning Add finishes in microseconds.
349+
timeout := time.NewTimer(2 * time.Second)
350+
defer timeout.Stop()
351+
352+
select {
353+
case <-done:
354+
close(stop)
355+
wg.Wait()
356+
357+
podsPerKey, lookupErr := index.Lookup(ctx, []BlockHash{requestKey}, sets.Set[string]{})
358+
require.NoError(t, lookupErr)
359+
assert.Contains(t, podsPerKey[requestKey], targetPod,
360+
"target pod must be present after Add completes")
361+
case <-timeout.C:
362+
close(stop)
363+
wg.Wait()
364+
t.Fatal("Add did not complete within 2s — likely spinning due to stale Evicts marking empty PodCaches as removed")
365+
}
366+
}
227367
func TestPodEntryString(t *testing.T) {
228368
confirmed := PodEntry{PodIdentifier: "10.0.0.1:8080", DeviceTier: "gpu"}
229369
assert.Equal(t, "10.0.0.1:8080@gpu", confirmed.String())

0 commit comments

Comments
 (0)