Skip to content

Commit cc7e283

Browse files
authored
[trtllm] fix: minor fixes to trtllm rollout (verl-project#5095)
### What does this PR do? 1. Revert verl-project#5085, because TRTLLM needs this actor-level env var otherwise TRTLLM CI will break (as shown in the CI check of MR 5085). Also, the CPU affinity env var affected the e2e TRTLLM throughput. 2. Tentatively revert empty cache in `update_weights()`. I'm seeing this affect KL loss and need to further investigate. https://wandb.ai/nvidia/verify-mr-5032 > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
1 parent f904162 commit cc7e283

File tree

2 files changed

+1
-4
lines changed

2 files changed

+1
-4
lines changed

verl/workers/rollout/trtllm_rollout/trtllm_async_server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,12 @@ async def launch_servers(self):
329329
else f"trtllm_server_reward_{self.replica_rank}"
330330
)
331331

332-
runtime_env_vars = {"TLLM_NUMA_AWARE_WORKER_AFFINITY": "0"}
333332
server = TRTLLMHttpServer.options(
334333
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
335334
node_id=node_id,
336335
soft=False,
337336
),
338-
runtime_env={"env_vars": runtime_env_vars},
337+
runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}},
339338
name=name,
340339
).remote(
341340
config=self.config,

verl/workers/rollout/trtllm_rollout/trtllm_rollout.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
3232
from torch.multiprocessing.reductions import reduce_tensor
3333

34-
from verl.utils.memory_utils import aggressive_empty_cache
3534
from verl.utils.net_utils import is_valid_ipv6_address
3635
from verl.workers.config import HFModelConfig, RolloutConfig
3736
from verl.workers.rollout.base import BaseRollout
@@ -425,4 +424,3 @@ async def flush():
425424
# Finalize update weights
426425
await self._adapter.update_weights(None)
427426
await asyncio.to_thread(dist.barrier, group=self.hybrid_device_mesh["exclude_dp"].get_group())
428-
aggressive_empty_cache(force_sync=False)

0 commit comments

Comments
 (0)