Skip to content

Commit b31cb9f

Browse files
committed
fix: set block_size to 1
1 parent 487a1ee commit b31cb9f

6 files changed

Lines changed: 10 additions & 20 deletions

File tree

src/parallax/server/block_radix_cache.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,6 @@ def insert_block(
164164

165165
self.num_cached_blocks += 1
166166

167-
logger.debug(
168-
f"Inserted new block: block_id={block_id}, "
169-
f"tokens={token_ids[:5]}..., total_cached={self.num_cached_blocks}"
170-
)
171-
172167
if self.num_cached_blocks > self.max_cached_blocks:
173168
self._evict_lru_blocks(self.num_cached_blocks - self.max_cached_blocks)
174169

@@ -190,11 +185,6 @@ def decrease_lock_ref(self, nodes: List[BlockTreeNode]):
190185
if node.lock_ref > 0:
191186
node.lock_ref -= 1
192187

193-
if node.lock_ref == 0:
194-
logger.debug(
195-
f"Node {node.node_id} (block_id={node.block_id}) ref count = 0, evictable"
196-
)
197-
198188
def register_request(self, request_id: str, nodes: List[BlockTreeNode]):
199189
"""Register nodes used by request."""
200190
self.request_to_nodes[request_id] = nodes

src/parallax/server/cache_manager.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
linear_num_v_heads: Optional[int] = None,
4444
# Prefix Cache Config
4545
enable_prefix_cache: bool = False,
46-
max_cached_blocks: int = 1000,
46+
max_cached_blocks: Optional[int] = None,
4747
sliding_window: Optional[int] = None,
4848
):
4949
self.num_layers = num_layers
@@ -83,6 +83,9 @@ def __init__(
8383
num_gpu_blocks = 0
8484

8585
self.num_gpu_blocks = num_gpu_blocks
86+
self.max_cached_blocks = (
87+
self.num_gpu_blocks if max_cached_blocks is None else max_cached_blocks
88+
)
8689

8790
# 1. Initialize Allocators
8891
self.allocator = (
@@ -121,10 +124,10 @@ def __init__(
121124
if enable_prefix_cache and self.needs_blocks:
122125
self.prefix_cache = BlockRadixCache(
123126
block_size=block_size,
124-
max_cached_blocks=max_cached_blocks,
127+
max_cached_blocks=self.max_cached_blocks,
125128
on_block_evict=self._on_prefix_block_evict,
126129
)
127-
logger.info(f"Prefix cache enabled with max_cached_blocks={max_cached_blocks}")
130+
logger.info(f"Prefix cache enabled with max_cached_blocks={self.max_cached_blocks}")
128131

129132
# Mapping: request_id -> token_ids (for prefix matching)
130133
self.request_token_ids: Dict[str, List[int]] = {}
@@ -548,11 +551,6 @@ def insert_full_blocks_to_cache(self, request_id: str):
548551
parent_path.append(new_node)
549552
registered_nodes.append(new_node)
550553

551-
logger.debug(
552-
f"Request {request_id}: Inserted block {block_idx} "
553-
f"(block_id={block_id}) to prefix cache"
554-
)
555-
556554
if registered_nodes:
557555
if request_id in self.prefix_cache.request_to_nodes:
558556
old_nodes = self.prefix_cache.request_to_nodes[request_id]

src/parallax/server/executor/mlx_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(
178178
sliding_window = None
179179

180180
# Validate and adjust block size for Metal backend
181-
supported_block_sizes = [8, 16, 32, 64]
181+
supported_block_sizes = [1, 8, 16, 32, 64]
182182
if kv_block_size not in supported_block_sizes:
183183
nearest_block_size = min(supported_block_sizes, key=lambda x: abs(x - kv_block_size))
184184
logger.warning(

src/parallax/server/server_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def parse_args() -> argparse.Namespace:
102102
)
103103

104104
parser.add_argument(
105-
"--kv-block-size", type=int, default=32, help="Block size for KV cache management"
105+
"--kv-block-size", type=int, default=1, help="Block size for KV cache management"
106106
)
107107

108108
parser.add_argument(

src/parallax_extensions/kernels/paged_attention.metal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,8 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
14561456

14571457
#define instantiate_paged_attention_block_size(type, cache_type, num_threads, \
14581458
num_simd_lanes, partition_size) \
1459+
instantiate_paged_attention_heads(type, cache_type, 1, num_threads, \
1460+
num_simd_lanes, partition_size); \
14591461
instantiate_paged_attention_heads(type, cache_type, 8, num_threads, \
14601462
num_simd_lanes, partition_size); \
14611463
instantiate_paged_attention_heads(type, cache_type, 16, num_threads, \
2.98 MB
Binary file not shown.

0 commit comments

Comments
 (0)