Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkg/kvcache/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ func NewIndexerForTest(tp kvblock.TokenProcessor, idx kvblock.Index, scorer KVBl
tokenizersPool: pool,
}
}

// ContainsGroup exports the private containsGroup function for testing.
func ContainsGroup(groups []int, groupID int) bool {
return containsGroup(groups, groupID)
}
32 changes: 32 additions & 0 deletions pkg/kvcache/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ import (
"github.com/llm-d/llm-d-kv-cache/pkg/utils/logging"
)

// AttentionGroupConfig defines the attention window size and type for a group.
type AttentionGroupConfig struct {
WindowSize int `json:"windowSize"` // Attention window size (0 or omit for full attention = no constraint)
AttentionType string `json:"attentionType"` // Attention type (e.g., "full", "sliding", "block_sparse")
BlockSize int `json:"blockSize"` // KV block size for this group (if 0, falls back to ModelConfig.BlockSize)
}

// ModelConfig holds model-specific configuration including block size and attention groups.
type ModelConfig struct {
Name string `json:"name"` // Model name
BlockSize int `json:"blockSize"` // Default KV block size (used when AttentionGroupConfig.BlockSize is 0)
AttentionGroups []AttentionGroupConfig `json:"attentionGroups"` // Multiple attention groups with different window sizes and block sizes
}

// Config holds the configuration for the Indexer module.
// The configuration cover the different components found in the Indexer
// module.
Expand All @@ -42,6 +56,7 @@ type Config struct {
KVBlockScorerConfig *KVBlockScorerConfig // not exported
TokenizersPoolConfig *tokenization.Config `json:"tokenizersPoolConfig"`
BackendConfigs []*KVCacheBackendConfig `json:"kvCacheBackendConfigs"`
ModelConfigs []*ModelConfig
}

// NewDefaultConfig returns a default configuration for the Indexer module.
Expand All @@ -56,6 +71,7 @@ func NewDefaultConfig() (*Config, error) {
KVBlockScorerConfig: DefaultKVBlockScorerConfig(),
TokenizersPoolConfig: tokenizerPoolConfig,
BackendConfigs: DefaultKVCacheBackendConfig(),
ModelConfigs: nil, // No default model configs - must be explicitly configured
}, nil
}

Expand Down Expand Up @@ -293,3 +309,19 @@ func (k *Indexer) SetTokenizer(tokenizer tokenization.Tokenizer, modelName strin
func (k *Indexer) blockSize() int {
return k.tokenProcessor.BlockSize()
}

// GetModelConfig returns the model configuration for the given model name.
// Returns nil if no configuration is found for the model.
func (c *Config) GetModelConfig(modelName string) *ModelConfig {
if c.ModelConfigs == nil {
return nil
}

for _, mc := range c.ModelConfigs {
if mc.Name == modelName {
return mc
}
}

return nil
}
125 changes: 125 additions & 0 deletions pkg/kvcache/indexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,128 @@ func TestGetPodScores_TruncateZero(t *testing.T) {
assert.Equal(t, []uint32{1, 2}, tp.receivedTokens,
"token processor should receive all tokens when limit is zero")
}

func TestModelConfig(t *testing.T) {
tests := []struct {
name string
config *kvcache.ModelConfig
wantValid bool
}{
{
name: "valid model config",
config: &kvcache.ModelConfig{
Name: "llama-3-8b",
BlockSize: 16,
AttentionGroups: []kvcache.AttentionGroupConfig{
{WindowSize: 2048, AttentionType: "sliding"},
{WindowSize: 0, AttentionType: "full"},
},
},
wantValid: true,
},
{
name: "single attention group",
config: &kvcache.ModelConfig{
Name: "gpt-4",
BlockSize: 16,
AttentionGroups: []kvcache.AttentionGroupConfig{
{WindowSize: 0, AttentionType: "full"},
},
},
wantValid: true,
},
{
name: "no attention groups",
config: &kvcache.ModelConfig{
Name: "test-model",
BlockSize: 16,
AttentionGroups: []kvcache.AttentionGroupConfig{},
},
wantValid: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.config.Name, tt.config.Name)
assert.Equal(t, tt.config.BlockSize, tt.config.BlockSize)
assert.Equal(t, len(tt.config.AttentionGroups), len(tt.config.AttentionGroups))
})
}
}

