[trainer] feat: support moving ppo actor logics to single controller#4480
[trainer] feat: support moving ppo actor logics to single controller#4480wuxibin89 merged 9 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request is a significant refactoring that moves the PPO actor logic into a more generic TrainingWorker, removing the specialized ActorWorker. This is a positive change towards a more unified and composable worker architecture. The changes also introduce a TensorDict-based data flow for the new implementation path, which is a modern approach for handling complex data in PyTorch. My review found one critical issue related to handling an optional dependency that could lead to a runtime crash. Otherwise, the refactoring appears well-structured, though complex.
verl/workers/utils/padding.py
Outdated
| else: | ||
| try: | ||
| # for cuda and cpu | ||
| from flash_attn.bert_padding import pad_input, unpad_input | ||
| except Exception: | ||
| pad_input, unpad_input = None, None |
There was a problem hiding this comment.
The current implementation sets pad_input and unpad_input to None if flash_attn fails to import. This will cause a TypeError later when these None values are called as functions in left_right_2_no_padding and no_padding_2_padding, as there are no checks for None before usage. This can crash the application if the optional flash-attn dependency is not installed.
A more robust approach is to define dummy functions that raise a helpful ImportError when called. This ensures that a clear error message is provided to the user, guiding them to install the missing dependency. Additionally, it's better practice to catch the specific ImportError rather than a broad Exception.
else:
try:
# for cuda and cpu
from flash_attn.bert_padding import pad_input, unpad_input
except ImportError:
def _raise_import_error(*args, **kwargs):
raise ImportError(
"flash_attn is not installed or not compatible with your environment. "
"Please install it to use padding removal features (e.g., `pip install flash-attn`)."
)
pad_input, unpad_input = _raise_import_error, _raise_import_error|
/gemini review |
|
|
||
| assert self.config.actor.use_dynamic_bsz == self.config.rollout.log_prob_use_dynamic_bsz | ||
|
|
||
| # assign engine configs |
There was a problem hiding this comment.
Why do we need to deal with engine config here but not in L507?
There was a problem hiding this comment.
The current config system is very chaotic and we need to cleanup all the legacy code to make this part clean.
There was a problem hiding this comment.
Code Review
This PR is a significant refactoring that replaces ActorWorker with TrainingWorker to move PPO actor logic to a single controller. The changes are extensive, touching configuration, the main trainer loop, and worker implementations. The overall direction is good, centralizing logic in the trainer. However, I've found a critical issue where an API change in engine_workers.py breaks the caller in ray_trainer.py, which was not updated accordingly. I've also identified a robustness issue with manual device management that should be addressed.
| def compute_ref_log_prob(self, data: TensorDict) -> TensorDict: | ||
| return self.ref.infer_batch(data=data).cpu() |
There was a problem hiding this comment.
Changing the signature of compute_ref_log_prob from taking DataProto to TensorDict breaks the calling code in verl/trainer/ppo/ray_trainer.py. The trainer still passes a DataProto and is not equipped to handle the returned TensorDict.
This will lead to a runtime error. The logic for computing reference log probabilities in ray_trainer.py needs to be updated to be consistent with the new TrainingWorker-based implementation, similar to how old_log_prob and update_actor are handled.
You should add similar logic for ref_log_prob computation in ray_trainer.py around line 1207:
if self.use_reference_policy:
# compute reference log_prob
with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
if self.use_legacy_worker_impl == "disable":
batch_td = batch.to_tensordict()
batch_td = left_right_2_no_padding(batch_td)
tu.assign_non_tensor(batch_td, calculate_entropy=False, compute_loss=False)
if not self.ref_in_actor:
output = self.ref_policy_wg.compute_ref_log_prob(batch_td)
else:
output = self.actor_rollout_wg.compute_ref_log_prob(batch_td)
log_probs = tu.get(output, "log_probs")
log_probs = no_padding_2_padding(log_probs, batch_td)
ref_log_prob = tu.get_tensordict({"ref_log_prob": log_probs.float()})
ref_log_prob = DataProto.from_tensordict(ref_log_prob)
else:
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)This will fix the type mismatch and correctly compute and assign the reference log probabilities.
verl/trainer/ppo/ray_trainer.py
Outdated
| self.actor_rollout_wg.to("device") | ||
|
|
||
| # update | ||
| output_ref_lst = [] | ||
| total_num_iterations = batch_td.shape[0] // ppo_mini_batch_size * ppo_epochs | ||
| for batch_idx, mini_batch_td in enumerate(dataloader): | ||
| # add global token num | ||
| global_token_num = mini_batch_td["input_ids"].offsets().diff().tolist() | ||
| tu.assign_non_tensor( | ||
| mini_batch_td, | ||
| global_token_num=NonTensorData(global_token_num), | ||
| update_lr_scheduler=batch_idx == total_num_iterations - 1, | ||
| disable_auto_offload=True, | ||
| ) | ||
| actor_output_ref = self.actor_rollout_wg.train_batch(mini_batch_td) | ||
| output_ref_lst.append(actor_output_ref) | ||
|
|
||
| actor_output = [output_ref.get() for output_ref in output_ref_lst] | ||
| actor_output = [tu.get(output, "metrics") for output in actor_output] | ||
|
|
||
| # manually sleep actor | ||
| self.actor_rollout_wg.to("cpu") |
There was a problem hiding this comment.
The manual device management with to("device") and to("cpu") is a bit risky. If an exception occurs during the update loop (e.g., during output_ref.get()), the to("cpu") call will be skipped, leaving the actor model on the GPU. This could lead to out-of-memory errors in subsequent steps.
To make this more robust, you should wrap the update logic in a try...finally block to ensure to("cpu") is always called.
# manually wakeup actor
self.actor_rollout_wg.to("device")
try:
# update
output_ref_lst = []
total_num_iterations = batch_td.shape[0] // ppo_mini_batch_size * ppo_epochs
for batch_idx, mini_batch_td in enumerate(dataloader):
# add global token num
global_token_num = mini_batch_td["input_ids"].offsets().diff().tolist()
tu.assign_non_tensor(
mini_batch_td,
global_token_num=NonTensorData(global_token_num),
update_lr_scheduler=batch_idx == total_num_iterations - 1,
disable_auto_offload=True,
)
actor_output_ref = self.actor_rollout_wg.train_batch(mini_batch_td)
output_ref_lst.append(actor_output_ref)
actor_output = [output_ref.get() for output_ref in output_ref_lst]
actor_output = [tu.get(output, "metrics") for output in actor_output]
finally:
# manually sleep actor
self.actor_rollout_wg.to("cpu")
verl/trainer/ppo/ray_trainer.py
Outdated
| batch.meta_info["temperature"] = rollout_config.temperature | ||
| actor_output = self.actor_rollout_wg.update_actor(batch) | ||
| # update actor | ||
| if self.use_legacy_worker_impl == "disable": |
There was a problem hiding this comment.
Should we move to a separate function? The fit function is quite long.
There was a problem hiding this comment.
Let me fix this and update compute_ref_log_prob as well
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request refactors the PPO actor logic from a dedicated ActorWorker to a more generic TrainingWorker, which is a significant and positive architectural change. The logic is moved into RayPPOTrainer helper methods, which now use TensorDict and new padding utilities, improving consistency. However, I've found two critical issues that need to be addressed:
- A large and important test for the PPO actor logic has been removed without a clear replacement, which poses a regression risk.
- The dynamic import logic for attention functions has been modified incorrectly, which will cause import errors in environments without NVIDIA GPUs.
I am having trouble creating individual review comments. Click here to see my feedback.
tests/models/test_engine.py (209-323)
The test test_actor_engine has been removed entirely. This was a substantial test (over 100 lines) that specifically validated the PPO actor logic, including compute_log_prob and update_actor. While ActorWorker is being replaced by TrainingWorker, it's crucial to ensure that the equivalent functionality in TrainingWorker is still covered by tests to prevent regressions. The existing test_engine seems to be a general test for TrainingWorker, but it's not clear if it fully covers the PPO-specific actor logic that was tested in test_actor_engine. Please add back a similar test for the new implementation using TrainingWorker.
verl/utils/attention_utils.py (23-30)
The logic for importing attention functions has been changed in a way that could cause an ImportError. The new code assumes that if NPU is not available, it can default to importing from flash_attn. However, flash_attn requires CUDA to be available. If neither NPU nor CUDA is available, this will fail. The previous implementation correctly checked for is_cuda_available. Please restore the check for CUDA availability to prevent crashes in environments without NVIDIA GPUs.
from verl.utils.device import is_cuda_available, is_npu_available
global _index_first_axis, _pad_input, _rearrange, _unpad_input
if is_npu_available:
from verl.utils.npu_flash_attn_utils import index_first_axis, pad_input, rearrange, unpad_input
elif is_cuda_available:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
…erl-project#4480) ### What does this PR do? - Support moving ppo actor logics to single controller using TrainingWorker - Remove ActorWorker ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
…erl-project#4480) ### What does this PR do? - Support moving ppo actor logics to single controller using TrainingWorker - Remove ActorWorker ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
…erl-project#4480) ### What does this PR do? - Support moving ppo actor logics to single controller using TrainingWorker - Remove ActorWorker ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)