Skip to content

Commit 02c75d8

Browse files
Add paged-decode interface to Qwen, matching llama/apertus
Adds decode() and initial_cache() to QwenDecoderLayer, QwenTransformer, and QwenLMHeadModel, mirroring the paged-KV decode interface already on llama.py and apertus.py. Lets Qwen models plug into the paged-KV inference engine.
1 parent a0c36c1 commit 02c75d8

2 files changed

Lines changed: 119 additions & 0 deletions

File tree

lib/levanter/src/levanter/models/qwen.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Dict, Optional, Type, cast
77

88
import equinox as eqx
9+
import jax
910
import jax.random as jrandom
1011

1112
import haliax as hax
@@ -16,7 +17,9 @@
1617
from haliax.state_dict import ModuleWithStateDictSerialization
1718

1819
from levanter.compat.hf_checkpoints import HFCheckpointConverter
20+
from levanter.inference.page_table import PageBatchInfo, PageTableSpec
1921
from levanter.layers.attention import Attention, AttentionConfig, AttentionMask
22+
from levanter.layers.kv_cache import KvPageCache, ListCache
2023
from levanter.layers.rotary import RotaryEmbeddingsConfig
2124
from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaLMHeadModel, LlamaMlp, LlamaTransformer
2225
from levanter.models.lm_model import LmConfig, LmHeadModel
@@ -185,6 +188,32 @@ def __call__(
185188
output = residual + mlp_output
186189
return output
187190

191+
@named_call
192+
def decode(
193+
self,
194+
x: NamedArray,
195+
kv_cache: KvPageCache,
196+
batch_info: PageBatchInfo,
197+
pos_ids: NamedArray,
198+
*,
199+
key=None,
200+
) -> tuple[NamedArray, KvPageCache]:
201+
k_attn, k_mlp = maybe_rng_split(key, 2)
202+
203+
residual = x
204+
x = self.input_layernorm(x)
205+
attn_output, kv_cache = self.self_attn.paged_decode(x, kv_cache, batch_info, pos_ids=pos_ids, key=k_attn)
206+
x = residual + attn_output
207+
208+
residual = x
209+
x = self.post_attention_layernorm(x)
210+
mlp_output = self.mlp(x, key=k_mlp)
211+
output = residual + mlp_output
212+
return output, kv_cache
213+
214+
def initial_cache(self, spec: PageTableSpec, *, dtype) -> KvPageCache:
215+
return self.self_attn.empty_page_cache(spec, dtype=dtype)
216+
188217

189218
# Modified transformer for Qwen
190219
class QwenTransformer(LlamaTransformer):
@@ -218,6 +247,42 @@ def __call__(
218247
x = self.norm(x)
219248
return x
220249

250+
@named_call
251+
def decode(
252+
self,
253+
kv_cache: ListCache[KvPageCache],
254+
x: NamedArray,
255+
batch_info: PageBatchInfo,
256+
pos_ids: NamedArray,
257+
*,
258+
key=None,
259+
) -> tuple[NamedArray, ListCache[KvPageCache]]:
260+
keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None
261+
caches = list(kv_cache)
262+
updated_caches: list[KvPageCache] = []
263+
264+
for i in range(self.config.num_layers):
265+
with jax.named_scope("slice layer"):
266+
layer = hax.tree_util.tree_map(lambda l: l["layer", i], self.layers.stacked) # type: ignore
267+
with jax.named_scope("slice cache"):
268+
this_cache = caches[i]
269+
x, this_cache = layer.decode(
270+
x,
271+
this_cache,
272+
batch_info,
273+
pos_ids=pos_ids,
274+
key=keys[i] if keys is not None else None,
275+
)
276+
with jax.named_scope("update cache"):
277+
updated_caches.append(this_cache)
278+
279+
x = self.norm(x)
280+
return x, ListCache(updated_caches)
281+
282+
def initial_cache(self, spec: PageTableSpec, *, dtype) -> ListCache[KvPageCache]:
283+
caches = [layer.initial_cache(spec, dtype=dtype) for layer in self.layers.unstacked()]
284+
return ListCache(caches)
285+
221286

222287
# Modified LM head model for Qwen
223288
class QwenLMHeadModel(LmHeadModel[QwenConfig], ModuleWithStateDictSerialization):
@@ -289,6 +354,30 @@ def init(cls, Vocab: Axis, config: QwenConfig, *, key) -> "QwenLMHeadModel":
289354
def _state_dict_key_map(self) -> Dict[str, Optional[str]]:
290355
return {"transformer": "model", "embeddings": None}
291356

357+
def initial_cache(self, spec: PageTableSpec, *, dtype) -> ListCache[KvPageCache]:
358+
return hax.auto_sharded(self.transformer.initial_cache(spec, dtype=dtype))
359+
360+
@named_call
361+
def decode(
362+
self,
363+
input_ids: NamedArray,
364+
kv_cache: ListCache[KvPageCache],
365+
batch_info: PageBatchInfo,
366+
pos_ids: NamedArray,
367+
*,
368+
key=None,
369+
) -> tuple[NamedArray, ListCache[KvPageCache]]:
370+
x = self.embeddings.embed(input_ids)
371+
k_t = maybe_rng_split(key, 1)[0] if key is not None else None
372+
x, new_state = self.transformer.decode(kv_cache, x, batch_info, pos_ids, key=k_t)
373+
374+
if self.lm_head is not None:
375+
logits = self.lm_head(x, key=None)
376+
else:
377+
logits = self.embeddings.unembed(x)
378+
379+
return logits, new_state
380+
292381

293382
# =====================
294383
# Qwen-3 Configuration

lib/levanter/tests/test_qwen2.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
import tempfile
66

77
import numpy as np
8+
import jax.numpy as jnp
89
from jax import random
910

1011
import haliax as hax
1112

13+
from levanter.inference.jit_scheduler import SequenceTable
14+
from levanter.inference.page_table import PageTable
1215
from levanter.layers.attention import AttentionMask
1316
from levanter.models.qwen import QwenConfig, QwenLMHeadModel
1417
from test_utils import skip_if_no_torch, use_test_mesh
@@ -118,3 +121,30 @@ def compute(model, input):
118121
torch_out2 = torch_out2.logits[0].detach().cpu().numpy()
119122
assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}"
120123
np.testing.assert_allclose(torch_out2, jax_out, rtol=1e-4, atol=2e-4)
124+
125+
126+
def test_qwen_supports_paged_kv_inference_interface():
127+
vocab_size = 64
128+
Vocab = hax.Axis("vocab", vocab_size)
129+
config = QwenConfig.from_hf_config(get_config(vocab_size))
130+
key = random.PRNGKey(0)
131+
132+
with use_test_mesh():
133+
model = QwenLMHeadModel.init(Vocab, config, key=key)
134+
135+
page_table = PageTable.init(8, 2, 4, 2)
136+
cache = model.initial_cache(page_table.spec(), dtype=jnp.bfloat16)
137+
138+
sequences = SequenceTable.init(page_table.max_seqs, page_table.pages_per_seq, page_table.page_size)
139+
sequences, slot_arr = sequences.reserve_slot(0)
140+
slot_id = int(slot_arr)
141+
142+
token_ids = hax.named(jnp.array([1], dtype=jnp.int32), axis=("position",))
143+
slot_ids = hax.named(jnp.array([slot_id], dtype=jnp.int32), axis=("position",))
144+
pos_ids = hax.named(jnp.array([0], dtype=jnp.int32), axis=("position",))
145+
sequences, _page_table, batch_info = sequences.allocate_for_seq(page_table, slot_ids, pos_ids)
146+
147+
logits, updated_cache = model.decode(token_ids, cache, batch_info, pos_ids)
148+
149+
assert logits.axes == (hax.Axis("position", 1), Vocab)
150+
assert len(updated_cache) == config.num_layers

0 commit comments

Comments
 (0)