-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Expand file tree
/
Copy pathasync_llm.py
More file actions
125 lines (99 loc) · 4.69 KB
/
async_llm.py
File metadata and controls
125 lines (99 loc) · 4.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import Any, List, Optional
from ..llmapi.llm import LLM
from ..llmapi.llm_args import RayPlacementConfig
class AsyncLLM(LLM):
"""AsyncLLM is a subclass of LLM that supports asynchronous setup, release and
resume operations that are necessary for RL or agentic scenarios.
Currently, RL APIs are only supported with Ray orchestrator.
"""
def __init__(
self,
placement_groups: Optional[List[Any]] = None,
placement_bundle_indices: Optional[List[List[int]]] = None,
per_worker_gpu_share: Optional[float] = None,
*args,
**kwargs,
):
kwargs["orchestrator_type"] = "ray"
kwargs["ray_placement_config"] = RayPlacementConfig(
defer_workers_init=True,
placement_groups=placement_groups,
placement_bundle_indices=placement_bundle_indices,
per_worker_gpu_share=per_worker_gpu_share,
)
# WAR: RL integration needs to use NCCL AllReduce for TP>1 due to a bug in TRTLLM's AllReduce
# which will cause convergence issue when using multiple rollout instances.
kwargs["allreduce_strategy"] = "NCCL"
if "ray_worker_extension_cls" not in kwargs:
kwargs["ray_worker_extension_cls"] = "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension"
super().__init__(*args, **kwargs)
self._async_initialized = False
self._paused = False
async def setup_async(self):
"""Setup the LLM asynchronously."""
if not self._async_initialized:
await self._executor.init_workers_async()
await self._executor.setup_engine_remote_async()
self._async_initialized = True
return self
async def release(self, tags: list[str]):
"""Release the GPU memory used by the LLM asynchronously.
Args:
tags: List of memory tag strings to release (e.g., ["model", "kv_cache"]).
"""
await self.collective_rpc("sleep", args=(tags,))
async def resume(self, tags: list[str]):
"""Resume the GPU memory used by the LLM asynchronously.
Args:
tags: List of memory tag strings to resume (e.g., ["model", "kv_cache"]).
"""
await self.collective_rpc("wakeup", args=(tags,))
async def update_weights(self, weights: dict[str, str]):
"""Update the weights of the LLM asynchronously.
Args:
weights: Dictionary mapping device UUIDs to IPC handles for weight tensors.
"""
await self.collective_rpc("update_weights", args=(weights,))
async def collective_rpc(
self,
method: str,
args: tuple[Any, ...] = (),
kwargs: Optional[dict] = None,
unique_reply_rank: Optional[int] = None, # TODO: deprecate this in the future
target_ranks: int | list[int] | None = None,
) -> list[Any]:
"""Execute an asynchronous RPC call on all GPU workers. Currently, this is only supported for RayExecutor.
Args:
method (str): The name of the worker method to execute.
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply.
target_ranks (int | list[int] | None): The ranks of the workers that will be used to send the reply.
Returns:
list[Any]: A list of results from each worker.
"""
return await self._executor.collective_rpc_async(
method, args, kwargs, unique_reply_rank=unique_reply_rank, target_ranks=target_ranks
)
def generate_async(self, *args, **kwargs):
if self._paused:
raise RuntimeError(
"AsyncLLM is paused. Call resume_generation() before submitting new requests."
)
return super().generate_async(*args, **kwargs)
async def pause_generation(self) -> None:
"""Abort all in-flight requests and block new ones until resume_generation() is called."""
self._paused = True
self._executor.abort_all_requests()
async def resume_generation(self) -> None:
"""Allow new generation requests after a pause_generation() call."""
self._paused = False
def __await__(self):
return self.setup_async().__await__()
def __enter__(self):
raise RuntimeError("Please use 'async with AsyncLLM' instead")
async def __aenter__(self):
await self.setup_async()
return super().__enter__()
async def __aexit__(self, exc_type, exc_val, exc_tb):
return super().__exit__(exc_type, exc_val, exc_tb)