Skip to content

Commit c8ac1fe

Browse files
Dynamic Inference | Evict and re-compute context requests. (NVIDIA#2738)
1 parent 8b10a64 commit c8ac1fe

File tree

11 files changed

+408
-188
lines changed

11 files changed

+408
-188
lines changed

examples/inference/gpt/gpt_dynamic_inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def get_inference_context(
174174
),
175175
block_size_tokens=args.inference_dynamic_batching_block_size,
176176
buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
177+
paused_buffer_size_gb=args.inference_dynamic_batching_paused_buffer_size_gb,
177178
max_requests=args.inference_dynamic_batching_max_requests,
178179
max_tokens=args.inference_dynamic_batching_max_tokens,
179180
tensor_model_parallel_size=args.tensor_model_parallel_size,
@@ -369,6 +370,7 @@ def _add_request():
369370
request.time_end = get_curr_time()
370371
request.state = "finished"
371372
request.request_id = finished_request.request_id
373+
request.events = finished_request.events
372374

373375
# Update prompt, in case engine has been suspended and resumed.
374376
request.prompt_tokens = finished_request.prompt_tokens.tolist()
@@ -543,7 +545,7 @@ def escape_str(s):
543545
# ---- Prompt summary line ----
544546
prompt_len = len(requests[request_idxs[0]].prompt_tokens)
545547
escaped_prompt_text = escape_str(prompt_text)
546-
print(f"{unique_idx+1}/{len(unique_prompt_map)} [n {len(request_idxs)}, l {prompt_len}] {escaped_prompt_text}")
548+
print(f"\n{unique_idx+1}/{len(unique_prompt_map)} [n {len(request_idxs)}, l {prompt_len}] {escaped_prompt_text}")
547549

548550
# ---- Group all outputs for this prompt ----
549551
output_map = defaultdict(list)

examples/inference/gpt/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def add_common_inference_args(parser: ArgumentParser) -> ArgumentParser:
7272
help="Add a deterministic number of requests per step. This arg is "
7373
"prioritized over `--incoming-requests-per-sec` below (which is non-"
7474
"deterministic). Note that the number of requests added per step is "
75-
"additionally limited by the inference context's `max_active_requests`, "
75+
"additionally limited by the inference context's `max_requests`, "
7676
"`max_tokens`, and KV buffer size.",
7777
)
7878
group.add_argument(
@@ -393,7 +393,7 @@ def build_dynamic_engine_setup_prefix(
393393
394394
Args:
395395
args (Namespace): Command-line arguments for this run.
396-
context (DynamicInferenceContext): Stores limits such as `max_active_requests`,
396+
context (DynamicInferenceContext): Stores limits such as `max_requests`,
397397
`max_tokens`, and `gtd_request_count`.
398398
requests (List[DynamicInferenceRequest]): List of inference requests.
399399
@@ -430,7 +430,7 @@ def build_dynamic_engine_setup_prefix(
430430
buffer_limits_str = (
431431
f"bf: {get_mem_size_str(args.inference_dynamic_batching_buffer_size_gb*1024**3)}, "
432432
f"{context.block_allocator.active_count} chunks "
433-
f"[r {context.max_active_requests}, t {context.max_tokens}]"
433+
f"[r {context.max_requests}, t {context.max_tokens}]"
434434
)
435435

436436
parts = [

megatron/core/inference/contexts/dynamic_block_allocator.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,20 @@ class BlockAllocator:
1616
1717
Args:
1818
context (DynamicInferenceContext): Dynamic inference context.
19-
active_count (int): Total number of active blocks available in the buffer.
20-
The full buffer size is 2*active_count, to accommodate an equal-size
21-
space for paused requests that live on the CPU.
19+
total_count (int): Total number of blocks in the buffer.
20+
paused_count (int): Number of paused blocks in the buffer. Must be less
21+
than `total_count`.
2222
"""
2323

24-
def __init__(self, context: "DynamicInferenceContext", total_count: int):
24+
def __init__(self, context: "DynamicInferenceContext", total_count: int, paused_count: int):
2525

2626
self.context = context
2727

28-
active_count = (total_count - 1) // 2 # -1 for dummy_block_idx (see below)
29-
active_count = max(1, active_count) # need at least one block
30-
self.total_count = 2 * active_count + 1 # +1 for dummy_block_idx
31-
self.total_avail = self.total_count - 1 # -1 for dummy_block_idx
32-
self.active_count = active_count
33-
self.paused_count = self.total_count - self.active_count - 1 # -1 for dummy_block_idx
28+
self.total_count = total_count
29+
self.total_avail = total_count - 1 # -1 for dummy_block_idx (see below)
30+
self.paused_count = paused_count
31+
self.active_count = total_count - paused_count - 1 # -1 for dummy_block_idx
32+
assert self.active_count >= 1 # ensures paused_count < total_count - 1
3433
self.dummy_block_idx = self.total_count - 1
3534

3635
# Initialize block pool as a "stack" data structure
@@ -40,10 +39,15 @@ def __init__(self, context: "DynamicInferenceContext", total_count: int):
4039

4140
def __str__(self):
4241
return (
43-
f"total avail {self.total_avail} / {self.total_count - 1}"
44-
f"; active {self.active_count}"
42+
f"using: total {self.get_total_used()}/{self.total_count - 1}"
43+
f"; active {self.get_active_used()}/{self.active_count}"
44+
f"; paused {self.get_paused_used()}/{self.paused_count}"
4545
)
4646

47+
def get_total_used(self):
48+
"""Compute number of total blocks used."""
49+
return self.total_count - self.total_avail - 1
50+
4751
def get_active_used(self):
4852
"""Compute number of active blocks used."""
4953
return (
@@ -77,7 +81,7 @@ def is_memory_available(self, num_blocks: int) -> bool:
7781
Return:
7882
(bool) Is memory available?
7983
"""
80-
return self.get_active_avail() >= num_blocks
84+
return self.total_avail >= num_blocks
8185

8286
def allocate_memory_blocks(self, num_blocks: int) -> Optional[Tensor]:
8387
"""Allocate memory blocks if available, else return None.

0 commit comments

Comments
 (0)