-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathblock_manager.py
More file actions
179 lines (159 loc) · 6.85 KB
/
block_manager.py
File metadata and controls
179 lines (159 loc) · 6.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
from collections import deque
import numpy as np
import xxhash
from atom.config import Config
from atom.model_engine.sequence import Sequence
class Block:
def __init__(self, block_id):
self.block_id = block_id
self.ref_count = 0
self.hash = -1
self.token_ids = []
def update(self, hash: int, token_ids: list[int]):
self.hash = hash
self.token_ids = token_ids
def reset(self):
self.ref_count = 1
self.hash = -1
self.token_ids = []
class BlockManager:
def __init__(self, config: Config):
block_size = config.kv_cache_block_size
num_blocks = config.num_kvcache_blocks
assert num_blocks > 0
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.free_block_ids_set: set[int] = set(range(num_blocks))
self.used_block_ids: set[int] = set()
self.enable_prefix_caching = config.enable_prefix_caching
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
def _pop_free_block(self) -> int:
"""Pop the next available free block id from the FIFO queue (lazy cleanup)."""
while self.free_block_ids:
block_id = self.free_block_ids.popleft()
if block_id in self.free_block_ids_set:
self.free_block_ids_set.discard(block_id)
return block_id
raise AssertionError("No free blocks available")
def _allocate_block(self, block_id: int) -> Block:
block = self.blocks[block_id]
assert block.ref_count == 0
# Evict stale hash entry before resetting
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
del self.hash_to_block_id[block.hash]
block.reset()
self.free_block_ids_set.discard(block_id)
self.used_block_ids.add(block_id)
return self.blocks[block_id]
def _deallocate_block(self, block_id: int):
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
self.free_block_ids_set.add(block_id)
def can_allocate(self, seq: Sequence) -> bool:
if not self.enable_prefix_caching:
return len(self.free_block_ids_set) >= seq.num_blocks + seq.num_mamba_blocks
# Dry-run: count how many blocks would be cache hits
h = -1
cache_miss = False
needed_free = 0
for i in range(seq.num_blocks):
token_ids = seq.block(i)
h = (
self.compute_hash(token_ids, h)
if len(token_ids) == self.block_size
else -1
)
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
needed_free += 1
return len(self.free_block_ids_set) >= needed_free
def allocate(self, seq: Sequence):
assert not seq.block_table
h = -1
cache_miss = False
for i in range(seq.num_blocks):
token_ids = seq.block(i)
h = (
self.compute_hash(token_ids, h)
if len(token_ids) == self.block_size
else -1
)
block_id = (
self.hash_to_block_id.get(h, -1) if self.enable_prefix_caching else -1
)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
block_id = self._pop_free_block()
block = self._allocate_block(block_id)
else:
seq.num_cached_tokens += self.block_size
if block_id in self.used_block_ids:
block = self.blocks[block_id]
block.ref_count += 1
else:
block = self._allocate_block(block_id)
if h != -1:
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
seq.block_table.append(block_id)
# handle mamba-like model
if seq.mamba_enabled:
# For mamba, we need to ensure the last block is always allocated
# even if it has less than block_size tokens
for i in range(seq.num_mamba_blocks):
block_id = self.free_block_ids[0]
self._allocate_block(block_id)
# No prefix caching support for mamba arch
seq.mamba_block_table.append(block_id)
def deallocate(self, seq: Sequence):
for block_id in reversed(seq.block_table):
block = self.blocks[block_id]
block.ref_count -= 1
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
if seq.mamba_enabled:
for block_id in reversed(seq.mamba_block_table):
block = self.blocks[block_id]
# just in case
block.ref_count = 0
self._deallocate_block(block_id)
seq.mamba_block_table.clear()
def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool:
seq_len = len(seq)
current_blocks = len(seq.block_table)
needed_blocks = (
seq_len + num_new_tokens + self.block_size - 1
) // self.block_size
new_blocks_needed = max(0, needed_blocks - current_blocks)
return len(self.free_block_ids_set) >= new_blocks_needed
def may_append(self, seq: Sequence, num_new_tokens: int = 1):
block_table = seq.block_table
seq_len = len(seq)
# Check if we need to allocate a new block
# When len(seq) % block_size == 1, we need a new block for the next token
# When block_size == 1, every token needs a new block
if 0 < seq_len % self.block_size <= num_new_tokens or self.block_size == 1:
needed_blocks = (seq_len + self.block_size - 1) // self.block_size
while len(block_table) < needed_blocks:
# Decode-generated blocks: token not finalized yet (depends on
# sampling / speculative verification), so we cannot compute a
# correct hash here. Just allocate the block without hashing.
block_id = self._pop_free_block()
self._allocate_block(block_id)
block_table.append(block_id)