Skip to content

Commit 1449816

Browse files
authored
add chunked prefill and continuous batching writeup (#64)
Signed-off-by: Alex Chi Z <iskyzh@gmail.com>
1 parent 34fb3fe commit 1449816

File tree

9 files changed

+341
-24
lines changed

9 files changed

+341
-24
lines changed

book/src/SUMMARY.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
- [Key-Value Cache](./week2-01-kv-cache.md)
1818
- [Quantized Matmul (2 Days)]()
1919
- [Flash Attention (2 Days)]()
20-
- [Chunked Prefill]()
21-
- [Continuous Batching]()
20+
- [Continuous Batching (2 Days)](./week2-06-prefill-and-batch.md)
2221
- [Week 3: Serving]()
2322

2423
---
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Week 2 Day 6 and 7: Chunked Prefill and Continuous Batching
2+
3+
In this chapter, we will implement **continuous batching**. The idea is to batch multiple requests together so we can make full use of the compute resources.
4+
5+
So far, we have assumed that the model only processes a single batch each time it is called. However, a single batch is usually not enough to saturate the compute resources. To address this, we can process multiple requests at the same time.
6+
7+
The first question is how to batch requests. A naive approach would be to select a fixed number of prompts (for example, 5) from the request queue and perform decoding as before. The problem is that different prompts produce sequences of different lengths. It is possible that 4 out of 5 requests finish decoding quickly, while the remaining one takes much longer. This leads to wasted compute resources and stalls all other requests.
8+
9+
A smarter approach is **continuous batching**. That is, we set the maximum number of requests we can process at once. When one request finishes, we replace its slot (i.e., its KV cache) with another request. In this way, the pipeline remains fully utilized.
10+
11+
Another challenge is how to handle decoding and prefilling at the same time. In this chapter, we adopt a simplified approach: we prefill one request, then decode one token for each request in progress. The general idea can be described with the following pseudocode:
12+
13+
```python
14+
while requests_in_queue_or_in_progress:
15+
if prefill_request exists:
16+
prefill_request.try_prefill() # perform a chunk of chunked prefill
17+
if prefill_request.ready:
18+
if kv_cache.try_add(prefill_request):
19+
prefill_request = next(requests)
20+
tokens = decode(model, kv_cache)
21+
requests.append(tokens)
22+
```
23+
24+
We will also implement **chunked prefill** in this chapter. Prefilling a long prompt can take a significant amount of time. Since we are interleaving prefills and decodes, we want to reduce the latency of producing the next token. Ideally, the time slots for prefill and decode should be roughly equal. To achieve this, we can prefill a portion of the request at a time, using multiple slots to finish the entire prefill.
25+
26+
For prefilling, this essentially means providing a chunk of tokens to the model to populate the KV cache. For example:
27+
28+
```python
29+
# assume prompt_tokens is a list of 400 tokens and prefill chunk size is 128
30+
_step(model, prompt_tokens[0:128], offset=0, kv_cache)
31+
_step(model, prompt_tokens[128:256], offset=128, kv_cache)
32+
_step(model, prompt_tokens[256:384], offset=256, kv_cache)
33+
_step(model, prompt_tokens[384:400], offset=384, kv_cache)
34+
```
35+
36+
Note that the causal mask generated during prefilling has the shape `LxS`. For example, assume we already have 5 tokens in the KV cache and want to prefill 3 tokens. The mask should look like this:
37+
38+
```
39+
0 0 0 -inf -inf
40+
0 0 0 0 -inf
41+
0 0 0 0 0
42+
```
43+
44+
This is the same masking logic you implemented in Week 1.
45+
46+
## Task 1: Batch RoPE and Causal Mask for Prefill
47+
48+
```
49+
src/tiny_llm/positional_encoding.py
50+
src/tiny_llm/attention.py::causal_mask
51+
```
52+
53+
Ensure your RoPE implementation accepts a list of offsets. Also, make sure your mask implementation correctly handles the case where `L != S`.
54+
55+
## Task 2: Batch KV Cache
56+
57+
```
58+
src/tiny_llm/kv_cache.py::BatchingKvCache
59+
```
60+
61+
The batch KV cache is a collection of KV caches, one for each request. A challenge here is generating a `BxHxLxS` mask for the batch, since requests can have different lengths.
62+
63+
```
64+
S = max(S_i of the batch)
65+
L = mask_length (input parameter)
66+
keys: 1, H, S_i, D
67+
values: 1, H, S_i, D
68+
batched_keys: B, H, S, D
69+
batched_values: B, H, S, D
70+
mask: B, 1, L, S
71+
```
72+
73+
You should fill the `batched_keys` and `batched_values` arrays so that each request’s data is aligned at the end:
74+
75+
```python
76+
batched_keys[i, :, (S-S_i):S, :] = keys[i, :, :, :]
77+
batched_values[i, :, (S-S_i):S, :] = values[i, :, :, :]
78+
mask[i, :, 0:L, (S-S_i):S] = causal_mask(L, S_i)
79+
```
80+
81+
## Task 3: Handle Batches in the Model
82+
83+
```
84+
src/tiny_llm/qwen2_week2.py
85+
```
86+
87+
Ensure your model can handle multiple requests simultaneously. You should also use the masks returned by the batch KV cache.
88+
89+
## Task 4: Batch Generate
90+
91+
```
92+
src/tiny_llm/batch.py
93+
```
94+
95+
Implement `try_prefill` so that it prefills an entire request at once. Then implement the rest of the code as described in the starter code.
96+
97+
## Task 5: Chunked Prefill
98+
99+
```
100+
src/tiny_llm/batch.py
101+
```
102+
103+
Modify `try_prefill` so that it performs prefilling in chunks, rather than all at once.
104+
105+
You can test your implementation by running:
106+
107+
```bash
108+
pdm run batch-main
109+
```
110+
111+
This will use the `qwen2-0.5b` model with a batch size of 5 to process a fixed set of prompts.
112+
113+
{{#include copyright.md}}

src/tiny_llm/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,6 @@ def flash_attention(
5353
key: mx.array,
5454
value: mx.array,
5555
scale: float | None = None,
56+
mask: mx.array | None = None,
5657
) -> mx.array:
5758
pass

src/tiny_llm/batch.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import mlx.core as mx
2+
from mlx_lm.tokenizer_utils import TokenizerWrapper
3+
from .kv_cache import *
4+
from .qwen2_week2 import Qwen2ModelWeek2
5+
from typing import Callable
6+
from datetime import datetime
7+
8+
9+
def _step(model, y, offsets, kv_cache):
10+
logits = model(y, offsets, kv_cache)
11+
logits = logits[:, -1, :]
12+
logprobs = logits - mx.logsumexp(logits, keepdims=True)
13+
sampler = lambda x: mx.argmax(x, axis=-1)
14+
y = sampler(logprobs)
15+
return y
16+
17+
18+
class Request:
19+
def __init__(
20+
self,
21+
model: any,
22+
tokenizer: TokenizerWrapper,
23+
prompt: str,
24+
prefill_max_step: int = 128,
25+
prompt_idx: int = 0,
26+
):
27+
self.prompt = prompt
28+
self.kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)]
29+
self.model = model
30+
self.detokenizer = tokenizer.detokenizer.__class__(tokenizer._tokenizer)
31+
self.prefill_tokens = mx.array(
32+
tokenizer.encode(prompt, add_special_tokens=False)
33+
)
34+
self.prefill_max_step = prefill_max_step
35+
self.is_done = False
36+
self.is_prefill_done = False
37+
self.eos_token_id = tokenizer.eos_token_id
38+
self.next_token = None
39+
self.offset = 0
40+
self.prompt_idx = prompt_idx
41+
42+
def try_prefill(self):
43+
"""
44+
Prefill this request up to max_step size, returns None if prefill is not done
45+
"""
46+
if self.is_prefill_done:
47+
raise ValueError("prefill called after done")
48+
# TODO: in task 4, prefill the full request at once; in task 5, prefill a chunk at a time
49+
50+
def decode_done(self, token, update_offset=True):
51+
if self.is_done:
52+
raise ValueError("decode called after done")
53+
if token == self.eos_token_id:
54+
self.is_done = True
55+
return
56+
# TODO: update the offset and add the token to the detokenizer
57+
58+
def text(self):
59+
return self.detokenizer.text
60+
61+
62+
def _print_progress(
63+
requests: list[Request | None],
64+
is_idle: list[bool],
65+
pending_prefill_request: Request | None,
66+
queue_size: int,
67+
progress_cnt: int,
68+
start_time: datetime,
69+
):
70+
print(f" --- {datetime.now() - start_time}")
71+
animation_frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
72+
animation_frame = animation_frames[progress_cnt % len(animation_frames)]
73+
for i in range(len(requests)):
74+
if is_idle[i]:
75+
print(f" Decode #{i}: idle", flush=True)
76+
else:
77+
text_preview = requests[i].text()[-80:].replace("\n", " ")
78+
print(
79+
f"{animation_frame} Decode [req {requests[i].prompt_idx}, {requests[i].offset}]: {text_preview}",
80+
flush=True,
81+
)
82+
if pending_prefill_request is not None:
83+
if pending_prefill_request.is_prefill_done:
84+
print(
85+
f" Prefill [req {pending_prefill_request.prompt_idx}]: done, waiting for slot, {queue_size} requests in queue",
86+
flush=True,
87+
)
88+
return
89+
precentage = (
90+
pending_prefill_request.offset / pending_prefill_request.prefill_tokens.size
91+
) * 100
92+
print(
93+
f"{animation_frame} Prefill [req {pending_prefill_request.prompt_idx}]: {precentage:.2f}% ({pending_prefill_request.prefill_tokens.size - pending_prefill_request.offset} remaining tokens)",
94+
flush=True,
95+
)
96+
else:
97+
print(f" Prefill: idle, {queue_size} requests in queue", flush=True)
98+
99+
100+
def batch_generate(
101+
model: any,
102+
tokenizer: TokenizerWrapper,
103+
prompts: list[str],
104+
max_seq_len=512,
105+
batch_size=5,
106+
prefill_step=128,
107+
):
108+
decode_requests: list[Request] = [None] * batch_size
109+
is_idle = [True] * batch_size
110+
kv_cache = [
111+
BatchingKvCache(max_active_requests=batch_size, max_seq_len=max_seq_len)
112+
for _ in range(model.num_hidden_layers)
113+
]
114+
result = []
115+
pending_prefill_request = None
116+
next_request_idx = 0
117+
progress_cnt = 0
118+
start_time = datetime.now()
119+
120+
while True:
121+
if len(prompts) == 0 and all(is_idle):
122+
break
123+
# prefill until no idle slots
124+
if len(prompts) > 0 and pending_prefill_request is None:
125+
prompt = prompts.pop(0)
126+
pending_prefill_request = Request(
127+
model, tokenizer, prompt, prefill_step, next_request_idx
128+
)
129+
next_request_idx += 1
130+
131+
# In every iteration, we do a prefill first
132+
if pending_prefill_request is not None:
133+
made_progress = False
134+
if not pending_prefill_request.is_prefill_done:
135+
pending_prefill_request.try_prefill()
136+
made_progress = True
137+
if pending_prefill_request.is_prefill_done:
138+
# Implement this: find an idle slot and add the request to the decode requests
139+
pass
140+
if made_progress:
141+
_print_progress(
142+
decode_requests,
143+
is_idle,
144+
pending_prefill_request,
145+
len(prompts),
146+
progress_cnt,
147+
start_time,
148+
)
149+
progress_cnt += 1
150+
151+
# After the prefill request moves forward one step, we do the decode
152+
if not all(is_idle):
153+
next_tokens = []
154+
offsets = []
155+
# TODO: collect the next tokens and offsets from the decode requests
156+
next_tokens = _step(model, next_tokens.reshape(-1, 1), offsets, kv_cache)
157+
for i in range(batch_size):
158+
# TODO: check if the decode has finished by comparing EOS or the seqlength. If so,
159+
# remove the request from the decode requests and add the result to the result list;
160+
# otherwise, call `decode_done` to update the offset and add the token to the detokenizer
161+
pass
162+
_print_progress(
163+
decode_requests,
164+
is_idle,
165+
pending_prefill_request,
166+
len(prompts),
167+
progress_cnt,
168+
start_time,
169+
)
170+
progress_cnt += 1
171+
return result

src/tiny_llm/generate.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,3 @@ def simple_generate_with_kv_cache(
2020
) -> str:
2121
def _step(model, y, offset, kv_cache):
2222
pass
23-
24-
25-
def batch_generate(
26-
model: any,
27-
tokenizer: TokenizerWrapper,
28-
prompts: list[str],
29-
max_seq_len=512,
30-
batch_size=5,
31-
prefill_step=128,
32-
):
33-
pass

src/tiny_llm/kv_cache.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,52 @@ def update_and_fetch(
3131

3232
class BatchingKvCache(TinyKvCache):
3333
def __init__(self, max_active_requests: int, max_seq_len: int):
34-
pass
34+
self.max_active_requests = max_active_requests
35+
self.max_seq_len = max_seq_len
36+
self.kv_caches: list[TinyKvCache] = [None] * max_active_requests
37+
self.HD = None
3538

3639
def update_and_fetch(
37-
self, key: mx.array, value: mx.array
38-
) -> tuple[mx.array, mx.array, int]:
39-
pass
40+
self,
41+
keys: mx.array,
42+
values: mx.array,
43+
mask_length: int | None = None,
44+
mask: mx.array | str | None = None,
45+
) -> tuple[mx.array, mx.array, int, Optional[mx.array]]:
46+
B, H, S, D = keys.shape
47+
assert keys.shape == values.shape
48+
assert S <= self.max_seq_len
49+
assert self.HD == (H, D), f"expect {self.HD} but got {H, D}"
50+
assert B == self.max_active_requests
51+
# Step 1: append the result to the cache
52+
data = []
53+
for b in range(B):
54+
if self.kv_caches[b] is None:
55+
data.append(None)
56+
continue
57+
key, value = keys[b : b + 1], values[b : b + 1]
58+
new_key, new_value, seq_len, mask = self.kv_caches[b].update_and_fetch(
59+
key, value
60+
)
61+
data.append((new_key[0], new_value[0], seq_len, mask))
62+
63+
# Step 2: compute seq_len of this batch
64+
def get_seq_len(data):
65+
if data is None:
66+
return 0
67+
_, _, seq_len, _ = data
68+
return seq_len
69+
70+
seq_len = max(map(get_seq_len, data))
71+
72+
# Step 3: generate masks and a single array of keys and values
73+
keys = mx.zeros((self.max_active_requests, H, seq_len, D), dtype=key.dtype)
74+
values = mx.zeros((self.max_active_requests, H, seq_len, D), dtype=value.dtype)
75+
masks = mx.full(
76+
(self.max_active_requests, mask_length, seq_len), -mx.inf, dtype=key.dtype
77+
)
78+
# TODO: generate masks and a single array of keys and values
79+
return keys, values, None, masks.reshape(B, 1, mask_length, seq_len)
4080

4181
def add_request(self, prefilled: TinyKvCache, id: int):
4282
pass
@@ -47,9 +87,14 @@ def remove_request(self, id: int):
4787

4888
class TinyKvFullCache(TinyKvCache):
4989
def __init__(self):
50-
pass
90+
self.key_values = None
91+
self.offset = 0
5192

5293
def update_and_fetch(
53-
self, key: mx.array, value: mx.array
54-
) -> tuple[mx.array, mx.array, int]:
94+
self,
95+
key: mx.array,
96+
value: mx.array,
97+
mask_length: int | None = None,
98+
mask: mx.array | str | None = None,
99+
) -> tuple[mx.array, mx.array, int, Optional[mx.array]]:
55100
pass

0 commit comments

Comments
 (0)