Skip to content

Commit d8195a8

Browse files
authored
[trainer] feat: support moving ppo actor logics to single controller (#4480)
### What does this PR do? - Support moving ppo actor logics to single controller using TrainingWorker - Remove ActorWorker ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 5a2e0b1 commit d8195a8

File tree

9 files changed

+215
-317
lines changed

9 files changed

+215
-317
lines changed

tests/models/test_engine.py

Lines changed: 1 addition & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
McoreEngineConfig,
4949
McoreOptimizerConfig,
5050
)
51-
from verl.workers.engine_workers import ActorWorker, CriticWorker, TrainingWorker, TrainingWorkerConfig
51+
from verl.workers.engine_workers import CriticWorker, TrainingWorker, TrainingWorkerConfig
5252
from verl.workers.utils.losses import ppo_loss, sft_loss
5353
from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding
5454

@@ -206,121 +206,6 @@ def test_engine(strategy):
206206
ray.shutdown()
207207

208208

209-
@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"])
210-
def test_actor_engine(strategy):
211-
ray.init()
212-
213-
path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B")
214-
model_config = HFModelConfig(path=path)
215-
216-
if strategy == "megatron":
217-
engine_config = McoreEngineConfig(
218-
forward_only=False,
219-
use_mbridge=False,
220-
tensor_model_parallel_size=2,
221-
pipeline_model_parallel_size=2,
222-
context_parallel_size=2,
223-
)
224-
optimizer_config = McoreOptimizerConfig(lr_decay_steps=10)
225-
elif strategy in ["fsdp", "fsdp2"]:
226-
engine_config = FSDPEngineConfig(
227-
forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2
228-
)
229-
optimizer_config = FSDPOptimizerConfig()
230-
else:
231-
raise NotImplementedError(f"strategy {strategy} is not supported")
232-
233-
config = ActorConfig(
234-
model_config=model_config,
235-
engine=engine_config,
236-
strategy=strategy,
237-
ppo_micro_batch_size_per_gpu=256,
238-
ppo_mini_batch_size=4,
239-
optim=optimizer_config,
240-
use_dynamic_bsz=True,
241-
rollout_n=1,
242-
)
243-
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorWorker), config=config)
244-
resource_pool = RayResourcePool(process_on_nodes=[8])
245-
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
246-
# init model
247-
wg.init_model()
248-
249-
batch_size = 8
250-
seqlen = 32
251-
252-
response_length = seqlen // 2
253-
254-
torch.manual_seed(1)
255-
np.random.seed(1)
256-
257-
input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen))
258-
attention_mask = create_random_mask(
259-
input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6
260-
)
261-
position_ids = compute_position_id_with_mask(attention_mask)
262-
263-
global_token_num = torch.sum(attention_mask, dim=-1).tolist()
264-
265-
print(input_ids.float().mean(), attention_mask.float().mean())
266-
267-
responses = input_ids[:, response_length:]
268-
response_mask = attention_mask[:, response_length:]
269-
270-
assert torch.all(response_mask[:, 0] == 1)
271-
272-
data = DataProto.from_single_dict(
273-
{
274-
"input_ids": input_ids,
275-
"prompts": input_ids[:, :response_length],
276-
"attention_mask": attention_mask,
277-
"position_ids": position_ids,
278-
"responses": responses,
279-
"response_mask": response_mask,
280-
},
281-
meta_info={"temperature": 1.0, "global_token_num": global_token_num},
282-
)
283-
284-
# sft_loss_ = partial(sft_loss, config=config)
285-
286-
# eval
287-
output = wg.compute_log_prob(data)
288-
289-
# load hf model and compare results with hf model
290-
hf_model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16)
291-
hf_output = hf_model(input_ids, attention_mask=attention_mask)
292-
hf_logprobs = logprobs_from_logits_naive(
293-
hf_output.logits[:, -response_length - 1 : -1, :].float(), input_ids[:, -response_length:]
294-
)
295-
hf_logprobs_mean = torch.mean(hf_logprobs * response_mask)
296-
mcore_logprobs_mean = torch.mean(output.batch["old_log_probs"] * response_mask)
297-
298-
torch.testing.assert_close(hf_logprobs_mean, mcore_logprobs_mean, atol=1e-3, rtol=1e-2)
299-
300-
data = data.union(output)
301-
302-
# TODO: sft_loss_ is not compatible with ActorWorker until we replace DataProto with torch.jagged TensorDict
303-
# wg.set_loss_fn(sft_loss_)
304-
305-
# train for one step
306-
# metrics = wg.update_actor(data)
307-
# print(metrics)
308-
309-
# add ppo data
310-
data.batch["advantages"] = torch.rand_like(responses, dtype=torch.float32)
311-
data.batch["ref_log_prob"] = torch.rand_like(responses, dtype=torch.float32)
312-
313-
# set ppo loss
314-
ppo_loss_ = partial(ppo_loss, config=config)
315-
wg.set_loss_fn(ppo_loss_)
316-
317-
# update again
318-
ppo_metrics = wg.update_actor(data)
319-
print(ppo_metrics)
320-
321-
ray.shutdown()
322-
323-
324209
def create_model():
325210
from transformers import Qwen3Config
326211

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ actor_rollout_ref:
8888
kl_loss_type: low_var_kl
8989
ppo_epochs: 1
9090
shuffle: false
91+
data_loader_seed: 42
9192
checkpoint:
9293
_target_: verl.trainer.config.CheckpointConfig
9394
save_contents:
@@ -127,7 +128,6 @@ actor_rollout_ref:
127128
mode: disabled
128129
record_file: null
129130
replay_file: null
130-
data_loader_seed: 42
131131
load_weight: true
132132
ref:
133133
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ actor_rollout_ref:
7575
kl_loss_type: low_var_kl
7676
ppo_epochs: 1
7777
shuffle: false
78+
data_loader_seed: 42
7879
checkpoint:
7980
_target_: verl.trainer.config.CheckpointConfig
8081
save_contents:

