Skip to content

Commit bd0172a

Browse files
committed
parse hma kv event metadata
Signed-off-by: Sage Ahrac <sagiahrak@gmail.com>
1 parent 8cf550e commit bd0172a

5 files changed

Lines changed: 262 additions & 21 deletions

File tree

pkg/kvevents/engineadapter/sglang_adapter.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ func (s *SGLangAdapter) convertBlockStoredEvent(rawEventBytes []byte) (kvevents.
188188
BlockHashes: blockHashes,
189189
Tokens: event.TokenIds,
190190
ParentHash: parentHash,
191+
BlockSize: event.BlockSize,
191192
DeviceTier: deviceTier,
192193
LoraID: event.LoraID,
193194
LoraName: event.LoraName,

pkg/kvevents/engineadapter/sglang_adapter_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ func TestSGLangParseMessage_Valid(t *testing.T) {
7171
require.True(t, ok)
7272
assert.Equal(t, []uint64{100, 101}, blockStored.BlockHashes)
7373
assert.Equal(t, uint64(99), blockStored.ParentHash)
74+
assert.Equal(t, 16, blockStored.BlockSize)
7475
}
7576

7677
// TestSGLangParseMessage_InvalidPayload tests error handling for invalid msgpack data.
@@ -114,6 +115,7 @@ func TestSGLangBlockStored_FullFields(t *testing.T) {
114115
assert.Equal(t, []uint64{100, 101}, blockStored.BlockHashes)
115116
assert.Equal(t, uint64(99), blockStored.ParentHash)
116117
assert.Equal(t, []uint32{1, 2, 3}, blockStored.Tokens)
118+
assert.Equal(t, 16, blockStored.BlockSize)
117119
assert.Equal(t, "gpu", blockStored.DeviceTier)
118120
assert.Nil(t, blockStored.LoraID)
119121
assert.Nil(t, blockStored.LoraName)
@@ -146,6 +148,7 @@ func TestSGLangBlockStored_7Fields(t *testing.T) {
146148
assert.Equal(t, []uint64{300, 301}, blockStored.BlockHashes)
147149
assert.Equal(t, uint64(299), blockStored.ParentHash)
148150
assert.Equal(t, []uint32{7, 8, 9}, blockStored.Tokens)
151+
assert.Equal(t, 64, blockStored.BlockSize)
149152
assert.Equal(t, "GPU", blockStored.DeviceTier)
150153
assert.Nil(t, blockStored.LoraID)
151154
assert.Nil(t, blockStored.LoraName, "SGLang does not send lora_name")
@@ -177,6 +180,7 @@ func TestSGLangBlockStored_MinimalFields(t *testing.T) {
177180
assert.Equal(t, []uint64{400}, blockStored.BlockHashes)
178181
assert.Equal(t, uint64(399), blockStored.ParentHash)
179182
assert.Equal(t, []uint32{10, 11}, blockStored.Tokens)
183+
assert.Equal(t, 128, blockStored.BlockSize)
180184
assert.Equal(t, "", blockStored.DeviceTier, "medium should default to empty")
181185
assert.Nil(t, blockStored.LoraID)
182186
assert.Nil(t, blockStored.LoraName)

pkg/kvevents/engineadapter/vllm_adapter.go

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,18 @@ func fieldAt(fields []any, i int) any {
132132
// convertBlockStoredEvent converts a decoded []any into a BlockStoredEvent.
133133
// vLLM field positions (array_like=True, tag=True):
134134
//
135-
// [0] tag string (consumed by decodeVLLMEvent)
136-
// [1] block_hashes []hash
137-
// [2] parent_hash hash|nil
138-
// [3] token_ids []uint32
139-
// [4] block_size int (consumed but not stored)
140-
// [5] lora_id int|nil (optional, omit_defaults)
141-
// [6] medium string|nil (optional, omit_defaults)
142-
// [7] lora_name string|nil (optional, omit_defaults)
143-
// [8] extra_keys [][]any|nil (optional, omit_defaults)
135+
// [0] tag string (consumed by decodeVLLMEvent)
136+
// [1] block_hashes []hash
137+
// [2] parent_hash hash|nil
138+
// [3] token_ids []uint32
139+
// [4] block_size int
140+
// [5] lora_id int|nil (optional, omit_defaults)
141+
// [6] medium string|nil (optional, omit_defaults)
142+
// [7] lora_name string|nil (optional, omit_defaults)
143+
// [8] extra_keys [][]any|nil (optional, omit_defaults)
144+
// [9] group_idx int|nil (optional, HMA)
145+
// [10] kv_cache_spec_kind string|nil (optional, HMA)
146+
// [11] kv_cache_spec_sliding_window int|nil (optional, HMA)
144147
//
145148
// Trailing fields may be absent in older vLLM versions. Extra trailing fields
146149
// from newer vLLM versions are silently ignored.
@@ -176,7 +179,11 @@ func (v *VLLMAdapter) convertBlockStoredEvent(fields []any) (kvevents.GenericEve
176179
return nil, fmt.Errorf("BlockStored: %w", err)
177180
}
178181

179-
// [4] block_size — consumed but not stored in domain event
182+
// [4] block_size
183+
blockSize, err := toInt(fields[4])
184+
if err != nil {
185+
return nil, fmt.Errorf("BlockStored: block_size: %w", err)
186+
}
180187

181188
// [5] lora_id (optional)
182189
var loraID *int
@@ -221,14 +228,48 @@ func (v *VLLMAdapter) convertBlockStoredEvent(fields []any) (kvevents.GenericEve
221228
}
222229
}
223230

231+
var groupIdx *int
232+
if raw := fieldAt(fields, 9); raw != nil {
233+
group, err := toInt(raw)
234+
if err != nil {
235+
return nil, fmt.Errorf("BlockStored: group_idx: %w", err)
236+
}
237+
if group < 0 {
238+
return nil, fmt.Errorf("BlockStored: group_idx: negative value: %d", group)
239+
}
240+
groupIdx = &group
241+
}
242+
243+
var specKind kvevents.KVCacheSpecKind
244+
if raw := fieldAt(fields, 10); raw != nil {
245+
s, ok := raw.(string)
246+
if !ok {
247+
return nil, fmt.Errorf("BlockStored: kv_cache_spec_kind is not a string: %T", raw)
248+
}
249+
specKind = kvevents.KVCacheSpecKind(s)
250+
}
251+
252+
var slidingWindow *int
253+
if raw := fieldAt(fields, 11); raw != nil {
254+
window, err := toInt(raw)
255+
if err != nil {
256+
return nil, fmt.Errorf("BlockStored: kv_cache_spec_sliding_window: %w", err)
257+
}
258+
slidingWindow = &window
259+
}
260+
224261
return &kvevents.BlockStoredEvent{
225-
BlockHashes: blockHashes,
226-
Tokens: tokens,
227-
ParentHash: parentHash,
228-
DeviceTier: deviceTier,
229-
LoraID: loraID,
230-
LoraName: loraName,
231-
ExtraKeys: extraKeys,
262+
BlockHashes: blockHashes,
263+
Tokens: tokens,
264+
ParentHash: parentHash,
265+
BlockSize: blockSize,
266+
DeviceTier: deviceTier,
267+
LoraID: loraID,
268+
LoraName: loraName,
269+
ExtraKeys: extraKeys,
270+
GroupIdx: groupIdx,
271+
KVCacheSpecKind: specKind,
272+
KVCacheSpecSlidingWindowSize: slidingWindow,
232273
}, nil
233274
}
234275

@@ -238,6 +279,7 @@ func (v *VLLMAdapter) convertBlockStoredEvent(fields []any) (kvevents.GenericEve
238279
// [0] tag string
239280
// [1] block_hashes []hash
240281
// [2] medium string|nil (optional, omit_defaults)
282+
// [3] group_idx int|nil (optional, HMA)
241283
func (v *VLLMAdapter) convertBlockRemovedEvent(fields []any) (kvevents.GenericEvent, error) {
242284
if len(fields) < 2 {
243285
return nil, fmt.Errorf("BlockRemoved: need at least 2 fields, got %d", len(fields))
@@ -261,9 +303,22 @@ func (v *VLLMAdapter) convertBlockRemovedEvent(fields []any) (kvevents.GenericEv
261303
deviceTier = s
262304
}
263305

306+
var groupIdx *int
307+
if raw := fieldAt(fields, 3); raw != nil {
308+
group, err := toInt(raw)
309+
if err != nil {
310+
return nil, fmt.Errorf("BlockRemoved: group_idx: %w", err)
311+
}
312+
if group < 0 {
313+
return nil, fmt.Errorf("BlockRemoved: group_idx: negative value: %d", group)
314+
}
315+
groupIdx = &group
316+
}
317+
264318
return &kvevents.BlockRemovedEvent{
265319
BlockHashes: blockHashes,
266320
DeviceTier: deviceTier,
321+
GroupIdx: groupIdx,
267322
}, nil
268323
}
269324

pkg/kvevents/engineadapter/vllm_adapter_test.go

Lines changed: 160 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,40 @@ func TestVLLMBlockStoredWithLora(t *testing.T) {
158158
assert.Equal(t, [][]any{{"uuid-A", "salt"}, nil}, blockStored.ExtraKeys)
159159
}
160160

161+
func TestVLLMBlockStoredWithHMAMetadata(t *testing.T) {
162+
adapter := NewVLLMAdapter()
163+
164+
vllmEvent := []any{
165+
"BlockStored",
166+
[]any{uint64(700), uint64(701)},
167+
uint64(699),
168+
[]uint32{1, 2, 3, 4},
169+
16,
170+
nil,
171+
"gpu",
172+
nil,
173+
nil,
174+
uint64(1),
175+
"sliding_window",
176+
128,
177+
}
178+
179+
rawBytes, err := msgpack.Marshal(vllmEvent)
180+
require.NoError(t, err)
181+
182+
event, err := adapter.decodeVLLMEvent(rawBytes)
183+
require.NoError(t, err)
184+
185+
blockStored, ok := event.(*kvevents.BlockStoredEvent)
186+
require.True(t, ok)
187+
assert.Equal(t, 16, blockStored.BlockSize)
188+
require.NotNil(t, blockStored.GroupIdx)
189+
assert.Equal(t, 1, *blockStored.GroupIdx)
190+
assert.Equal(t, kvevents.KVCacheSpecKindSlidingWindow, blockStored.KVCacheSpecKind)
191+
require.NotNil(t, blockStored.KVCacheSpecSlidingWindowSize)
192+
assert.Equal(t, 128, *blockStored.KVCacheSpecSlidingWindowSize)
193+
}
194+
161195
// TestDecodeVLLMEvent_BlockStoredMissingTrailingFields tests backward compatibility
162196
// when trailing optional fields are absent (older vLLM with omit_defaults=True).
163197
func TestDecodeVLLMEvent_BlockStoredMissingTrailingFields(t *testing.T) {
@@ -236,7 +270,7 @@ func TestDecodeVLLMEvent_BlockStoredMissingTrailingFields(t *testing.T) {
236270
func TestDecodeVLLMEvent_BlockStoredExtraTrailingFields(t *testing.T) {
237271
adapter := NewVLLMAdapter()
238272

239-
// Simulate a future vLLM version with extra_keys and another unknown field
273+
// Simulate a future vLLM version with HMA metadata plus another unknown field.
240274
vllmEvent := []any{
241275
"BlockStored",
242276
[]any{uint64(400), uint64(401)},
@@ -247,7 +281,10 @@ func TestDecodeVLLMEvent_BlockStoredExtraTrailingFields(t *testing.T) {
247281
"gpu",
248282
"my-lora",
249283
[]any{[]any{"extra", "keys"}}, // [8] extra_keys
250-
"completely-unknown-field", // [9] future unknown — silently ignored
284+
uint64(0), // [9] group_idx
285+
"full_attention", // [10] kv_cache_spec_kind
286+
nil, // [11] kv_cache_spec_sliding_window
287+
"completely-unknown-field", // [12] future unknown — silently ignored
251288
}
252289

253290
rawBytes, err := msgpack.Marshal(vllmEvent)
@@ -267,6 +304,9 @@ func TestDecodeVLLMEvent_BlockStoredExtraTrailingFields(t *testing.T) {
267304
assert.Equal(t, "my-lora", *blockStored.LoraName)
268305
require.NotNil(t, blockStored.ExtraKeys)
269306
assert.Equal(t, [][]any{{"extra", "keys"}}, blockStored.ExtraKeys)
307+
require.NotNil(t, blockStored.GroupIdx)
308+
assert.Equal(t, 0, *blockStored.GroupIdx)
309+
assert.Equal(t, kvevents.KVCacheSpecKindFullAttention, blockStored.KVCacheSpecKind)
270310
}
271311

272312
// TestDecodeVLLMEvent_BlockRemovedExtraTrailingFields tests forward compatibility for BlockRemoved.
@@ -277,8 +317,8 @@ func TestDecodeVLLMEvent_BlockRemovedExtraTrailingFields(t *testing.T) {
277317
"BlockRemoved",
278318
[]any{uint64(500)},
279319
"cpu",
280-
"future-field-1",
281-
"future-field-2",
320+
uint64(0), // [3] group_idx
321+
"future-field-1", // [4] future unknown — silently ignored
282322
}
283323

284324
rawBytes, err := msgpack.Marshal(vllmEvent)
@@ -291,6 +331,8 @@ func TestDecodeVLLMEvent_BlockRemovedExtraTrailingFields(t *testing.T) {
291331
require.True(t, ok)
292332
assert.Equal(t, []uint64{500}, blockRemoved.BlockHashes)
293333
assert.Equal(t, "cpu", blockRemoved.DeviceTier)
334+
require.NotNil(t, blockRemoved.GroupIdx)
335+
assert.Equal(t, 0, *blockRemoved.GroupIdx)
294336
}
295337

296338
// TestDecodeVLLMEvent_BlockRemovedMissingMedium tests backward compat for BlockRemoved.
@@ -312,6 +354,120 @@ func TestDecodeVLLMEvent_BlockRemovedMissingMedium(t *testing.T) {
312354
require.True(t, ok)
313355
assert.Equal(t, []uint64{600}, blockRemoved.BlockHashes)
314356
assert.Equal(t, "", blockRemoved.DeviceTier)
357+
assert.Nil(t, blockRemoved.GroupIdx)
358+
}
359+
360+
func TestDecodeVLLMEvent_BlockRemovedWithGroupIdx(t *testing.T) {
361+
adapter := NewVLLMAdapter()
362+
363+
vllmEvent := []any{
364+
"BlockRemoved",
365+
[]any{uint64(700)},
366+
"gpu",
367+
uint64(1),
368+
}
369+
370+
rawBytes, err := msgpack.Marshal(vllmEvent)
371+
require.NoError(t, err)
372+
373+
event, err := adapter.decodeVLLMEvent(rawBytes)
374+
require.NoError(t, err)
375+
376+
blockRemoved, ok := event.(*kvevents.BlockRemovedEvent)
377+
require.True(t, ok)
378+
require.NotNil(t, blockRemoved.GroupIdx)
379+
assert.Equal(t, 1, *blockRemoved.GroupIdx)
380+
}
381+
382+
func TestDecodeVLLMEvent_BlockStoredInvalidHMAMetadata(t *testing.T) {
383+
adapter := NewVLLMAdapter()
384+
385+
tests := []struct {
386+
name string
387+
event []any
388+
wantErr string
389+
}{
390+
{
391+
name: "negative group idx",
392+
event: []any{
393+
"BlockStored",
394+
[]any{uint64(700)},
395+
uint64(699),
396+
[]uint32{1, 2},
397+
16,
398+
nil,
399+
"gpu",
400+
nil,
401+
nil,
402+
int64(-1),
403+
},
404+
wantErr: "group_idx",
405+
},
406+
{
407+
name: "non-string spec kind",
408+
event: []any{
409+
"BlockStored",
410+
[]any{uint64(700)},
411+
uint64(699),
412+
[]uint32{1, 2},
413+
16,
414+
nil,
415+
"gpu",
416+
nil,
417+
nil,
418+
uint64(0),
419+
uint64(123),
420+
},
421+
wantErr: "kv_cache_spec_kind",
422+
},
423+
{
424+
name: "non-numeric sliding window",
425+
event: []any{
426+
"BlockStored",
427+
[]any{uint64(700)},
428+
uint64(699),
429+
[]uint32{1, 2},
430+
16,
431+
nil,
432+
"gpu",
433+
nil,
434+
nil,
435+
uint64(0),
436+
"sliding_window",
437+
"bad-window",
438+
},
439+
wantErr: "kv_cache_spec_sliding_window",
440+
},
441+
}
442+
443+
for _, tt := range tests {
444+
t.Run(tt.name, func(t *testing.T) {
445+
rawBytes, err := msgpack.Marshal(tt.event)
446+
require.NoError(t, err)
447+
448+
_, err = adapter.decodeVLLMEvent(rawBytes)
449+
require.Error(t, err)
450+
assert.Contains(t, err.Error(), tt.wantErr)
451+
})
452+
}
453+
}
454+
455+
func TestDecodeVLLMEvent_BlockRemovedInvalidGroupIdx(t *testing.T) {
456+
adapter := NewVLLMAdapter()
457+
458+
vllmEvent := []any{
459+
"BlockRemoved",
460+
[]any{uint64(700)},
461+
"gpu",
462+
int64(-1),
463+
}
464+
465+
rawBytes, err := msgpack.Marshal(vllmEvent)
466+
require.NoError(t, err)
467+
468+
_, err = adapter.decodeVLLMEvent(rawBytes)
469+
require.Error(t, err)
470+
assert.Contains(t, err.Error(), "group_idx")
315471
}
316472

317473
func intPtr(v int) *int {

0 commit comments

Comments
 (0)