Skip to content

Commit eaa7c73

Browse files
authored
feat: add speculative indexing support for PodEntry and Indexer (#369)
* feat: speculative indexing support * refactor: flatten nested cleanup logic in InMemoryIndex.Evict * refactor: replace Annotation string with Annotations struct in PodEntry * fix: lint * refactor: use strconv for BlockHash string representation * feat: add speculative indexing support to RedisIndex * refactor: simplify nested if in Evict() with early returns * refactor: simplify PodEntry Annotations to Speculative bool * feat: add explicit KeyType parameter to Evict() interface * fix: suppress nilerr lint in Redis Evict for missing engine key * fix: handle nil engineKeys in CostAwareMemoryIndex.Add and add common test * refactor: delegate GetPodScores to ScoreTokens GetPodScores duplicated the tokenize→lookup→score logic that ScoreTokens already provides. Reduce it to a thin wrapper that tokenizes, truncates, and delegates to ScoreTokens. Fixes rebase regression from PR 415. * fix: restore traceLogger lines in ScoreTokens * fix: restore block hit ratio span attributes in ScoreTokens
1 parent 36d7117 commit eaa7c73

File tree

12 files changed

+353
-77
lines changed

12 files changed

+353
-77
lines changed

examples/valkey_example/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer)
188188

189189
// Demonstrate eviction
190190
logger.Info("Demonstrating cache eviction")
191-
err = indexer.KVBlockIndex().Evict(ctx, promptKeys[0], podEntries[:1])
191+
err = indexer.KVBlockIndex().Evict(ctx, promptKeys[0], kvblock.EngineKey, podEntries[:1])
192192
if err != nil {
193193
return fmt.Errorf("failed to evict cache entry: %w", err)
194194
}

pkg/kvcache/indexer.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,36 @@ func (k *Indexer) KVBlockIndex() kvblock.Index {
123123
return k.kvBlockIndex
124124
}
125125