verl/trainer/config/actor/actor.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ ppo_epochs: 1
103103
# Shuffle training data across PPO epochs
104104
shuffle: false
105105

106+
# The seed used to construct mini-batch
107+
data_loader_seed: 42
108+
106109
# checkpoint configs
107110
checkpoint:
108111

verl/trainer/config/actor/megatron_actor.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,4 @@ _target_: verl.workers.config.McoreActorConfig
1515

1616
strategy: megatron
1717

18-
data_loader_seed: 42
19-
2018
load_weight: True

verl/trainer/ppo/ray_trainer.py

Lines changed: 131 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
from collections import defaultdict
2525
from copy import deepcopy
2626
from dataclasses import dataclass, field
27+
from itertools import chain
2728
from pprint import pprint
2829
from typing import Optional
2930

3031
import numpy as np
3132
import ray
3233
import torch
3334
from omegaconf import OmegaConf, open_dict
35+
from tensordict import NonTensorData
3436
from torch.utils.data import Dataset, Sampler
3537
from torchdata.stateful_dataloader import StatefulDataLoader
3638
from tqdm import tqdm
@@ -51,14 +53,17 @@
5153
)
5254
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
5355
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
56+
from verl.utils import tensordict_utils as tu
5457
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
5558
from verl.utils.config import omega_conf_to_dataclass
5659
from verl.utils.debug import marked_timer
5760
from verl.utils.metric import reduce_metrics
61+
from verl.utils.py_functional import append_to_dict
5862
from verl.utils.rollout_skip import RolloutSkip
5963
from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance
6064
from verl.utils.torch_functional import masked_mean
6165
from verl.utils.tracking import ValidationGenerationsLogger
66+
from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding
6267

6368

6469
@dataclass
@@ -343,6 +348,8 @@ def __init__(
343348
if self.config.algorithm.use_kl_in_reward:
344349
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
345350

351+
self.use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
352+
346353
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
347354

348355
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
@@ -772,6 +779,9 @@ def init_workers(self):
772779
self.actor_rollout_wg = all_wg[str(actor_role)]
773780
self.actor_rollout_wg.init_model()
774781

782+
if self.ref_in_actor:
783+
self.ref_policy_wg = self.actor_rollout_wg
784+
775785
# create async rollout manager and request scheduler
776786
self.async_rollout_mode = False
777787
if self.config.actor_rollout_ref.rollout.mode == "async":
@@ -974,6 +984,120 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle
974984
)
975985
metrics.update(global_balance_stats)
976986

