Skip to content

fix(sglang): decode token_ids as typed []uint32 with bigram support#603

Open
ryanx-sir wants to merge 3 commits into
llm-d:mainfrom
ryanx-sir:fix/sglang-bigram-token-ids
Open

fix(sglang): decode token_ids as typed []uint32 with bigram support#603
ryanx-sir wants to merge 3 commits into
llm-d:mainfrom
ryanx-sir:fix/sglang-bigram-token-ids

Conversation

@ryanx-sir
Copy link
Copy Markdown

Summary

SGLangAdapter currently fails to decode BlockStored events when the underlying SGLang engine has EAGLE-family speculative decoding enabled (--speculative-algorithm EAGLE | EAGLE3 | FROZEN_KV_MTP). In that mode SGLang emits token_ids as a nested array of bigram pairs rather than a flat int list, which the existing []uint32 msgpack target cannot satisfy — every BlockStored event from such pods is dropped with a decode error.

This PR teaches sglang_adapter.go to handle both wire shapes via a typed msgpack.CustomDecoder, with no []any boxing on the hot path.

The two shapes

SGLang's mem_cache/events.py:_record_store_event materialises token_ids differently depending on whether the request's RadixKey is in bigram mode:

Mode Wire payload Trigger
Flat [t0, t1, t2, ...] normal inference
Bigram [[t0, t1], [t1, t2], [t2, t3], ...] is_bigram = is_eagle == True, i.e. EAGLE / EAGLE3 / FROZEN_KV_MTP

is_bigram is set in radix_cache.py, hiradix_cache.py, unified_radix_cache.py, and swa_radix_cache.py whenever the cache was constructed with is_eagle=True. The tuple form has been the wire contract since events.py was introduced in sgl-project/sglang#23678 (the in-code comment explicitly notes it preserves the historical EAGLE event payload after #23106 made bigram an O(1) flag).

Pairs overlap on the boundary token, so collapsing each pair to its second element gives back the engine's raw[1:] token sequence — exactly the prefix the indexer needs.

The fix

A new sglangTokenIDs (alias of []uint32) implements msgpack.CustomDecoder:

  1. DecodeArrayLen to read the outer count.
  2. PeekCode on the first element — FixArray | Array16 | Array32 ⇒ bigram, otherwise flat.
  3. Bigram branch: per pair, Skip() the previous token and DecodeUint32() the current; defensive Skip() for any future trailing fields.
  4. Flat branch: straight DecodeUint32() per element.

This eliminates the prior []any + flattenTokenIDs([]any) + toUint32(any) boxing chain (three layers of reflect / type-switch per token) and replaces it with a single typed allocation per event. The dead helpers go away.

Test plan

  • New TestSGLangBlockStored_BigramTokenIDs round-trips [[10,11],[11,12],[12,13]] and asserts the decoder yields [11, 12, 13] (= raw[1:]).
  • All existing TestSGLang* cases (FullFields / 7Fields / MinimalFields / TooFewFields / BlockRemoved / AllBlocksCleared / UnknownTag / ShardingKey / ParseMessage) remain green — flat-mode wire is unchanged.
  • go test ./pkg/kvevents/engineadapter/ -count=1 passes.
  • go vet ./pkg/kvevents/engineadapter/... clean.

Notes for reviewers

  • No config flag was added. Shape detection is purely content-driven on the wire (the two msgpack codes are disjoint), so it works for mixed-pod pools where some pods run EAGLE and others don't — no operator-side knob to misconfigure.
  • The BlockStored.token_ids: list[int] type annotation in SGLang's kv_events.py is misleading; the bigram payload violates it. That mismatch is what led the original adapter (PR feat: support SGLang KV events in VLLMAdapter #443) — modelled on the "SGLang ≈ vLLM wire format" assumption — to type the field as []uint32. PR fix: single-pass []any decode for forward/backward compat with vLLM event schema #484 fixed an analogous issue on the vLLM side; this is the SGLang counterpart it explicitly deferred.

@github-actions github-actions Bot added the size/L Denotes a PR that changes 100-499 lines, ignoring generated files. label May 23, 2026
@github-actions
Copy link
Copy Markdown

Unsigned commits detected! Please sign your commits.

For instructions on how to set up GPG/SSH signing and verify your commits, please see GitHub Documentation.

@ryanx-sir ryanx-sir force-pushed the fix/sglang-bigram-token-ids branch from fed018f to bde7c48 Compare May 23, 2026 16:25
SGLang emits BlockStored.token_ids in two distinct wire shapes:

  * flat       [t0, t1, ...]              (normal inference)
  * bigram     [[t0, t1], [t1, t2], ...]  (EAGLE / EAGLE3 /
                                           FROZEN_KV_MTP speculative
                                           decoding)

The bigram path is taken whenever the engine's RadixKey carries
is_bigram=is_eagle -- see sglang/python/sglang/srt/mem_cache/radix_cache.py,
hiradix_cache.py, unified_radix_cache.py, swa_radix_cache.py, and the
materialisation in mem_cache/events.py:_record_store_event. Pairs
overlap on the boundary token, so collapsing each pair to its second
element reproduces the engine's raw[1:] token sequence.

Previously the adapter typed TokenIds as []uint32, which silently
failed msgpack decoding when EAGLE was on (a nested array cannot
satisfy the []uint32 target). This change introduces a sglangTokenIDs
type with a custom DecodeMsgpack that sniffs the first element's wire
code (FixArray/Array16/Array32 -> bigram, otherwise flat) and decodes
straight into []uint32. No []any boxing, no reflect type-switch, one
allocation per event.