126+
// ComputeBlockKeys computes the KV-block keys for a given prompt and model name.
127+
// This method extracts the tokenization and block key computation logic so that
128+
// callers (e.g., IGW::EPP::PrepareDataPlugin) can compute block keys once and reuse them
129+
// across multiple extension points without re-tokenizing.
130+
func (k *Indexer) ComputeBlockKeys(ctx context.Context, renderReq *types.RenderChatRequest, prompt, modelName string,
131+
) ([]kvblock.BlockHash, error) {
132+
traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvcache.ComputeBlockKeys")
133+
134+
// 1. tokenize prompt
135+
tokens := k.tokenizersPool.Tokenize(renderReq, prompt)
136+
137+
// 2. Truncate prompt (if set in the request)
138+
if renderReq != nil && renderReq.TruncatePromptTokens != nil {
139+
limit := *renderReq.TruncatePromptTokens
140+
if limit > 0 && len(tokens) > limit {
141+
tokens = tokens[len(tokens)-limit:]
142+
}
143+
}
144+
145+
// 3. get block keys
146+
blockKeys := k.tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, modelName)
147+
if len(blockKeys) == 0 {
148+
traceLogger.Info("no block keys found")
149+
return nil, nil
150+
}
151+
traceLogger.Info("computed block keys", "tokens", tokens, "block-keys", blockKeys)
152+
153+
return blockKeys, nil
154+
}
155+
126156
// GetPodScores retrieves the pod scores for a given prompt and model name.
127157
// The function receives the mentioned information and a list of relevant pod
128158
// identifiers. A Pod identifier should be its address.
@@ -159,9 +189,8 @@ func (k *Indexer) ScoreTokens(
159189
modelName string,
160190
podIdentifiers []string,
161191
) (map[string]float64, error) {
162-
// Start tracing span for main operation
163192
tracer := otel.Tracer(telemetry.InstrumentationName)
164-
ctx, span := tracer.Start(ctx, "llm_d.kv_cache.get_scores",
193+
ctx, span := tracer.Start(ctx, "llm_d.kv_cache.score_tokens",
165194
trace.WithSpanKind(trace.SpanKindInternal),
166195
)
167196
defer span.End()
@@ -170,21 +199,20 @@ func (k *Indexer) ScoreTokens(
170199

171200
blockKeys := k.tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, modelName)
172201

173-
// Set initial attributes
174202
span.SetAttributes(
175203
attribute.String("gen_ai.request.model", modelName),
176204
attribute.Int("llm_d.kv_cache.pod_count", len(podIdentifiers)),
177205
attribute.Int("llm_d.kv_cache.token_count", len(tokens)),
206+
attribute.Int("llm_d.kv_cache.block_keys.count", len(blockKeys)),
178207
)
179-
span.SetAttributes(attribute.Int("llm_d.kv_cache.block_keys.count", len(blockKeys)))
208+
180209
if len(blockKeys) == 0 {
181210
traceLogger.Info("no block keys found, returning empty scores")
182211
//nolint:nilnil // no need to return an error
183212
return nil, nil
184213
}
185214
traceLogger.Info("found tokens", "tokens", tokens, "block-keys", blockKeys)
186215

187-
// query kvblock indexer for pods
188216
keyToPods, err := k.kvBlockIndex.Lookup(ctx, blockKeys, sets.New(podIdentifiers...))
189217
if err != nil {
190218
span.SetStatus(codes.Error, err.Error())
@@ -209,13 +237,11 @@ func (k *Indexer) ScoreTokens(
209237
attribute.Int("llm_d.kv_cache.blocks_found", blocksFound),
210238
)
211239

212-
// 5. score pods
213240
podScores, err := k.kvBlockScorer.Score(ctx, blockKeys, keyToPods)
214241
if err != nil {
215242
span.SetStatus(codes.Error, err.Error())
216243
return nil, fmt.Errorf("failed to query kvblock scorer: %w", err)
217244
}
218-
traceLogger.Info("found pod scores", "pod-scores", podScores)
219245

220246
return podScores, nil
221247
}

pkg/kvcache/kvblock/cost_aware_memory.go

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,26 @@ func (c *CostPodCache) CalculateByteSize(keyStr string) int64 {
166166
var _ Index = &CostAwareMemoryIndex{}
167167

168168
// Add adds a set of keys and their associated pod entries to the index backend.
169+
// If engineKeys is nil, only requestKey -> PodEntry mappings are created (no engineKey -> requestKey mapping).
170+
// This is used for speculative entries where engine keys are not yet known.
169171
func (m *CostAwareMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []BlockHash, entries []PodEntry) error {
170172
m.mu.Lock()
171173
defer m.mu.Unlock()
172174

173-
if len(engineKeys) == 0 || len(requestKeys) == 0 || len(entries) == 0 {
175+
if len(requestKeys) == 0 || len(entries) == 0 {
174176
return fmt.Errorf("no keys or entries provided for adding to index")
175177
}
176-
if len(engineKeys) != len(requestKeys) {
178+
if engineKeys != nil && len(engineKeys) != len(requestKeys) {
177179
return fmt.Errorf("mismatch between engine keys and request keys length")
178180
}
179181

180182
traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvblock.CostAwareMemoryIndex.Add")
181183

182184
for i, requestKey := range requestKeys {
183-
engineKey := engineKeys[i]
184-
185-
// Store engineKey -> requestKey mapping
186-
m.requestKeys.Add(engineKey, requestKey)
185+
// Store engineKey -> requestKey mapping (only if engineKeys provided)
186+
if engineKeys != nil {
187+
m.requestKeys.Add(engineKeys[i], requestKey)
188+
}
187189

188190
keyStr := requestKey.String()
189191
podCache, found := m.data.Get(keyStr)
@@ -198,7 +200,7 @@ func (m *CostAwareMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys
198200
// Calculate the actual cost for this cache entry
199201
cost := podCache.CalculateByteSize(keyStr)
200202
m.data.Set(keyStr, podCache, cost)
201-
traceLogger.Info("added pods to key", "requestKey", requestKey, "engineKey", engineKey, "pods", entries, "cost-bytes", cost)
203+
traceLogger.Info("added pods to key", "requestKey", requestKey, "pods", entries, "cost-bytes", cost)
202204
}
203205
m.data.Wait()
204206
return nil
@@ -260,7 +262,9 @@ func (m *CostAwareMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHa
260262
}
261263

262264
// Evict removes a key and its associated pod entries from the index backend.
263-
func (m *CostAwareMemoryIndex) Evict(ctx context.Context, engineKey BlockHash, entries []PodEntry) error {
265+
// keyType indicates whether the key is an EngineKey (requires engine→request lookup)
266+
// or a RequestKey (used directly for speculative entries without engineKey mapping).
267+
func (m *CostAwareMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType KeyType, entries []PodEntry) error {
264268
m.mu.Lock()
265269
defer m.mu.Unlock()
266270

@@ -270,17 +274,33 @@ func (m *CostAwareMemoryIndex) Evict(ctx context.Context, engineKey BlockHash, e
270274

271275
traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvblock.CostAwareMemoryIndex.Evict")
272276

273-
requestKey, found := m.requestKeys.Get(engineKey)
274-
if !found {
275-
traceLogger.Info("engineKey not found in index, nothing to evict", "engineKey", engineKey)
276-
return nil
277+
var requestKey BlockHash
278+
hasEngineKeyMapping := false
279+
280+
switch keyType {
281+
case EngineKey:
282+
rk, found := m.requestKeys.Get(key)
283+
if !found {
284+
traceLogger.Info("engineKey not found in mapping, nothing to evict", "engineKey", key)
285+
return nil
286+
}
287+
requestKey = rk
288+
hasEngineKeyMapping = true
289+
case RequestKey:
290+
requestKey = key
291+
default:
292+
return fmt.Errorf("unknown key type: %d", keyType)
277293
}
278294

279295
keyStr := requestKey.String()
280296
podCache, found := m.data.Get(keyStr)
281297
if !found || podCache == nil {
282-
traceLogger.Info("requestKey not found in index, cleaning up engineKey", "requestKey", requestKey, "engineKey", engineKey)
283-
m.requestKeys.Remove(engineKey)
298+
if hasEngineKeyMapping {
299+
traceLogger.Info("requestKey not found in index, cleaning up engineKey", "requestKey", requestKey, "engineKey", key)
300+
m.requestKeys.Remove(key)
301+
} else {
302+
traceLogger.Info("key not found in index, nothing to evict", "key", key)
303+
}
284304
return nil
285305
}
286306

@@ -292,12 +312,13 @@ func (m *CostAwareMemoryIndex) Evict(ctx context.Context, engineKey BlockHash, e
292312

293313
if podCache.Len() == 0 {
294314
m.data.Del(keyStr)
295-
m.requestKeys.Remove(engineKey)
296-
297-
traceLogger.Info("removed requestKey from index as no pods remain", "requestKey", requestKey, "engineKey", engineKey)
315+
if hasEngineKeyMapping {
316+
m.requestKeys.Remove(key)
317+
}
318+
traceLogger.Info("removed requestKey from index as no pods remain", "requestKey", requestKey, "key", key)
298319
} else if podCacheLenBefore != podCache.Len() {
299320
m.data.Set(keyStr, podCache, podCache.CalculateByteSize(keyStr))
300-
traceLogger.Info("evicted pods from engineKey", "requestKey", requestKey, "engineKey", engineKey, "pods", entries)
321+
traceLogger.Info("evicted pods from key", "requestKey", requestKey, "key", key, "keyType", keyType, "pods", entries)
301322
}
302323
m.data.Wait()
303324
return nil

pkg/kvcache/kvblock/in_memory.go

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -147,21 +147,23 @@ func (m *InMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHash,
147147
}
148148

149149
// Add adds a set of engineKeys/requestKeys and their associated pod entries to the index backend.
150+
// If engineKeys is nil, only requestKey -> PodEntry mappings are created (no engineKey -> requestKey mapping).
151+
// This is used for speculative entries where engine keys are not yet known.
150152
func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []BlockHash, entries []PodEntry) error {
151-
if len(engineKeys) == 0 || len(requestKeys) == 0 || len(entries) == 0 {
153+
if len(requestKeys) == 0 || len(entries) == 0 {
152154
return fmt.Errorf("no keys or entries provided for adding to index")
153155
}
154-
if len(engineKeys) != len(requestKeys) {
156+
if engineKeys != nil && len(engineKeys) != len(requestKeys) {
155157
return fmt.Errorf("mismatch between engine keys and request keys length")
156158
}
157159

158160
traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvblock.InMemoryIndex.Add")
159161

160162
for i, requestKey := range requestKeys {
161-
engineKey := engineKeys[i]
162-
163-
// 1. Store engineKey -> requestKey mapping
164-
m.engineToRequestKeys.Add(engineKey, requestKey)
163+
// 1. Store engineKey -> requestKey mapping (only if engineKeys provided)
164+
if engineKeys != nil {
165+
m.engineToRequestKeys.Add(engineKeys[i], requestKey)
166+
}
165167

166168
// 2. Store requestKey -> PodCache mapping
167169
var podCache *PodCache
@@ -203,30 +205,48 @@ func (m *InMemoryIndex) Add(ctx context.Context, engineKeys, requestKeys []Block
203205
}
204206
podCache.mu.Unlock()
205207

206-
traceLogger.Info("added pods to key", "requestKey", requestKey, "engineKey", engineKey, "pods", entries)
208+
traceLogger.Info("added pods to key", "requestKey", requestKey, "pods", entries)
207209
}
208210

209211
return nil
210212
}
211213

212-
// Evict removes a engineKey and its associated pod entries from the index backend.
213-
func (m *InMemoryIndex) Evict(ctx context.Context, engineKey BlockHash, entries []PodEntry) error {
214+
// Evict removes a key and its associated pod entries from the index backend.
215+
// keyType indicates whether the key is an EngineKey (requires engine→request lookup)
216+
// or a RequestKey (used directly for speculative entries without engineKey mapping).
217+
func (m *InMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType KeyType, entries []PodEntry) error {
214218
if len(entries) == 0 {
215219
return fmt.Errorf("no entries provided for eviction from index")
216220
}
217221

218222
traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvblock.InMemoryIndex.Evict")
219223

220-
requestKey, found := m.engineToRequestKeys.Get(engineKey)
221-
if !found {
222-
traceLogger.Info("engineKey not found in index, nothing to evict", "engineKey", engineKey)
223-
return nil
224+
var requestKey BlockHash
225+
hasEngineKeyMapping := false
226+
227+
switch keyType {
228+
case EngineKey:
229+
rk, found := m.engineToRequestKeys.Get(key)
230+
if !found {
231+
traceLogger.Info("engineKey not found in mapping, nothing to evict", "engineKey", key)
232+
return nil
233+
}
234+
requestKey = rk
235+
hasEngineKeyMapping = true
236+
case RequestKey:
237+
requestKey = key
238+
default:
239+
return fmt.Errorf("unknown key type: %d", keyType)
224240
}
225241

226242
podCache, found := m.data.Get(requestKey)
227243
if !found || podCache == nil {
228-
traceLogger.Info("requestKey not found in index, cleaning up engineKey", "requestKey", requestKey, "engineKey", engineKey)
229-
m.engineToRequestKeys.Remove(engineKey)
244+
if hasEngineKeyMapping {
245+
traceLogger.Info("requestKey not found in index, cleaning up engineKey", "requestKey", requestKey, "engineKey", key)
246+
m.engineToRequestKeys.Remove(key)
247+
} else {
248+
traceLogger.Info("key not found in index, nothing to evict", "key", key)
249+
}
230250
return nil
231251
}
232252

@@ -238,21 +258,28 @@ func (m *InMemoryIndex) Evict(ctx context.Context, engineKey BlockHash, entries
238258
isEmpty := podCache.cache.Len() == 0
239259
podCache.mu.Unlock()
240260

241-
traceLogger.Info("evicted pods from key", "requestKey", requestKey, "engineKey", engineKey, "pods", entries)
242-
243-
// Remove key from main cache if empty
244-
if isEmpty {
245-
// Re-fetch and hold the lock through removal to prevent racing with Add
246-
if currentCache, stillExists := m.data.Get(requestKey); stillExists && currentCache != nil {
247-
currentCache.mu.Lock()
248-
if currentCache.cache.Len() == 0 {
249-
m.data.Remove(requestKey)
250-
m.engineToRequestKeys.Remove(engineKey)
251-
traceLogger.Info("removed requestKey from index as no pods remain", "requestKey", requestKey, "engineKey", engineKey)
252-
}
253-
currentCache.mu.Unlock()
261+
traceLogger.Info("evicted pods from key", "requestKey", requestKey, "key", key, "keyType", keyType, "pods", entries)
262+
263+
// Remove key from main cache if empty.
264+
// Re-fetch and hold the lock through removal to prevent racing with Add.
265+
if !isEmpty {
266+
return nil
267+
}
268+
269+
currentCache, stillExists := m.data.Get(requestKey)
270+
if !stillExists || currentCache == nil {
271+
return nil
272+
}
273+
274+
currentCache.mu.Lock()
275+
if currentCache.cache.Len() == 0 {
276+
m.data.Remove(requestKey)
277+
if hasEngineKeyMapping {
278+
m.engineToRequestKeys.Remove(key)
254279
}
280+
traceLogger.Info("removed requestKey from index as no pods remain", "requestKey", requestKey, "key", key)
255281
}
282+
currentCache.mu.Unlock()
256283

257284
return nil
258285
}

0 commit comments

Comments
 (0)