|
6 | 6 | from typing import Dict, Optional, Type, cast |
7 | 7 |
|
8 | 8 | import equinox as eqx |
| 9 | +import jax |
9 | 10 | import jax.random as jrandom |
10 | 11 |
|
11 | 12 | import haliax as hax |
|
16 | 17 | from haliax.state_dict import ModuleWithStateDictSerialization |
17 | 18 |
|
18 | 19 | from levanter.compat.hf_checkpoints import HFCheckpointConverter |
| 20 | +from levanter.inference.page_table import PageBatchInfo, PageTableSpec |
19 | 21 | from levanter.layers.attention import Attention, AttentionConfig, AttentionMask |
| 22 | +from levanter.layers.kv_cache import KvPageCache, ListCache |
20 | 23 | from levanter.layers.rotary import RotaryEmbeddingsConfig |
21 | 24 | from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaLMHeadModel, LlamaMlp, LlamaTransformer |
22 | 25 | from levanter.models.lm_model import LmConfig, LmHeadModel |
@@ -185,6 +188,32 @@ def __call__( |
185 | 188 | output = residual + mlp_output |
186 | 189 | return output |
187 | 190 |
|
| 191 | + @named_call |
| 192 | + def decode( |
| 193 | + self, |
| 194 | + x: NamedArray, |
| 195 | + kv_cache: KvPageCache, |
| 196 | + batch_info: PageBatchInfo, |
| 197 | + pos_ids: NamedArray, |
| 198 | + *, |
| 199 | + key=None, |
| 200 | + ) -> tuple[NamedArray, KvPageCache]: |
| 201 | + k_attn, k_mlp = maybe_rng_split(key, 2) |
| 202 | + |
| 203 | + residual = x |
| 204 | + x = self.input_layernorm(x) |
| 205 | + attn_output, kv_cache = self.self_attn.paged_decode(x, kv_cache, batch_info, pos_ids=pos_ids, key=k_attn) |
| 206 | + x = residual + attn_output |
| 207 | + |
| 208 | + residual = x |
| 209 | + x = self.post_attention_layernorm(x) |
| 210 | + mlp_output = self.mlp(x, key=k_mlp) |
| 211 | + output = residual + mlp_output |
| 212 | + return output, kv_cache |
| 213 | + |
| 214 | + def initial_cache(self, spec: PageTableSpec, *, dtype) -> KvPageCache: |
| 215 | + return self.self_attn.empty_page_cache(spec, dtype=dtype) |
| 216 | + |
188 | 217 |
|
189 | 218 | # Modified transformer for Qwen |
190 | 219 | class QwenTransformer(LlamaTransformer): |
@@ -218,6 +247,42 @@ def __call__( |
218 | 247 | x = self.norm(x) |
219 | 248 | return x |
220 | 249 |
|
| 250 | + @named_call |
| 251 | + def decode( |
| 252 | + self, |
| 253 | + kv_cache: ListCache[KvPageCache], |
| 254 | + x: NamedArray, |
| 255 | + batch_info: PageBatchInfo, |
| 256 | + pos_ids: NamedArray, |
| 257 | + *, |
| 258 | + key=None, |
| 259 | + ) -> tuple[NamedArray, ListCache[KvPageCache]]: |
| 260 | + keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None |
| 261 | + caches = list(kv_cache) |
| 262 | + updated_caches: list[KvPageCache] = [] |
| 263 | + |
| 264 | + for i in range(self.config.num_layers): |
| 265 | + with jax.named_scope("slice layer"): |
| 266 | + layer = hax.tree_util.tree_map(lambda l: l["layer", i], self.layers.stacked) # type: ignore |
| 267 | + with jax.named_scope("slice cache"): |
| 268 | + this_cache = caches[i] |
| 269 | + x, this_cache = layer.decode( |
| 270 | + x, |
| 271 | + this_cache, |
| 272 | + batch_info, |
| 273 | + pos_ids=pos_ids, |
| 274 | + key=keys[i] if keys is not None else None, |
| 275 | + ) |
| 276 | + with jax.named_scope("update cache"): |
| 277 | + updated_caches.append(this_cache) |
| 278 | + |
| 279 | + x = self.norm(x) |
| 280 | + return x, ListCache(updated_caches) |
| 281 | + |
| 282 | + def initial_cache(self, spec: PageTableSpec, *, dtype) -> ListCache[KvPageCache]: |
| 283 | + caches = [layer.initial_cache(spec, dtype=dtype) for layer in self.layers.unstacked()] |
| 284 | + return ListCache(caches) |
| 285 | + |
221 | 286 |
|
222 | 287 | # Modified LM head model for Qwen |
223 | 288 | class QwenLMHeadModel(LmHeadModel[QwenConfig], ModuleWithStateDictSerialization): |
@@ -289,6 +354,30 @@ def init(cls, Vocab: Axis, config: QwenConfig, *, key) -> "QwenLMHeadModel": |
289 | 354 | def _state_dict_key_map(self) -> Dict[str, Optional[str]]: |
290 | 355 | return {"transformer": "model", "embeddings": None} |
291 | 356 |
|
| 357 | + def initial_cache(self, spec: PageTableSpec, *, dtype) -> ListCache[KvPageCache]: |
| 358 | + return hax.auto_sharded(self.transformer.initial_cache(spec, dtype=dtype)) |
| 359 | + |
| 360 | + @named_call |
| 361 | + def decode( |
| 362 | + self, |
| 363 | + input_ids: NamedArray, |
| 364 | + kv_cache: ListCache[KvPageCache], |
| 365 | + batch_info: PageBatchInfo, |
| 366 | + pos_ids: NamedArray, |
| 367 | + *, |
| 368 | + key=None, |
| 369 | + ) -> tuple[NamedArray, ListCache[KvPageCache]]: |
| 370 | + x = self.embeddings.embed(input_ids) |
| 371 | + k_t = maybe_rng_split(key, 1)[0] if key is not None else None |
| 372 | + x, new_state = self.transformer.decode(kv_cache, x, batch_info, pos_ids, key=k_t) |
| 373 | + |
| 374 | + if self.lm_head is not None: |
| 375 | + logits = self.lm_head(x, key=None) |
| 376 | + else: |
| 377 | + logits = self.embeddings.unembed(x) |
| 378 | + |
| 379 | + return logits, new_state |
| 380 | + |
292 | 381 |
|
293 | 382 | # ===================== |
294 | 383 | # Qwen-3 Configuration |
|
0 commit comments