Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 1 addition & 116 deletions tests/models/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
McoreEngineConfig,
McoreOptimizerConfig,
)
from verl.workers.engine_workers import ActorWorker, CriticWorker, TrainingWorker, TrainingWorkerConfig
from verl.workers.engine_workers import CriticWorker, TrainingWorker, TrainingWorkerConfig
from verl.workers.utils.losses import ppo_loss, sft_loss
from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding

Expand Down Expand Up @@ -206,121 +206,6 @@ def test_engine(strategy):
ray.shutdown()


@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"])
def test_actor_engine(strategy):
ray.init()

path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B")
model_config = HFModelConfig(path=path)

if strategy == "megatron":
engine_config = McoreEngineConfig(
forward_only=False,
use_mbridge=False,
tensor_model_parallel_size=2,
pipeline_model_parallel_size=2,
context_parallel_size=2,
)
optimizer_config = McoreOptimizerConfig(lr_decay_steps=10)
elif strategy in ["fsdp", "fsdp2"]:
engine_config = FSDPEngineConfig(
forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2
)
optimizer_config = FSDPOptimizerConfig()
else:
raise NotImplementedError(f"strategy {strategy} is not supported")

config = ActorConfig(
model_config=model_config,
engine=engine_config,
strategy=strategy,
ppo_micro_batch_size_per_gpu=256,
ppo_mini_batch_size=4,
optim=optimizer_config,
use_dynamic_bsz=True,
rollout_n=1,
)
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorWorker), config=config)
resource_pool = RayResourcePool(process_on_nodes=[8])
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
# init model
wg.init_model()

batch_size = 8
seqlen = 32

response_length = seqlen // 2

torch.manual_seed(1)
np.random.seed(1)

input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen))
attention_mask = create_random_mask(
input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6
)
position_ids = compute_position_id_with_mask(attention_mask)

global_token_num = torch.sum(attention_mask, dim=-1).tolist()

print(input_ids.float().mean(), attention_mask.float().mean())

responses = input_ids[:, response_length:]
response_mask = attention_mask[:, response_length:]

assert torch.all(response_mask[:, 0] == 1)

data = DataProto.from_single_dict(
{
"input_ids": input_ids,
"prompts": input_ids[:, :response_length],
"attention_mask": attention_mask,
"position_ids": position_ids,
"responses": responses,
"response_mask": response_mask,
},
meta_info={"temperature": 1.0, "global_token_num": global_token_num},
)

# sft_loss_ = partial(sft_loss, config=config)

# eval
output = wg.compute_log_prob(data)

# load hf model and compare results with hf model
hf_model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16)
hf_output = hf_model(input_ids, attention_mask=attention_mask)
hf_logprobs = logprobs_from_logits_naive(
hf_output.logits[:, -response_length - 1 : -1, :].float(), input_ids[:, -response_length:]
)
hf_logprobs_mean = torch.mean(hf_logprobs * response_mask)
mcore_logprobs_mean = torch.mean(output.batch["old_log_probs"] * response_mask)

torch.testing.assert_close(hf_logprobs_mean, mcore_logprobs_mean, atol=1e-3, rtol=1e-2)

data = data.union(output)

# TODO: sft_loss_ is not compatible with ActorWorker until we replace DataProto with torch.jagged TensorDict
# wg.set_loss_fn(sft_loss_)

# train for one step
# metrics = wg.update_actor(data)
# print(metrics)

# add ppo data
data.batch["advantages"] = torch.rand_like(responses, dtype=torch.float32)
data.batch["ref_log_prob"] = torch.rand_like(responses, dtype=torch.float32)

# set ppo loss
ppo_loss_ = partial(ppo_loss, config=config)
wg.set_loss_fn(ppo_loss_)

# update again
ppo_metrics = wg.update_actor(data)
print(ppo_metrics)

ray.shutdown()