Adds TestSGLangBlockStored_BigramTokenIDs covering the EAGLE path;
existing flat-format tests remain green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: jiruixian <ryanx.sir@gmail.com>
@ryanx-sir ryanx-sir force-pushed the fix/sglang-bigram-token-ids branch from bde7c48 to b4cd838 Compare May 23, 2026 16:27
@vMaroon
Copy link
Copy Markdown
Member

vMaroon commented May 24, 2026

@bongwoobak can you give this a look?

CI lint failed on the previous commit:

  pkg/kvevents/engineadapter/sglang_adapter.go:132
    variable name 'n' is too short for the scope of its usage (varnamelen)
  pkg/kvevents/engineadapter/sglang_adapter.go:150
    `if isBigram` has complex nested blocks (complexity: 7) (nestif)

Rename `n` -> `count` and lift the bigram per-pair decode into two
helpers (decodeBigramTokenIDs / decodeBigramPair) so the bigram branch
in DecodeMsgpack is a single call instead of two levels of nested loops
with per-step error wrapping. Behavior is unchanged.

Signed-off-by: Ryan Xu <ryanx.sir@gmail.com>
Signed-off-by: jiruixian <ryanx.sir@gmail.com>
@ryanx-sir ryanx-sir force-pushed the fix/sglang-bigram-token-ids branch from 9b481aa to 3622bb9 Compare May 24, 2026 12:23
Copy link
Copy Markdown
Member

@bongwoobak bongwoobak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since SGLang stores KV cache differently when EAGLE-family speculative decoding is enabled, this fix is needed. That said, I've left a few concerns that I think should be addressed in this PR.

Bottom line: with this PR alone, SGLang + EAGLE/MTP pods won't produce any prefix cache hits unless the request-key side is also updated.

// decodeBigramTokenIDs decodes the bigram-shaped token_ids array
// ([[t0,t1],[t1,t2],...]) into out by keeping the second element of each pair,
// which collapses overlapping pairs back to the underlying token sequence.
func decodeBigramTokenIDs(dec *msgpack.Decoder, out []uint32) error {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decode succeeds, but EAGLE prefix matching still won't work.

llm-d's hashing side has no bigram awareness (token_processor.go:hash() is plain FNV/CBOR over the chunk, grep "EAGLE|bigram" pkg/ returns 0).

So with this PR, EAGLE pods store hash(raw[1:17]) while the router hashes incoming requests as hash(raw[0:16]). same prompt, different hashes between EAGLE and non-EAGLE pods.

llm-d's hashing side needs a paired path (shift-by-1 or pair-hash)

// nested [[prev, curr], ...] payload that SGLang emits when EAGLE-family
// speculative decoding is enabled. Each pair's second element is kept, which
// matches the engine's raw[1:] token sequence.
func TestSGLangBlockStored_BigramTokenIDs(t *testing.T) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests for the new error/skip branches? Specifically: inner < 2 (malformed pair), inner > 2 (the defensive Skip loop), and empty token_ids.

…nches

The previous bigram decoder kept only pair[i][1] for each pair, dropping
pair[0][0]. For a page with N pairs covering raw[0:N+1], that emitted
raw[1:N+1], which made the canonical block hash diverge from the flat-token
request path: EAGLE pods registered hash([t1..t16]) while incoming requests
were hashed as hash([t0..t15]).

Fix this adapter-side by decoding pair[0][0] as the page-head and writing
both halves of the first pair (then only the tail of each subsequent pair).
The result is raw[0:N+1]; the trailing overlap token is dropped by
kvblock.chunkTokens as a partial block, so the recomputed canonical block
hashes match the flat-token path with no changes to token_processor.go.

Also adds tests for the branches reviewer flagged: empty token_ids
(short-circuit before peek), inner<2 (malformed pair → error), and inner>2
(defensive skip path tolerates extras).

Refs: llm-d#603 review feedback from @bongwoobak

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: jiruixian <ryanx.sir@gmail.com>
@ryanx-sir
Copy link
Copy Markdown
Author

Thanks for the careful read @bongwoobak — you were right that the first revision broke prefix matching for EAGLE pods. I pushed 2b9332f which addresses both comments.

On "shift-by-1 vs pair-hash": I went with shift-by-1, but kept it adapter-local instead of touching token_processor.go. The reasoning is that SGLang's bigram payload of N pairs covers N+1 raw tokens (raw[0:N+1]), and the previous decoder was dropping pair[0][0], emitting only raw[1:N+1]. The fix is to decode the page-head too — for i == 0 we now keep both halves of the pair, for i > 0 we keep only the tail (which is what the prior code did for every pair).

The trailing pair[N-1][1] is the overlap token shared with the next page; kvblock.chunkTokens already drops it as a partial block (// no partial blocks), so the recomputed canonical block hashes line up exactly with what the request side produces for the same flat prompt. No bigram awareness is needed in token_processor.go, EAGLE and non-EAGLE pods produce identical hashes for the same prefix, and they can coexist in the same InferencePool.

I considered the pair-hash route, but it would require teaching token_processor.go two hashing modes, plumbing per-pool engine identity to the request path, and either splitting EAGLE/non-EAGLE pods into separate pools or doing dual lookups per request — much wider blast radius than this one-engine quirk warrants. Happy to revisit if you'd prefer the hashing-side variant.

On tests: added three tests for the branches you flagged — TestSGLangBlockStored_BigramMalformedPair (inner < 2), TestSGLangBlockStored_BigramOversizedPair (inner > 2 defensive skip), and TestSGLangBlockStored_EmptyTokenIDs. The existing bigram test was updated to assert the full raw[0:N+1] sequence.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size/L Denotes a PR that changes 100-499 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants