Skip to content

Commit 9949b80

Browse files
committed
hma prefix routing
Signed-off-by: Kapil Jain <kapiljain1989@gmail.com> # Conflicts: # pkg/kvcache/indexer.go
1 parent 738f18b commit 9949b80

File tree

12 files changed

+1609
-69
lines changed

12 files changed

+1609
-69
lines changed

pkg/kvcache/export_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ func NewIndexerForTest(tp kvblock.TokenProcessor, idx kvblock.Index, scorer KVBl
3030
tokenizersPool: pool,
3131
}
3232
}
33+
34+
// ContainsGroup exports the private containsGroup function for testing.
35+
func ContainsGroup(groups []int, groupID int) bool {
36+
return containsGroup(groups, groupID)
37+
}

pkg/kvcache/indexer.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ import (
3434
"github.com/llm-d/llm-d-kv-cache/pkg/utils/logging"
3535
)
3636

37+
// AttentionGroupConfig defines the attention window size and type for a group.
38+
type AttentionGroupConfig struct {
39+
WindowSize int `json:"windowSize"` // Attention window size (0 or omit for full attention = no constraint)
40+
AttentionType string `json:"attentionType"` // Attention type (e.g., "full", "sliding", "block_sparse")
41+
BlockSize int `json:"blockSize"` // KV block size for this group (if 0, falls back to ModelConfig.BlockSize)
42+
}
43+
44+
// ModelConfig holds model-specific configuration including block size and attention groups.
45+
type ModelConfig struct {
46+
Name string `json:"name"` // Model name
47+
BlockSize int `json:"blockSize"` // Default KV block size (used when AttentionGroupConfig.BlockSize is 0)
48+
AttentionGroups []AttentionGroupConfig `json:"attentionGroups"` // Multiple attention groups with different window sizes and block sizes
49+
}
50+
3751
// Config holds the configuration for the Indexer module.
3852
// The configuration cover the different components found in the Indexer
3953
// module.
@@ -42,6 +56,7 @@ type Config struct {
4256
KVBlockScorerConfig *KVBlockScorerConfig // not exported
4357
TokenizersPoolConfig *tokenization.Config `json:"tokenizersPoolConfig"`
4458
BackendConfigs []*KVCacheBackendConfig `json:"kvCacheBackendConfigs"`
59+
ModelConfigs []*ModelConfig
4560
}
4661

4762
// NewDefaultConfig returns a default configuration for the Indexer module.
@@ -56,6 +71,7 @@ func NewDefaultConfig() (*Config, error) {
5671
KVBlockScorerConfig: DefaultKVBlockScorerConfig(),
5772
TokenizersPoolConfig: tokenizerPoolConfig,
5873
BackendConfigs: DefaultKVCacheBackendConfig(),
74+
ModelConfigs: nil, // No default model configs - must be explicitly configured
5975
}, nil
6076
}
6177