987+
def _compute_ref_log_prob(self, batch: DataProto) -> DataProto:
988+
if self.use_legacy_worker_impl == "disable":
989+
# step 1: convert dataproto to tensordict.
990+
batch_td = batch.to_tensordict()
991+
# step 2: convert from padding to nopadding
992+
batch_td = left_right_2_no_padding(batch_td)
993+
# step 3: add meta info
994+
tu.assign_non_tensor(batch_td, calculate_entropy=False, compute_loss=False)
995+
output = self.ref_policy_wg.compute_ref_log_prob(batch_td)
996+
# gather output
997+
log_probs = tu.get(output, "log_probs")
998+
# step 4. No padding to padding
999+
log_probs = no_padding_2_padding(log_probs, batch_td)
1000+
# step 5: rebuild a tensordict and convert to dataproto
1001+
ref_log_prob = tu.get_tensordict({"ref_log_prob": log_probs.float()})
1002+
ref_log_prob = DataProto.from_tensordict(ref_log_prob)
1003+
else:
1004+
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
1005+
1006+
return ref_log_prob
1007+
1008+
def _compute_old_log_prob(self, batch: DataProto):
1009+
if self.use_legacy_worker_impl == "disable":
1010+
# TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free
1011+
# step 1: convert dataproto to tensordict.
1012+
batch_td = batch.to_tensordict()
1013+
# step 2: convert from padding to nopadding
1014+
batch_td = left_right_2_no_padding(batch_td)
1015+
# step 3: add meta info
1016+
tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False)
1017+
output = self.actor_rollout_wg.compute_log_prob(batch_td)
1018+
# gather output
1019+
entropy = tu.get(output, "entropy")
1020+
log_probs = tu.get(output, "log_probs")
1021+
old_log_prob_mfu = tu.get(output, "metrics")["mfu"]
1022+
# step 4. No padding to padding
1023+
entropy = no_padding_2_padding(entropy, batch_td)
1024+
log_probs = no_padding_2_padding(log_probs, batch_td)
1025+
# step 5: rebuild a tensordict and convert to dataproto
1026+
old_log_prob = tu.get_tensordict({"old_log_probs": log_probs.float(), "entropys": entropy.float()})
1027+
old_log_prob = DataProto.from_tensordict(old_log_prob)
1028+
else:
1029+
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
1030+
old_log_prob_mfu = 0
1031+
return old_log_prob, old_log_prob_mfu
1032+
1033+
def _update_actor(self, batch: DataProto) -> DataProto:
1034+
rollout_config = self.config.actor_rollout_ref.rollout
1035+
batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
1036+
# TODO: Make "temperature" single source of truth from generation.
1037+
batch.meta_info["temperature"] = rollout_config.temperature
1038+
# update actor
1039+
if self.use_legacy_worker_impl == "disable":
1040+
batch_td = batch.to_tensordict()
1041+
# step 2: convert from padding to no-padding
1042+
batch_td = left_right_2_no_padding(batch_td)
1043+
calculate_entropy = self.config.actor_rollout_ref.actor.entropy_coeff != 0.0
1044+
ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
1045+
ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n
1046+
ppo_epochs = self.config.actor_rollout_ref.actor.ppo_epochs
1047+
seed = self.config.actor_rollout_ref.actor.data_loader_seed
1048+
shuffle = self.config.actor_rollout_ref.actor.shuffle
1049+
tu.assign_non_tensor(batch_td, calculate_entropy=calculate_entropy, global_batch_size=ppo_mini_batch_size)
1050+
1051+
# make iterator
1052+
dataloader = tu.make_iterator(
1053+
batch_td,
1054+
mini_batch_size=ppo_mini_batch_size,
1055+
epochs=ppo_epochs,
1056+
seed=seed,
1057+
dataloader_kwargs={"shuffle": shuffle},
1058+
)
1059+
# manually wakeup actor
1060+
self.actor_rollout_wg.to("device")
1061+
1062+
# update
1063+
output_ref_lst = []
1064+
total_num_iterations = batch_td.shape[0] // ppo_mini_batch_size * ppo_epochs
1065+
for batch_idx, mini_batch_td in enumerate(dataloader):
1066+
# add global token num
1067+
global_token_num = mini_batch_td["input_ids"].offsets().diff().tolist()
1068+
tu.assign_non_tensor(
1069+
mini_batch_td,
1070+
global_token_num=NonTensorData(global_token_num),
1071+
update_lr_scheduler=batch_idx == total_num_iterations - 1,
1072+
disable_auto_offload=True,
1073+
)
1074+
actor_output_ref = self.actor_rollout_wg.train_batch(mini_batch_td)
1075+
output_ref_lst.append(actor_output_ref)
1076+
1077+
actor_output = [output_ref.get() for output_ref in output_ref_lst]
1078+
actor_output = [tu.get(output, "metrics") for output in actor_output]
1079+
1080+
# manually sleep actor
1081+
self.actor_rollout_wg.to("cpu")
1082+
1083+
# each metric is a list of list (dp and micro-batch) (output[0] is metric in dp[0])
1084+
# flatten each metric
1085+
agg_actor_output = {}
1086+
for output in actor_output:
1087+
for key, val in output.items():
1088+
# flattn dp and micro batch
1089+
if isinstance(val, list):
1090+
output[key] = list(chain.from_iterable(val))
1091+
append_to_dict(agg_actor_output, output, prefix="actor/")
1092+
1093+
# modify key name
1094+
agg_actor_output["perf/mfu/actor"] = agg_actor_output.pop("actor/mfu")
1095+
1096+
actor_output = DataProto.from_single_dict(data={}, meta_info={"metrics": agg_actor_output})
1097+
else:
1098+
actor_output = self.actor_rollout_wg.update_actor(batch)
1099+
return actor_output
1100+
9771101
def fit(self):
9781102
"""
9791103
The training loop of PPO.
@@ -1143,7 +1267,7 @@ def fit(self):
11431267
)
11441268
else: # Recompute old_log_probs
11451269
with marked_timer("old_log_prob", timing_raw, color="blue"):
1146-
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
1270+
old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch)
11471271
entropys = old_log_prob.batch["entropys"]
11481272
response_masks = batch.batch["response_mask"]
11491273
actor_config = self.config.actor_rollout_ref.actor
@@ -1153,7 +1277,10 @@ def fit(self):
11531277
loss_agg_mode=actor_config.loss_agg_mode,
11541278
loss_scale_factor=actor_config.loss_scale_factor,
11551279
)
1156-
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
1280+
old_log_prob_metrics = {
1281+
"actor/entropy": entropy_agg.detach().item(),
1282+
"perf/mfu/actor_infer": old_log_prob_mfu,
1283+
}
11571284
metrics.update(old_log_prob_metrics)
11581285
old_log_prob.batch.pop("entropys")
11591286
batch = batch.union(old_log_prob)
@@ -1168,10 +1295,7 @@ def fit(self):
11681295
if self.use_reference_policy:
11691296
# compute reference log_prob
11701297
with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
1171-
if not self.ref_in_actor:
1172-
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
1173-
else:
1174-
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
1298+
ref_log_prob = self._compute_ref_log_prob(batch)
11751299
batch = batch.union(ref_log_prob)
11761300

11771301
# compute values
@@ -1240,11 +1364,7 @@ def fit(self):
12401364
if self.config.trainer.critic_warmup <= self.global_steps:
12411365
# update actor
12421366
with marked_timer("update_actor", timing_raw, color="red"):
1243-
rollout_config = self.config.actor_rollout_ref.rollout
1244-
batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
1245-
# TODO: Make "temperature" single source of truth from generation.
1246-
batch.meta_info["temperature"] = rollout_config.temperature
1247-
actor_output = self.actor_rollout_wg.update_actor(batch)
1367+
actor_output = self._update_actor(batch)
12481368
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
12491369
metrics.update(actor_output_metrics)
12501370

0 commit comments

Comments
 (0)