func TestConfig_GetModelConfig(t *testing.T) {
config := &kvcache.Config{
ModelConfigs: []*kvcache.ModelConfig{
{
Name: "llama-3-8b",
BlockSize: 16,
AttentionGroups: []kvcache.AttentionGroupConfig{
{WindowSize: 2048, AttentionType: "sliding"},
{WindowSize: 0, AttentionType: "full"},
},
},
{
Name: "gpt-4",
BlockSize: 32,
AttentionGroups: []kvcache.AttentionGroupConfig{
{WindowSize: 0, AttentionType: "full"},
},
},
},
}

tests := []struct {
name string
modelName string
wantFound bool
wantModel *kvcache.ModelConfig
}{
{
name: "find llama-3-8b",
modelName: "llama-3-8b",
wantFound: true,
wantModel: config.ModelConfigs[0],
},
{
name: "find gpt-4",
modelName: "gpt-4",
wantFound: true,
wantModel: config.ModelConfigs[1],
},
{
name: "model not found",
modelName: "unknown-model",
wantFound: false,
wantModel: nil,
},
{
name: "empty string",
modelName: "",
wantFound: false,
wantModel: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := config.GetModelConfig(tt.modelName)
if tt.wantFound {
require.NotNil(t, got)
assert.Equal(t, tt.wantModel.Name, got.Name)
assert.Equal(t, tt.wantModel.BlockSize, got.BlockSize)
} else {
assert.Nil(t, got)
}
})
}
}