@@ -293,3 +309,19 @@ func (k *Indexer) SetTokenizer(tokenizer tokenization.Tokenizer, modelName strin
293309
func (k *Indexer) blockSize() int {
294310
return k.tokenProcessor.BlockSize()
295311
}
312+
313+
// GetModelConfig returns the model configuration for the given model name.
314+
// Returns nil if no configuration is found for the model.
315+
func (c *Config) GetModelConfig(modelName string) *ModelConfig {
316+
if c.ModelConfigs == nil {
317+
return nil
318+
}
319+
320+
for _, mc := range c.ModelConfigs {
321+
if mc.Name == modelName {
322+
return mc
323+
}
324+
}
325+
326+
return nil
327+
}

pkg/kvcache/indexer_test.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,128 @@ func TestGetPodScores_TruncateZero(t *testing.T) {
370370
assert.Equal(t, []uint32{1, 2}, tp.receivedTokens,
371371
"token processor should receive all tokens when limit is zero")
372372
}
373+
374+
func TestModelConfig(t *testing.T) {
375+
tests := []struct {
376+
name string
377+
config *kvcache.ModelConfig
378+
wantValid bool
379+
}{
380+
{
381+
name: "valid model config",
382+
config: &kvcache.ModelConfig{
383+
Name: "llama-3-8b",
384+
BlockSize: 16,
385+
AttentionGroups: []kvcache.AttentionGroupConfig{
386+
{WindowSize: 2048, AttentionType: "sliding"},
387+
{WindowSize: 0, AttentionType: "full"},
388+
},
389+
},
390+
wantValid: true,
391+
},
392+
{
393+
name: "single attention group",
394+
config: &kvcache.ModelConfig{
395+
Name: "gpt-4",
396+
BlockSize: 16,
397+
AttentionGroups: []kvcache.AttentionGroupConfig{
398+
{WindowSize: 0, AttentionType: "full"},
399+
},
400+
},
401+
wantValid: true,
402+
},
403+
{
404+
name: "no attention groups",
405+
config: &kvcache.ModelConfig{
406+
Name: "test-model",
407+
BlockSize: 16,
408+
AttentionGroups: []kvcache.AttentionGroupConfig{},
409+
},
410+
wantValid: true,
411+
},
412+
}
413+
414+
for _, tt := range tests {
415+
t.Run(tt.name, func(t *testing.T) {
416+
assert.Equal(t, tt.config.Name, tt.config.Name)
417+
assert.Equal(t, tt.config.BlockSize, tt.config.BlockSize)
418+
assert.Equal(t, len(tt.config.AttentionGroups), len(tt.config.AttentionGroups))
419+
})
420+
}
421+
}
422+
423+
func TestConfig_GetModelConfig(t *testing.T) {
424+
config := &kvcache.Config{
425+
ModelConfigs: []*kvcache.ModelConfig{
426+
{
427+
Name: "llama-3-8b",
428+
BlockSize: 16,
429+
AttentionGroups: []kvcache.AttentionGroupConfig{
430+
{WindowSize: 2048, AttentionType: "sliding"},
431+
{WindowSize: 0, AttentionType: "full"},
432+
},
433+
},
434+
{
435+
Name: "gpt-4",
436+
BlockSize: 32,
437+
AttentionGroups: []kvcache.AttentionGroupConfig{
438+
{WindowSize: 0, AttentionType: "full"},
439+
},
440+
},
441+
},
442+
}
443+
444+
tests := []struct {
445+
name string
446+
modelName string
447+
wantFound bool
448+
wantModel *kvcache.ModelConfig
449+
}{
450+
{
451+
name: "find llama-3-8b",
452+
modelName: "llama-3-8b",
453+
wantFound: true,
454+
wantModel: config.ModelConfigs[0],
455+
},
456+
{
457+
name: "find gpt-4",
458+
modelName: "gpt-4",
459+
wantFound: true,
460+
wantModel: config.ModelConfigs[1],
461+
},
462+
{
463+
name: "model not found",
464+
modelName: "unknown-model",
465+
wantFound: false,
466+
wantModel: nil,
467+
},
468+
{
469+
name: "empty string",
470+
modelName: "",
471+
wantFound: false,
472+
wantModel: nil,
473+
},
474+
}
475+
476+
for _, tt := range tests {
477+
t.Run(tt.name, func(t *testing.T) {
478+
got := config.GetModelConfig(tt.modelName)
479+
if tt.wantFound {
480+
require.NotNil(t, got)
481+
assert.Equal(t, tt.wantModel.Name, got.Name)
482+
assert.Equal(t, tt.wantModel.BlockSize, got.BlockSize)
483+
} else {
484+
assert.Nil(t, got)
485+
}
486+
})
487+
}
488+
}
489+
490+
func TestConfig_GetModelConfig_NilModelConfigs(t *testing.T) {
491+
config := &kvcache.Config{
492+
ModelConfigs: nil,
493+
}
494+
495+
got := config.GetModelConfig("any-model")
496+
assert.Nil(t, got)
497+
}

pkg/kvcache/kvblock/cost_aware_memory.go

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,23 +103,73 @@ func (m *CostAwareMemoryIndex) MaxCost() int64 {
103103

104104
// CostPodCache wraps a sync.Map of PodEntry and provides cost calculation for memory usage estimation.
105105
type CostPodCache struct {
106-
cache sync.Map // map[PodEntry]struct{}
106+
cache sync.Map // map[string]*PodEntry (key: "podID@tier")
107107
// size tracks the number of entries in cache for O(1) Len().
108108
size atomic.Int64
109109
}
110110

111-
// Add adds a PodEntry to the cache.
111+
// Add adds or updates a PodEntry in the cache, merging StoredGroups if the entry exists.
112112
func (c *CostPodCache) Add(entry PodEntry) {
113-
if _, loaded := c.cache.LoadOrStore(entry, struct{}{}); !loaded {
113+
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)
114+
115+
// Try to load existing entry
116+
if existingVal, loaded := c.cache.Load(cacheKey); loaded {
117+
if existingEntry, ok := existingVal.(*PodEntry); ok {
118+
// Merge StoredGroups
119+
existingEntry.StoredGroups = mergeGroupsUnique(existingEntry.StoredGroups, entry.StoredGroups)
120+
// Store updated entry
121+
c.cache.Store(cacheKey, existingEntry)
122+
}
123+
} else {
124+
// Create new entry
125+
newEntry := &PodEntry{
126+
PodIdentifier: entry.PodIdentifier,
127+
DeviceTier: entry.DeviceTier,
128+
Speculative: entry.Speculative,
129+
StoredGroups: mergeGroupsUnique(nil, entry.StoredGroups),
130+
}
131+
c.cache.Store(cacheKey, newEntry)
114132
c.size.Add(1)
115133
}
116134
}
117135

118-
// Delete removes a PodEntry from the cache.
136+
// Delete removes a PodEntry from the cache entirely.
119137
func (c *CostPodCache) Delete(entry PodEntry) {
120-
if _, loaded := c.cache.LoadAndDelete(entry); loaded {
138+
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)
139+
if _, loaded := c.cache.LoadAndDelete(cacheKey); loaded {
140+
c.size.Add(-1)
141+
}
142+
}
143+
144+
// RemoveGroups removes specified groups from a PodEntry's StoredGroups.
145+
// If no groups remain, the entry is deleted.
146+
func (c *CostPodCache) RemoveGroups(entry PodEntry) bool {
147+
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)
148+
149+
existingVal, loaded := c.cache.Load(cacheKey)
150+
if !loaded {
151+
return false
152+
}
153+
154+
existingEntry, ok := existingVal.(*PodEntry)
155+
if !ok {
156+
return false
157+
}
158+
159+
// Remove specified groups
160+
updatedGroups := removeGroups(existingEntry.StoredGroups, entry.StoredGroups)
161+
162+
if len(updatedGroups) == 0 {
163+
// No groups left, delete the entry
164+
c.cache.Delete(cacheKey)
121165
c.size.Add(-1)
166+
return true
122167
}
168+
169+
// Update with remaining groups
170+
existingEntry.StoredGroups = updatedGroups
171+
c.cache.Store(cacheKey, existingEntry)
172+
return false
123173
}
124174

125175
// Len returns the number of entries in the cache.
@@ -141,16 +191,22 @@ func (c *CostPodCache) CalculateByteSize(keyStr string) int64 {
141191

142192
// Count entries and calculate their size
143193
c.cache.Range(func(key, value interface{}) bool {
144-
entry, ok := key.(PodEntry)
145-
if !ok {
194+
// key is now a string, value is *PodEntry
195+
keyStr, okKey := key.(string)
196+
entry, okEntry := value.(*PodEntry)
197+
if !okKey || !okEntry {
146198
return true
147199
}
148200

149201
entryCount++
150-
totalBytes += int64(len(entry.PodIdentifier)) // PodIdentifier string content
151-
totalBytes += int64(len(entry.DeviceTier)) // DeviceTier string content
152-
totalBytes += 32 // string headers (16 bytes each for 2 strings)
153-
totalBytes += 8 // struct padding/alignment
202+
totalBytes += int64(len(keyStr)) // cache key string
203+
totalBytes += int64(len(entry.PodIdentifier)) // PodIdentifier string content
204+
totalBytes += int64(len(entry.DeviceTier)) // DeviceTier string content
205+
totalBytes += int64(len(entry.StoredGroups) * 8) // StoredGroups slice (8 bytes per int)
206+
totalBytes += 32 // string headers (16 bytes each for 2 strings)
207+
totalBytes += 24 // slice header for StoredGroups
208+
totalBytes += 8 // pointer to PodEntry
209+
totalBytes += 8 // struct padding/alignment
154210
return true
155211
})
156212

@@ -234,17 +290,17 @@ func (m *CostAwareMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHa
234290
if podIdentifierSet.Len() == 0 {
235291
// If no pod identifiers are provided, return all pods
236292
pods.cache.Range(func(k, value interface{}) bool {
237-
if pod, ok := k.(PodEntry); ok {
238-
podsPerKey[key] = append(podsPerKey[key], pod)
293+
if pod, ok := value.(*PodEntry); ok {
294+
podsPerKey[key] = append(podsPerKey[key], *pod)
239295
}
240296
return true
241297
})
242298
} else {
243299
// Filter pods based on the provided pod identifiers
244300
pods.cache.Range(func(k, value interface{}) bool {
245-
if pod, ok := k.(PodEntry); ok {
301+
if pod, ok := value.(*PodEntry); ok {
246302
if podIdentifierSet.Has(pod.PodIdentifier) {
247-
podsPerKey[key] = append(podsPerKey[key], pod)
303+
podsPerKey[key] = append(podsPerKey[key], *pod)
248304
}
249305
}
250306
return true
@@ -307,7 +363,8 @@ func (m *CostAwareMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType
307363
podCacheLenBefore := podCache.Len()
308364

309365
for _, entry := range entries {
310-
podCache.Delete(entry)
366+
// Remove groups from the entry; if no groups remain, the entry is deleted
367+
podCache.RemoveGroups(entry)
311368
}
312369

313370
if podCache.Len() == 0 {

0 commit comments

Comments
 (0)