def create_model():
from transformers import Qwen3Config

Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ actor_rollout_ref:
kl_loss_type: low_var_kl
ppo_epochs: 1
shuffle: false
data_loader_seed: 42
checkpoint:
_target_: verl.trainer.config.CheckpointConfig
save_contents:
Expand Down Expand Up @@ -127,7 +128,6 @@ actor_rollout_ref:
mode: disabled
record_file: null
replay_file: null
data_loader_seed: 42
load_weight: true
ref:
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ actor_rollout_ref:
kl_loss_type: low_var_kl
ppo_epochs: 1
shuffle: false
data_loader_seed: 42
checkpoint:
_target_: verl.trainer.config.CheckpointConfig
save_contents:
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/actor/actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ ppo_epochs: 1
# Shuffle training data across PPO epochs
shuffle: false

# The seed used to construct mini-batch
data_loader_seed: 42

# checkpoint configs
checkpoint:

Expand Down
2 changes: 0 additions & 2 deletions verl/trainer/config/actor/megatron_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,4 @@ _target_: verl.workers.config.McoreActorConfig

strategy: megatron

data_loader_seed: 42

load_weight: True
142 changes: 131 additions & 11 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from itertools import chain
from pprint import pprint
from typing import Optional

import numpy as np
import ray
import torch
from omegaconf import OmegaConf, open_dict
from tensordict import NonTensorData
from torch.utils.data import Dataset, Sampler
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm
Expand All @@ -51,14 +53,17 @@
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.debug import marked_timer
from verl.utils.metric import reduce_metrics
from verl.utils.py_functional import append_to_dict
from verl.utils.rollout_skip import RolloutSkip
from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding


@dataclass
Expand Down Expand Up @@ -343,6 +348,8 @@ def __init__(
if self.config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)

self.use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")

self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)

def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
Expand Down Expand Up @@ -772,6 +779,9 @@ def init_workers(self):
self.actor_rollout_wg = all_wg[str(actor_role)]
self.actor_rollout_wg.init_model()

if self.ref_in_actor:
self.ref_policy_wg = self.actor_rollout_wg

# create async rollout manager and request scheduler
self.async_rollout_mode = False
if self.config.actor_rollout_ref.rollout.mode == "async":
Expand Down Expand Up @@ -974,6 +984,120 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle
)
metrics.update(global_balance_stats)

def _compute_ref_log_prob(self, batch: DataProto) -> DataProto:
if self.use_legacy_worker_impl == "disable":
# step 1: convert dataproto to tensordict.
batch_td = batch.to_tensordict()
# step 2: convert from padding to nopadding
batch_td = left_right_2_no_padding(batch_td)
# step 3: add meta info
tu.assign_non_tensor(batch_td, calculate_entropy=False, compute_loss=False)
output = self.ref_policy_wg.compute_ref_log_prob(batch_td)
# gather output
log_probs = tu.get(output, "log_probs")
# step 4. No padding to padding
log_probs = no_padding_2_padding(log_probs, batch_td)
# step 5: rebuild a tensordict and convert to dataproto
ref_log_prob = tu.get_tensordict({"ref_log_prob": log_probs.float()})
ref_log_prob = DataProto.from_tensordict(ref_log_prob)
else:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)

return ref_log_prob

def _compute_old_log_prob(self, batch: DataProto):
if self.use_legacy_worker_impl == "disable":
# TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free
# step 1: convert dataproto to tensordict.
batch_td = batch.to_tensordict()
# step 2: convert from padding to nopadding
batch_td = left_right_2_no_padding(batch_td)
# step 3: add meta info
tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False)
output = self.actor_rollout_wg.compute_log_prob(batch_td)
# gather output
entropy = tu.get(output, "entropy")
log_probs = tu.get(output, "log_probs")
old_log_prob_mfu = tu.get(output, "metrics")["mfu"]
# step 4. No padding to padding
entropy = no_padding_2_padding(entropy, batch_td)
log_probs = no_padding_2_padding(log_probs, batch_td)
# step 5: rebuild a tensordict and convert to dataproto
old_log_prob = tu.get_tensordict({"old_log_probs": log_probs.float(), "entropys": entropy.float()})
old_log_prob = DataProto.from_tensordict(old_log_prob)
else:
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
old_log_prob_mfu = 0
return old_log_prob, old_log_prob_mfu

