Skip to content

Commit bd67cbf

Browse files
committed
prefil lchunking
1 parent cbe8d4e commit bd67cbf

2 files changed

Lines changed: 36 additions & 13 deletions

File tree

src/kvboost/engine.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def __init__(
9595
overlap_k: int = 0,
9696
# Attention sink (global memory prefix)
9797
sink_tokens: int = 0,
98+
# Chunked prefill (0 = disabled, single-shot prefill)
99+
prefill_chunk_size: int = 0,
98100
):
99101
if device is None:
100102
device = default_device()
@@ -116,6 +118,7 @@ def __init__(
116118
self.recompute_strategy = RecomputeStrategy(recompute_strategy)
117119
self.overlap_k = overlap_k
118120
self.sink_tokens = sink_tokens
121+
self.prefill_chunk_size = int(prefill_chunk_size)
119122

120123
# Pre-compute boundary token IDs for adaptive splitting
121124
self._boundary_tokens: Set[int] = (
@@ -689,24 +692,38 @@ def _decode_with_kv(
689692

690693
# ----- encode live tokens (prompt tail) -------------------------
691694
if live_ids:
692-
input_ids = torch.tensor([live_ids], dtype=torch.long, device=self.device)
693-
pos_ids = torch.arange(
694-
cached_len, cached_len + len(live_ids),
695-
dtype=torch.long, device=self.device,
696-
).unsqueeze(0)
695+
n_live = len(live_ids)
696+
cs = self.prefill_chunk_size
697+
# cs <= 0 → single-shot prefill (legacy behavior)
698+
chunk_step = n_live if cs <= 0 else min(cs, n_live)
699+
700+
out = None
701+
cur = 0
702+
while cur < n_live:
703+
end = min(cur + chunk_step, n_live)
704+
slice_ids = live_ids[cur:end]
705+
input_ids = torch.tensor([slice_ids], dtype=torch.long, device=self.device)
706+
pos_ids = torch.arange(
707+
cached_len + cur, cached_len + end,
708+
dtype=torch.long, device=self.device,
709+
).unsqueeze(0)
710+
711+
# last_logit_only is fine for non-final chunks too — we just
712+
# don't read those logits, and trimming saves a bit of memory.
713+
with torch.no_grad(), last_logit_only(self.model):
714+
out = self.model(
715+
input_ids=input_ids,
716+
past_key_values=self._as_cache(past_kv),
717+
position_ids=pos_ids,
718+
use_cache=True,
719+
)
720+
past_kv = self._normalize_past_kv(out.past_key_values)
721+
cur = end
697722

698-
with torch.no_grad(), last_logit_only(self.model):
699-
out = self.model(
700-
input_ids=input_ids,
701-
past_key_values=self._as_cache(past_kv),
702-
position_ids=pos_ids,
703-
use_cache=True,
704-
)
705723
first_token_time = time.perf_counter()
706724
# Capture first-token logits for comparison with baseline
707725
import numpy as np
708726
first_token_logits = out.logits[0, -1, :].cpu().float().numpy()
709-
past_kv = self._normalize_past_kv(out.past_key_values)
710727
next_token = self._sample(out.logits[:, -1, :], temperature, do_sample)
711728
generated.append(next_token)
712729
if on_token is not None:

src/kvboost/server/__main__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ def parse_args():
103103
help="KV quantization bits (16=off, 8=int8, 4=int4)")
104104
p.add_argument("--sink-tokens", type=int, default=0)
105105
p.add_argument("--overlap-k", type=int, default=0)
106+
p.add_argument("--prefill-chunk-size", type=int, default=0,
107+
help="Process the prompt in slices of N tokens during prefill, "
108+
"growing past_key_values between iterations. 0 = single-shot "
109+
"(legacy). Set to e.g. 512 or 1024 to fit long prompts on "
110+
"small GPUs by capping peak FFN/attention activation memory.")
106111

107112
# CPU paged backend
108113
p.add_argument("--block-size", type=int, default=16, help="Tokens per paged block")
@@ -269,6 +274,7 @@ def load_engine(args):
269274
kv_cache_bits=args.kv_cache_bits,
270275
sink_tokens=args.sink_tokens,
271276
overlap_k=args.overlap_k,
277+
prefill_chunk_size=args.prefill_chunk_size,
272278
device=device,
273279
)
274280

0 commit comments

Comments
 (0)