Skip to content

Commit caccf89

Browse files
committed
rename ServerAdapter
1 parent 5ad352e commit caccf89

File tree

4 files changed

+16
-16
lines changed

4 files changed

+16
-16
lines changed

verl/workers/rollout/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
8181
_ROLLOUT_REGISTRY = {
8282
("vllm", "async"): "verl.workers.rollout.vllm_rollout.vLLMAsyncRollout",
8383
("sglang", "async"): "verl.workers.rollout.sglang_rollout.sglang_rollout.ServerAdapter",
84-
("trtllm", "async"): "verl.workers.rollout.trtllm_rollout.trtllm_rollout.TRTLLMAsyncRollout",
84+
("trtllm", "async"): "verl.workers.rollout.trtllm_rollout.trtllm_rollout.ServerAdapter",
8585
}
8686

8787

verl/workers/rollout/trtllm_rollout/trtllm_async_rollout.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ flowchart TB
5151
space2[" "]
5252
style space2 fill:none,stroke:none
5353
54-
subgraph AsyncRollout["<b>TRTLLMAsyncRollout<br/>(per DP rank)</b>"]
55-
DPLeader["<b>• DP Leader coordination</b>"]
56-
IPCMgmt["<b>• IPC handle management</b>"]
57-
HTTPAdapter["<b>• HTTP adapter for server communication</b>"]
58-
end
54+
subgraph AsyncRollout["<b>ServerAdapter<br/>(per DP rank)</b>"]
55+
DPLeader["<b>• DP Leader coordination</b>"]
56+
IPCMgmt["<b>• IPC handle management</b>"]
57+
HTTPAdapter["<b>• HTTP adapter for server communication</b>"]
58+
end
5959
6060
AsyncRollout -->|<b>HTTP/REST API</b>| HTTPServer
6161
@@ -223,7 +223,7 @@ flowchart TB
223223
- Validate placement group configurations
224224

225225

226-
### 3.3 `TRTLLMAsyncRollout`
226+
### 3.3 `ServerAdapter`
227227

228228
**Purpose**: Rollout worker that handles weight updates, memory management, and generation via HTTP adapter.
229229

@@ -256,7 +256,7 @@ Each DP rank has one leader (the first TP rank within that DP group), and that l
256256
```mermaid
257257
sequenceDiagram
258258
participant Client as Client/Actor
259-
participant Rollout as TRTLLMAsyncRollout
259+
participant Rollout as ServerAdapter
260260
participant Adapter as AsyncHttpAdapter
261261
participant Server as TRTLLMHttpServer
262262
participant AsyncLLM as AsyncLLM Engine

verl/workers/rollout/trtllm_rollout/trtllm_async_server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from verl.utils.device import is_cuda_available
2828
from verl.workers.config import HFModelConfig, RolloutConfig
2929
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
30-
from verl.workers.rollout.trtllm_rollout.trtllm_rollout import TRTLLMAsyncRollout
30+
from verl.workers.rollout.trtllm_rollout.trtllm_rollout import ServerAdapter
3131
from verl.workers.rollout.utils import is_valid_ipv6_address, run_unvicorn
3232

3333
logger = logging.getLogger(__file__)
@@ -184,20 +184,20 @@ async def wake_up(self):
184184
# Call all workers to switch between trainer mode and rollout mode.
185185
await asyncio.gather(*[worker.wake_up.remote() for worker in self.workers])
186186
elif self.rollout_mode == RolloutMode.COLOCATED:
187-
await self.llm.resume(tags=TRTLLMAsyncRollout.get_full_tags())
187+
await self.llm.resume(tags=ServerAdapter.get_full_tags())
188188
elif self.rollout_mode == RolloutMode.STANDALONE:
189189
logger.info("skip wake_up in standalone mode")
190190

191191
async def sleep(self):
192192
if self.rollout_mode == RolloutMode.HYBRID:
193193
await asyncio.gather(*[worker.sleep.remote() for worker in self.workers])
194194
elif self.rollout_mode == RolloutMode.COLOCATED:
195-
await self.llm.release(tags=TRTLLMAsyncRollout.get_full_tags())
195+
await self.llm.release(tags=ServerAdapter.get_full_tags())
196196
elif self.rollout_mode == RolloutMode.STANDALONE:
197197
logger.info("skip sleep in standalone mode")
198198

199199

200-
_rollout_worker_actor_cls = ray.remote(TRTLLMAsyncRollout)
200+
_rollout_worker_actor_cls = ray.remote(ServerAdapter)
201201

202202

203203
class TRTLLMReplica(RolloutReplica):

verl/workers/rollout/trtllm_rollout/trtllm_rollout.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ async def update_weights(self, weights: dict[str, str]):
277277
return await self._make_async_request("update_weights", {"weights": weights})
278278

279279

280-
class TRTLLMAsyncRollout(BaseRollout):
280+
class ServerAdapter(BaseRollout):
281281
_WEIGHTS_TAGS = [
282282
"sampler",
283283
"drafter",
@@ -291,7 +291,7 @@ class TRTLLMAsyncRollout(BaseRollout):
291291

292292
@staticmethod
293293
def get_full_tags() -> list[str]:
294-
return TRTLLMAsyncRollout._WEIGHTS_TAGS + ["kv_cache"]
294+
return ServerAdapter._WEIGHTS_TAGS + ["kv_cache"]
295295

296296
def __init__(
297297
self, config: RolloutConfig, model_config: HFModelConfig, device_mesh: DeviceMesh, replica_rank: int = -1
@@ -322,7 +322,7 @@ def __init__(
322322
logger.info(f"exclude_dp_size = {self.hybrid_device_mesh['exclude_dp'].size()}")
323323
self.gpu_id = ray.get_gpu_ids()[0]
324324
self.replica_rank = self.hybrid_device_mesh["dp"].get_local_rank()
325-
assert len(ray.get_gpu_ids()) == 1, "TRTLLMAsyncRollout should run on a single GPU node"
325+
assert len(ray.get_gpu_ids()) == 1, "ServerAdapter should run on a single GPU node"
326326
else:
327327
rank = int(os.environ["RANK"])
328328
self.replica_rank = replica_rank
@@ -332,7 +332,7 @@ def __init__(
332332
assert self.replica_rank >= 0, "replica_rank is not set"
333333
assert self.is_leader_rank is not None, "is_leader_rank is not set"
334334

335-
print(f"TRTLLMAsyncRollout, replica_rank: {self.replica_rank}, is_leader_rank: {self.is_leader_rank}")
335+
print(f"ServerAdapter, replica_rank: {self.replica_rank}, is_leader_rank: {self.is_leader_rank}")
336336

337337
self.node_ip = ray.util.get_node_ip_address().strip("[]")
338338

0 commit comments

Comments
 (0)