func TestConfig_GetModelConfig_NilModelConfigs(t *testing.T) {
config := &kvcache.Config{
ModelConfigs: nil,
}

got := config.GetModelConfig("any-model")
assert.Nil(t, got)
}
89 changes: 73 additions & 16 deletions pkg/kvcache/kvblock/cost_aware_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,73 @@ func (m *CostAwareMemoryIndex) MaxCost() int64 {

// CostPodCache wraps a sync.Map of PodEntry and provides cost calculation for memory usage estimation.
type CostPodCache struct {
cache sync.Map // map[PodEntry]struct{}
cache sync.Map // map[string]*PodEntry (key: "podID@tier")
// size tracks the number of entries in cache for O(1) Len().
size atomic.Int64
}

// Add adds a PodEntry to the cache.
// Add adds or updates a PodEntry in the cache, merging StoredGroups if the entry exists.
func (c *CostPodCache) Add(entry PodEntry) {
if _, loaded := c.cache.LoadOrStore(entry, struct{}{}); !loaded {
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)

// Try to load existing entry
if existingVal, loaded := c.cache.Load(cacheKey); loaded {
if existingEntry, ok := existingVal.(*PodEntry); ok {
// Merge StoredGroups
existingEntry.StoredGroups = mergeGroupsUnique(existingEntry.StoredGroups, entry.StoredGroups)
// Store updated entry
c.cache.Store(cacheKey, existingEntry)
}
} else {
// Create new entry
newEntry := &PodEntry{
PodIdentifier: entry.PodIdentifier,
DeviceTier: entry.DeviceTier,
Speculative: entry.Speculative,
StoredGroups: mergeGroupsUnique(nil, entry.StoredGroups),
}
c.cache.Store(cacheKey, newEntry)
c.size.Add(1)
}
}

// Delete removes a PodEntry from the cache.
// Delete removes a PodEntry from the cache entirely.
func (c *CostPodCache) Delete(entry PodEntry) {
if _, loaded := c.cache.LoadAndDelete(entry); loaded {
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)
if _, loaded := c.cache.LoadAndDelete(cacheKey); loaded {
c.size.Add(-1)
}
}

// RemoveGroups removes specified groups from a PodEntry's StoredGroups.
// If no groups remain, the entry is deleted.
func (c *CostPodCache) RemoveGroups(entry PodEntry) bool {
cacheKey := podCacheKey(entry.PodIdentifier, entry.DeviceTier, entry.Speculative)

existingVal, loaded := c.cache.Load(cacheKey)
if !loaded {
return false
}

existingEntry, ok := existingVal.(*PodEntry)
if !ok {
return false
}

// Remove specified groups
updatedGroups := removeGroups(existingEntry.StoredGroups, entry.StoredGroups)

if len(updatedGroups) == 0 {
// No groups left, delete the entry
c.cache.Delete(cacheKey)
c.size.Add(-1)
return true
}

// Update with remaining groups
existingEntry.StoredGroups = updatedGroups
c.cache.Store(cacheKey, existingEntry)
return false
}

// Len returns the number of entries in the cache.
Expand All @@ -141,16 +191,22 @@ func (c *CostPodCache) CalculateByteSize(keyStr string) int64 {

// Count entries and calculate their size
c.cache.Range(func(key, value interface{}) bool {
entry, ok := key.(PodEntry)
if !ok {
// key is now a string, value is *PodEntry
keyStr, okKey := key.(string)
entry, okEntry := value.(*PodEntry)
if !okKey || !okEntry {
return true
}

entryCount++
totalBytes += int64(len(entry.PodIdentifier)) // PodIdentifier string content
totalBytes += int64(len(entry.DeviceTier)) // DeviceTier string content
totalBytes += 32 // string headers (16 bytes each for 2 strings)
totalBytes += 8 // struct padding/alignment
totalBytes += int64(len(keyStr)) // cache key string
totalBytes += int64(len(entry.PodIdentifier)) // PodIdentifier string content
totalBytes += int64(len(entry.DeviceTier)) // DeviceTier string content
totalBytes += int64(len(entry.StoredGroups) * 8) // StoredGroups slice (8 bytes per int)
totalBytes += 32 // string headers (16 bytes each for 2 strings)
totalBytes += 24 // slice header for StoredGroups
totalBytes += 8 // pointer to PodEntry
totalBytes += 8 // struct padding/alignment
return true
})

Expand Down Expand Up @@ -234,17 +290,17 @@ func (m *CostAwareMemoryIndex) Lookup(ctx context.Context, requestKeys []BlockHa
if podIdentifierSet.Len() == 0 {
// If no pod identifiers are provided, return all pods
pods.cache.Range(func(k, value interface{}) bool {
if pod, ok := k.(PodEntry); ok {
podsPerKey[key] = append(podsPerKey[key], pod)
if pod, ok := value.(*PodEntry); ok {
podsPerKey[key] = append(podsPerKey[key], *pod)
}
return true
})
} else {
// Filter pods based on the provided pod identifiers
pods.cache.Range(func(k, value interface{}) bool {
if pod, ok := k.(PodEntry); ok {
if pod, ok := value.(*PodEntry); ok {
if podIdentifierSet.Has(pod.PodIdentifier) {
podsPerKey[key] = append(podsPerKey[key], pod)
podsPerKey[key] = append(podsPerKey[key], *pod)
}
}
return true
Expand Down Expand Up @@ -307,7 +363,8 @@ func (m *CostAwareMemoryIndex) Evict(ctx context.Context, key BlockHash, keyType
podCacheLenBefore := podCache.Len()

for _, entry := range entries {
podCache.Delete(entry)
// Remove groups from the entry; if no groups remain, the entry is deleted
podCache.RemoveGroups(entry)
}

if podCache.Len() == 0 {
Expand Down
Loading