@@ -95,6 +95,8 @@ def __init__(
9595 overlap_k : int = 0 ,
9696 # Attention sink (global memory prefix)
9797 sink_tokens : int = 0 ,
98+ # Chunked prefill (0 = disabled, single-shot prefill)
99+ prefill_chunk_size : int = 0 ,
98100 ):
99101 if device is None :
100102 device = default_device ()
@@ -116,6 +118,7 @@ def __init__(
116118 self .recompute_strategy = RecomputeStrategy (recompute_strategy )
117119 self .overlap_k = overlap_k
118120 self .sink_tokens = sink_tokens
121+ self .prefill_chunk_size = int (prefill_chunk_size )
119122
120123 # Pre-compute boundary token IDs for adaptive splitting
121124 self ._boundary_tokens : Set [int ] = (
@@ -689,24 +692,38 @@ def _decode_with_kv(
689692
690693 # ----- encode live tokens (prompt tail) -------------------------
691694 if live_ids :
692- input_ids = torch .tensor ([live_ids ], dtype = torch .long , device = self .device )
693- pos_ids = torch .arange (
694- cached_len , cached_len + len (live_ids ),
695- dtype = torch .long , device = self .device ,
696- ).unsqueeze (0 )
695+ n_live = len (live_ids )
696+ cs = self .prefill_chunk_size
697+ # cs <= 0 → single-shot prefill (legacy behavior)
698+ chunk_step = n_live if cs <= 0 else min (cs , n_live )
699+
700+ out = None
701+ cur = 0
702+ while cur < n_live :
703+ end = min (cur + chunk_step , n_live )
704+ slice_ids = live_ids [cur :end ]
705+ input_ids = torch .tensor ([slice_ids ], dtype = torch .long , device = self .device )
706+ pos_ids = torch .arange (
707+ cached_len + cur , cached_len + end ,
708+ dtype = torch .long , device = self .device ,
709+ ).unsqueeze (0 )
710+
711+ # last_logit_only is fine for non-final chunks too — we just
712+ # don't read those logits, and trimming saves a bit of memory.
713+ with torch .no_grad (), last_logit_only (self .model ):
714+ out = self .model (
715+ input_ids = input_ids ,
716+ past_key_values = self ._as_cache (past_kv ),
717+ position_ids = pos_ids ,
718+ use_cache = True ,
719+ )
720+ past_kv = self ._normalize_past_kv (out .past_key_values )
721+ cur = end
697722
698- with torch .no_grad (), last_logit_only (self .model ):
699- out = self .model (
700- input_ids = input_ids ,
701- past_key_values = self ._as_cache (past_kv ),
702- position_ids = pos_ids ,
703- use_cache = True ,
704- )
705723 first_token_time = time .perf_counter ()
706724 # Capture first-token logits for comparison with baseline
707725 import numpy as np
708726 first_token_logits = out .logits [0 , - 1 , :].cpu ().float ().numpy ()
709- past_kv = self ._normalize_past_kv (out .past_key_values )
710727 next_token = self ._sample (out .logits [:, - 1 , :], temperature , do_sample )
711728 generated .append (next_token )
712729 if on_token is not None :
0 commit comments