@@ -16,21 +16,20 @@ class BlockAllocator:
1616
1717 Args:
1818 context (DynamicInferenceContext): Dynamic inference context.
19- active_count (int): Total number of active blocks available in the buffer.
20- The full buffer size is 2*active_count, to accommodate an equal-size
21- space for paused requests that live on the CPU .
19+ total_count (int): Total number of blocks in the buffer.
20+ paused_count (int): Number of paused blocks in the buffer. Must be less
21+ than `total_count` .
2222 """
2323
24- def __init__ (self , context : "DynamicInferenceContext" , total_count : int ):
24+ def __init__ (self , context : "DynamicInferenceContext" , total_count : int , paused_count : int ):
2525
2626 self .context = context
2727
28- active_count = (total_count - 1 ) // 2 # -1 for dummy_block_idx (see below)
29- active_count = max (1 , active_count ) # need at least one block
30- self .total_count = 2 * active_count + 1 # +1 for dummy_block_idx
31- self .total_avail = self .total_count - 1 # -1 for dummy_block_idx
32- self .active_count = active_count
33- self .paused_count = self .total_count - self .active_count - 1 # -1 for dummy_block_idx
28+ self .total_count = total_count
29+ self .total_avail = total_count - 1 # -1 for dummy_block_idx (see below)
30+ self .paused_count = paused_count
31+ self .active_count = total_count - paused_count - 1 # -1 for dummy_block_idx
32+ assert self .active_count >= 1 # ensures paused_count < total_count - 1
3433 self .dummy_block_idx = self .total_count - 1
3534
3635 # Initialize block pool as a "stack" data structure
@@ -40,10 +39,15 @@ def __init__(self, context: "DynamicInferenceContext", total_count: int):
4039
4140 def __str__ (self ):
4241 return (
43- f"total avail { self .total_avail } / { self .total_count - 1 } "
44- f"; active { self .active_count } "
42+ f"using: total { self .get_total_used ()} /{ self .total_count - 1 } "
43+ f"; active { self .get_active_used ()} /{ self .active_count } "
44+ f"; paused { self .get_paused_used ()} /{ self .paused_count } "
4545 )
4646
47+ def get_total_used (self ):
48+ """Compute number of total blocks used."""
49+ return self .total_count - self .total_avail - 1
50+
4751 def get_active_used (self ):
4852 """Compute number of active blocks used."""
4953 return (
@@ -77,7 +81,7 @@ def is_memory_available(self, num_blocks: int) -> bool:
7781 Return:
7882 (bool) Is memory available?
7983 """
80- return self .get_active_avail () >= num_blocks
84+ return self .total_avail >= num_blocks
8185
8286 def allocate_memory_blocks (self , num_blocks : int ) -> Optional [Tensor ]:
8387 """Allocate memory blocks if available, else return None.
0 commit comments