def _update_actor(self, batch: DataProto) -> DataProto:
rollout_config = self.config.actor_rollout_ref.rollout
batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
# TODO: Make "temperature" single source of truth from generation.
batch.meta_info["temperature"] = rollout_config.temperature
# update actor
if self.use_legacy_worker_impl == "disable":
batch_td = batch.to_tensordict()
# step 2: convert from padding to no-padding
batch_td = left_right_2_no_padding(batch_td)
calculate_entropy = self.config.actor_rollout_ref.actor.entropy_coeff != 0.0
ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n
ppo_epochs = self.config.actor_rollout_ref.actor.ppo_epochs
seed = self.config.actor_rollout_ref.actor.data_loader_seed
shuffle = self.config.actor_rollout_ref.actor.shuffle
tu.assign_non_tensor(batch_td, calculate_entropy=calculate_entropy, global_batch_size=ppo_mini_batch_size)

# make iterator
dataloader = tu.make_iterator(
batch_td,
mini_batch_size=ppo_mini_batch_size,
epochs=ppo_epochs,
seed=seed,
dataloader_kwargs={"shuffle": shuffle},
)
# manually wakeup actor
self.actor_rollout_wg.to("device")

# update
output_ref_lst = []
total_num_iterations = batch_td.shape[0] // ppo_mini_batch_size * ppo_epochs
for batch_idx, mini_batch_td in enumerate(dataloader):
# add global token num
global_token_num = mini_batch_td["input_ids"].offsets().diff().tolist()
tu.assign_non_tensor(
mini_batch_td,
global_token_num=NonTensorData(global_token_num),
update_lr_scheduler=batch_idx == total_num_iterations - 1,
disable_auto_offload=True,
)
actor_output_ref = self.actor_rollout_wg.train_batch(mini_batch_td)
output_ref_lst.append(actor_output_ref)

actor_output = [output_ref.get() for output_ref in output_ref_lst]
actor_output = [tu.get(output, "metrics") for output in actor_output]

# manually sleep actor
self.actor_rollout_wg.to("cpu")

# each metric is a list of list (dp and micro-batch) (output[0] is metric in dp[0])
# flatten each metric
agg_actor_output = {}
for output in actor_output:
for key, val in output.items():
# flattn dp and micro batch
if isinstance(val, list):
output[key] = list(chain.from_iterable(val))
append_to_dict(agg_actor_output, output, prefix="actor/")

# modify key name
agg_actor_output["perf/mfu/actor"] = agg_actor_output.pop("actor/mfu")

actor_output = DataProto.from_single_dict(data={}, meta_info={"metrics": agg_actor_output})
else:
actor_output = self.actor_rollout_wg.update_actor(batch)
return actor_output

def fit(self):
"""
The training loop of PPO.
Expand Down Expand Up @@ -1143,7 +1267,7 @@ def fit(self):
)
else: # Recompute old_log_probs
with marked_timer("old_log_prob", timing_raw, color="blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
actor_config = self.config.actor_rollout_ref.actor
Expand All @@ -1153,7 +1277,10 @@ def fit(self):
loss_agg_mode=actor_config.loss_agg_mode,
loss_scale_factor=actor_config.loss_scale_factor,
)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
old_log_prob_metrics = {
"actor/entropy": entropy_agg.detach().item(),
"perf/mfu/actor_infer": old_log_prob_mfu,
}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
Expand All @@ -1168,10 +1295,7 @@ def fit(self):
if self.use_reference_policy:
# compute reference log_prob
with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
ref_log_prob = self._compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)

# compute values
Expand Down Expand Up @@ -1240,11 +1364,7 @@ def fit(self):
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with marked_timer("update_actor", timing_raw, color="red"):
rollout_config = self.config.actor_rollout_ref.rollout
batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
# TODO: Make "temperature" single source of truth from generation.
batch.meta_info["temperature"] = rollout_config.temperature
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output = self._update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)

Expand Down
Loading