diff --git a/tests/models/test_engine.py b/tests/models/test_engine.py index a4ec6df7432..bb3433a3f33 100644 --- a/tests/models/test_engine.py +++ b/tests/models/test_engine.py @@ -48,7 +48,7 @@ McoreEngineConfig, McoreOptimizerConfig, ) -from verl.workers.engine_workers import ActorWorker, CriticWorker, TrainingWorker, TrainingWorkerConfig +from verl.workers.engine_workers import CriticWorker, TrainingWorker, TrainingWorkerConfig from verl.workers.utils.losses import ppo_loss, sft_loss from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding @@ -206,121 +206,6 @@ def test_engine(strategy): ray.shutdown() -@pytest.mark.parametrize("strategy", ["megatron", "fsdp", "fsdp2"]) -def test_actor_engine(strategy): - ray.init() - - path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B") - model_config = HFModelConfig(path=path) - - if strategy == "megatron": - engine_config = McoreEngineConfig( - forward_only=False, - use_mbridge=False, - tensor_model_parallel_size=2, - pipeline_model_parallel_size=2, - context_parallel_size=2, - ) - optimizer_config = McoreOptimizerConfig(lr_decay_steps=10) - elif strategy in ["fsdp", "fsdp2"]: - engine_config = FSDPEngineConfig( - forward_only=False, fsdp_size=4, strategy=strategy, ulysses_sequence_parallel_size=2 - ) - optimizer_config = FSDPOptimizerConfig() - else: - raise NotImplementedError(f"strategy {strategy} is not supported") - - config = ActorConfig( - model_config=model_config, - engine=engine_config, - strategy=strategy, - ppo_micro_batch_size_per_gpu=256, - ppo_mini_batch_size=4, - optim=optimizer_config, - use_dynamic_bsz=True, - rollout_n=1, - ) - ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorWorker), config=config) - resource_pool = RayResourcePool(process_on_nodes=[8]) - wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) - # init model - wg.init_model() - - batch_size = 8 - seqlen = 32 - - response_length = seqlen // 2 - - torch.manual_seed(1) - np.random.seed(1) - - input_ids = torch.randint(0, model_config.hf_config.vocab_size, (batch_size, seqlen)) - attention_mask = create_random_mask( - input_ids=input_ids, max_ratio_of_valid_token=0.8, max_ratio_of_left_padding=0.2, min_ratio_of_valid_token=0.6 - ) - position_ids = compute_position_id_with_mask(attention_mask) - - global_token_num = torch.sum(attention_mask, dim=-1).tolist() - - print(input_ids.float().mean(), attention_mask.float().mean()) - - responses = input_ids[:, response_length:] - response_mask = attention_mask[:, response_length:] - - assert torch.all(response_mask[:, 0] == 1) - - data = DataProto.from_single_dict( - { - "input_ids": input_ids, - "prompts": input_ids[:, :response_length], - "attention_mask": attention_mask, - "position_ids": position_ids, - "responses": responses, - "response_mask": response_mask, - }, - meta_info={"temperature": 1.0, "global_token_num": global_token_num}, - ) - - # sft_loss_ = partial(sft_loss, config=config) - - # eval - output = wg.compute_log_prob(data) - - # load hf model and compare results with hf model - hf_model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16) - hf_output = hf_model(input_ids, attention_mask=attention_mask) - hf_logprobs = logprobs_from_logits_naive( - hf_output.logits[:, -response_length - 1 : -1, :].float(), input_ids[:, -response_length:] - ) - hf_logprobs_mean = torch.mean(hf_logprobs * response_mask) - mcore_logprobs_mean = torch.mean(output.batch["old_log_probs"] * response_mask) - - torch.testing.assert_close(hf_logprobs_mean, mcore_logprobs_mean, atol=1e-3, rtol=1e-2) - - data = data.union(output) - - # TODO: sft_loss_ is not compatible with ActorWorker until we replace DataProto with torch.jagged TensorDict - # wg.set_loss_fn(sft_loss_) - - # train for one step - # metrics = wg.update_actor(data) - # print(metrics) - - # add ppo data - data.batch["advantages"] = torch.rand_like(responses, dtype=torch.float32) - data.batch["ref_log_prob"] = torch.rand_like(responses, dtype=torch.float32) - - # set ppo loss - ppo_loss_ = partial(ppo_loss, config=config) - wg.set_loss_fn(ppo_loss_) - - # update again - ppo_metrics = wg.update_actor(data) - print(ppo_metrics) - - ray.shutdown() - - def create_model(): from transformers import Qwen3Config diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 89849b8f124..6c906f2ac68 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -88,6 +88,7 @@ actor_rollout_ref: kl_loss_type: low_var_kl ppo_epochs: 1 shuffle: false + data_loader_seed: 42 checkpoint: _target_: verl.trainer.config.CheckpointConfig save_contents: @@ -127,7 +128,6 @@ actor_rollout_ref: mode: disabled record_file: null replay_file: null - data_loader_seed: 42 load_weight: true ref: rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 7ded715fe73..d4a88acbe8e 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -75,6 +75,7 @@ actor_rollout_ref: kl_loss_type: low_var_kl ppo_epochs: 1 shuffle: false + data_loader_seed: 42 checkpoint: _target_: verl.trainer.config.CheckpointConfig save_contents: diff --git a/verl/trainer/config/actor/actor.yaml b/verl/trainer/config/actor/actor.yaml index f5f1d15eee5..283095a1527 100644 --- a/verl/trainer/config/actor/actor.yaml +++ b/verl/trainer/config/actor/actor.yaml @@ -103,6 +103,9 @@ ppo_epochs: 1 # Shuffle training data across PPO epochs shuffle: false +# The seed used to construct mini-batch +data_loader_seed: 42 + # checkpoint configs checkpoint: diff --git a/verl/trainer/config/actor/megatron_actor.yaml b/verl/trainer/config/actor/megatron_actor.yaml index a632fe4380b..fde70c363c4 100644 --- a/verl/trainer/config/actor/megatron_actor.yaml +++ b/verl/trainer/config/actor/megatron_actor.yaml @@ -15,6 +15,4 @@ _target_: verl.workers.config.McoreActorConfig strategy: megatron -data_loader_seed: 42 - load_weight: True diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index e439a76d361..401e07cd81d 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -24,6 +24,7 @@ from collections import defaultdict from copy import deepcopy from dataclasses import dataclass, field +from itertools import chain from pprint import pprint from typing import Optional @@ -31,6 +32,7 @@ import ray import torch from omegaconf import OmegaConf, open_dict +from tensordict import NonTensorData from torch.utils.data import Dataset, Sampler from torchdata.stateful_dataloader import StatefulDataLoader from tqdm import tqdm @@ -51,14 +53,17 @@ ) from verl.trainer.ppo.reward import compute_reward, compute_reward_async from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model +from verl.utils import tensordict_utils as tu from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi from verl.utils.config import omega_conf_to_dataclass from verl.utils.debug import marked_timer from verl.utils.metric import reduce_metrics +from verl.utils.py_functional import append_to_dict from verl.utils.rollout_skip import RolloutSkip from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger +from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding @dataclass @@ -343,6 +348,8 @@ def __init__( if self.config.algorithm.use_kl_in_reward: self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) + self.use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): @@ -772,6 +779,9 @@ def init_workers(self): self.actor_rollout_wg = all_wg[str(actor_role)] self.actor_rollout_wg.init_model() + if self.ref_in_actor: + self.ref_policy_wg = self.actor_rollout_wg + # create async rollout manager and request scheduler self.async_rollout_mode = False if self.config.actor_rollout_ref.rollout.mode == "async": @@ -974,6 +984,120 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle ) metrics.update(global_balance_stats) + def _compute_ref_log_prob(self, batch: DataProto) -> DataProto: + if self.use_legacy_worker_impl == "disable": + # step 1: convert dataproto to tensordict. + batch_td = batch.to_tensordict() + # step 2: convert from padding to nopadding + batch_td = left_right_2_no_padding(batch_td) + # step 3: add meta info + tu.assign_non_tensor(batch_td, calculate_entropy=False, compute_loss=False) + output = self.ref_policy_wg.compute_ref_log_prob(batch_td) + # gather output + log_probs = tu.get(output, "log_probs") + # step 4. No padding to padding + log_probs = no_padding_2_padding(log_probs, batch_td) + # step 5: rebuild a tensordict and convert to dataproto + ref_log_prob = tu.get_tensordict({"ref_log_prob": log_probs.float()}) + ref_log_prob = DataProto.from_tensordict(ref_log_prob) + else: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + + return ref_log_prob + + def _compute_old_log_prob(self, batch: DataProto): + if self.use_legacy_worker_impl == "disable": + # TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free + # step 1: convert dataproto to tensordict. + batch_td = batch.to_tensordict() + # step 2: convert from padding to nopadding + batch_td = left_right_2_no_padding(batch_td) + # step 3: add meta info + tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False) + output = self.actor_rollout_wg.compute_log_prob(batch_td) + # gather output + entropy = tu.get(output, "entropy") + log_probs = tu.get(output, "log_probs") + old_log_prob_mfu = tu.get(output, "metrics")["mfu"] + # step 4. No padding to padding + entropy = no_padding_2_padding(entropy, batch_td) + log_probs = no_padding_2_padding(log_probs, batch_td) + # step 5: rebuild a tensordict and convert to dataproto + old_log_prob = tu.get_tensordict({"old_log_probs": log_probs.float(), "entropys": entropy.float()}) + old_log_prob = DataProto.from_tensordict(old_log_prob) + else: + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + old_log_prob_mfu = 0 + return old_log_prob, old_log_prob_mfu + + def _update_actor(self, batch: DataProto) -> DataProto: + rollout_config = self.config.actor_rollout_ref.rollout + batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable + # TODO: Make "temperature" single source of truth from generation. + batch.meta_info["temperature"] = rollout_config.temperature + # update actor + if self.use_legacy_worker_impl == "disable": + batch_td = batch.to_tensordict() + # step 2: convert from padding to no-padding + batch_td = left_right_2_no_padding(batch_td) + calculate_entropy = self.config.actor_rollout_ref.actor.entropy_coeff != 0.0 + ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size + ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n + ppo_epochs = self.config.actor_rollout_ref.actor.ppo_epochs + seed = self.config.actor_rollout_ref.actor.data_loader_seed + shuffle = self.config.actor_rollout_ref.actor.shuffle + tu.assign_non_tensor(batch_td, calculate_entropy=calculate_entropy, global_batch_size=ppo_mini_batch_size) + + # make iterator + dataloader = tu.make_iterator( + batch_td, + mini_batch_size=ppo_mini_batch_size, + epochs=ppo_epochs, + seed=seed, + dataloader_kwargs={"shuffle": shuffle}, + ) + # manually wakeup actor + self.actor_rollout_wg.to("device") + + # update + output_ref_lst = [] + total_num_iterations = batch_td.shape[0] // ppo_mini_batch_size * ppo_epochs + for batch_idx, mini_batch_td in enumerate(dataloader): + # add global token num + global_token_num = mini_batch_td["input_ids"].offsets().diff().tolist() + tu.assign_non_tensor( + mini_batch_td, + global_token_num=NonTensorData(global_token_num), + update_lr_scheduler=batch_idx == total_num_iterations - 1, + disable_auto_offload=True, + ) + actor_output_ref = self.actor_rollout_wg.train_batch(mini_batch_td) + output_ref_lst.append(actor_output_ref) + + actor_output = [output_ref.get() for output_ref in output_ref_lst] + actor_output = [tu.get(output, "metrics") for output in actor_output] + + # manually sleep actor + self.actor_rollout_wg.to("cpu") + + # each metric is a list of list (dp and micro-batch) (output[0] is metric in dp[0]) + # flatten each metric + agg_actor_output = {} + for output in actor_output: + for key, val in output.items(): + # flattn dp and micro batch + if isinstance(val, list): + output[key] = list(chain.from_iterable(val)) + append_to_dict(agg_actor_output, output, prefix="actor/") + + # modify key name + agg_actor_output["perf/mfu/actor"] = agg_actor_output.pop("actor/mfu") + + actor_output = DataProto.from_single_dict(data={}, meta_info={"metrics": agg_actor_output}) + else: + actor_output = self.actor_rollout_wg.update_actor(batch) + return actor_output + def fit(self): """ The training loop of PPO. @@ -1143,7 +1267,7 @@ def fit(self): ) else: # Recompute old_log_probs with marked_timer("old_log_prob", timing_raw, color="blue"): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] actor_config = self.config.actor_rollout_ref.actor @@ -1153,7 +1277,10 @@ def fit(self): loss_agg_mode=actor_config.loss_agg_mode, loss_scale_factor=actor_config.loss_scale_factor, ) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + old_log_prob_metrics = { + "actor/entropy": entropy_agg.detach().item(), + "perf/mfu/actor_infer": old_log_prob_mfu, + } metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) @@ -1168,10 +1295,7 @@ def fit(self): if self.use_reference_policy: # compute reference log_prob with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): - if not self.ref_in_actor: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - else: - ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + ref_log_prob = self._compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values @@ -1240,11 +1364,7 @@ def fit(self): if self.config.trainer.critic_warmup <= self.global_steps: # update actor with marked_timer("update_actor", timing_raw, color="red"): - rollout_config = self.config.actor_rollout_ref.rollout - batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable - # TODO: Make "temperature" single source of truth from generation. - batch.meta_info["temperature"] = rollout_config.temperature - actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output = self._update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) diff --git a/verl/utils/attention_utils.py b/verl/utils/attention_utils.py index 8340155e761..ea9884307fc 100644 --- a/verl/utils/attention_utils.py +++ b/verl/utils/attention_utils.py @@ -20,14 +20,14 @@ def _get_attention_functions() -> tuple[Callable, Callable, Callable, Callable]: """Dynamically import attention functions based on available hardware.""" - from verl.utils.device import is_cuda_available, is_npu_available + from verl.utils.device import is_npu_available global _index_first_axis, _pad_input, _rearrange, _unpad_input - if is_cuda_available: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input - elif is_npu_available: + if is_npu_available: from verl.utils.npu_flash_attn_utils import index_first_axis, pad_input, rearrange, unpad_input + else: + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input _index_first_axis, _pad_input, _rearrange, _unpad_input = index_first_axis, pad_input, rearrange, unpad_input diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index d7f7dfde49e..b339a3a18ce 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -264,161 +264,6 @@ def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False return self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load) -class ActorWorker(Worker, DistProfilerExtension): - """ - This worker can be instantiated as a standalone actor or a standalone reference policy - or a hybrid engine based on the config.rollout - """ - - def __init__(self, config: ActorConfig): - self.config = config - Worker.__init__(self) - self.profiler_config = self.config.profiler - tool_config = self.profiler_config.tool_config - DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=self.profiler_config, tool_config=tool_config) - ) - - initialize_global_process_group_ray(timeout_second=None) - - self.loss_fn = partial(ppo_loss, config=self.config) - - def _build_engine(self): - self.model_config = self.config.model_config - self.engine_config = self.config.engine - self.optimizer_config = self.config.optim - self.checkpoint_config = self.config.checkpoint - - from verl.workers.engine import BaseEngine, EngineRegistry - - self.engine: BaseEngine = EngineRegistry.new( - model_type="language_model", - backend=self.config.strategy, - model_config=self.model_config, - engine_config=self.engine_config, - optimizer_config=self.optimizer_config, - checkpoint_config=self.checkpoint_config, - ) - - # build dispatch info - self._register_dispatch_collect_info( - mesh_name="actor", - dp_rank=self.engine.get_data_parallel_rank(), - is_collect=self.engine.is_mp_src_rank_with_outputs(), - ) - - # aggregate with bon sampling - self.ppo_mini_batch_size = self.config.ppo_mini_batch_size * self.config.rollout_n - assert self.ppo_mini_batch_size % self.engine.get_data_parallel_size() == 0, ( - f"{self.ppo_mini_batch_size=} is not divisible by {self.engine.get_data_parallel_size()=}" - ) - self.ppo_mini_batch_size_per_dp = self.ppo_mini_batch_size // self.engine.get_data_parallel_size() - - # setup flops counter - self.flops_counter = FlopsCounter(self.model_config.hf_config) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - self._build_engine() - self.engine.initialize() - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def set_loss_fn(self, loss_fn): - self.loss_fn = loss_fn - - @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) - @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") - def compute_log_prob(self, data: DataProto): - data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz - data.meta_info["use_fused_kernels"] = self.config.use_fused_kernels - if "calculate_entropy" not in data.meta_info: - data.meta_info["calculate_entropy"] = True - calculate_entropy = data.meta_info["calculate_entropy"] - - if self.config.use_dynamic_bsz: - data.meta_info["max_token_len_per_gpu"] = self.config.ppo_infer_max_token_len_per_gpu - else: - data.meta_info["micro_batch_size_per_gpu"] = self.config.ppo_infer_micro_batch_size_per_gpu - - with self.engine.eval_mode(): - # TODO: make worker API to accept TensorDict as well - data = data.to_tensordict() - data = left_right_2_no_padding(data) - output = self.engine.infer_batch(data) - - if self.engine.is_mp_src_rank_with_outputs(): - output = output["model_output"] - log_probs = output["log_probs"] - log_probs = no_padding_2_padding(log_probs, data) # (bsz, response_length) - - tensors = {"old_log_probs": log_probs.float()} - if calculate_entropy: - entropy = no_padding_2_padding(output["entropy"], data) # (bsz, response_length) - tensors["entropys"] = entropy.float() - - # in megatron, only last pp contains valid data and returned to the single controller - output = DataProto.from_dict(tensors=tensors) - output = output.to("cpu") - - return output if self.engine.is_mp_src_rank_with_outputs() else None - - @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) - @DistProfiler.annotate(color="red", role="actor_update") - def update_actor(self, data: DataProto): - data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz - data.meta_info["use_fused_kernels"] = self.config.use_fused_kernels - data.meta_info["calculate_entropy"] = self.config.entropy_coeff != 0.0 - if self.config.use_dynamic_bsz: - data.meta_info["max_token_len_per_gpu"] = self.config.ppo_max_token_len_per_gpu - else: - data.meta_info["micro_batch_size_per_gpu"] = self.config.ppo_micro_batch_size_per_gpu - - metrics = {} - # Support all hardwares - data = data.to(get_device_id()) - # perform forward computation - with self.engine.train_mode(): - dataloader = data.make_iterator( - mini_batch_size=self.ppo_mini_batch_size_per_dp, - epochs=self.config.ppo_epochs, - seed=self.config.data_loader_seed + self.engine.get_data_parallel_rank(), - dataloader_kwargs={"shuffle": self.config.shuffle}, - ) - with Timer(name="update_policy", logger=None) as timer: - for batch_idx, mini_batch in enumerate(dataloader): - mini_batch.meta_info["global_batch_size"] = self.ppo_mini_batch_size - # TODO: make worker API to accept TensorDict as well - mini_batch = mini_batch.to_tensordict() - mini_batch = left_right_2_no_padding(mini_batch) - output = self.engine.train_batch(mini_batch, self.loss_fn) - mini_batch_metrics = output.get("metrics", {}) - append_to_dict(metrics, mini_batch_metrics, prefix="actor/") - - delta_time = timer.last - - global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/actor"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size - metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) - metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) - metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) - - lr = self.engine.lr_scheduler_step() - metrics["actor/lr"] = lr - - output = DataProto(batch=None, meta_info={"metrics": metrics}) - - return output - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): - return self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): - return self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load) - - class CriticWorker(Worker, DistProfilerExtension): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy @@ -599,10 +444,19 @@ def __init__(self, config: DictConfig, role: str, **kwargs): Worker.__init__(self) self.config = config self.role = role - self.actor: ActorWorker = None - self.ref: ActorWorker = None + self.actor: TrainingWorker = None + self.ref: TrainingWorker = None self.rollout: BaseRollout = None + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_loss_fn(self, loss_fn): + self.actor.set_loss_fn(loss_fn=loss_fn) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def to(self, device, model=True, optimizer=True, grad=True): + """Manual control of load/offload""" + self.actor.to(device=device, model=model, optimizer=optimizer, grad=grad) + @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model) @@ -621,9 +475,25 @@ def init_model(self): ref_config: ActorConfig = omega_conf_to_dataclass(self.config.ref) ref_config.model_config = model_config - self.ref = ActorWorker(ref_config) - self.ref.init_model() - self.ref.engine.to("cpu") + # construct TrainingWorkerConfig + ref_training_config = TrainingWorkerConfig( + model_type="language_model", + model_config=ref_config.model_config, + engine_config=ref_config.engine, + optimizer_config=ref_config.optim, + checkpoint_config=ref_config.checkpoint, + ) + + # assign engine configs + ref_training_config.engine_config.use_dynamic_bsz = self.config.ref.use_dynamic_bsz + ref_training_config.engine_config.infer_max_token_len_per_gpu = self.config.ref.ppo_max_token_len_per_gpu + ref_training_config.engine_config.infer_micro_batch_size_per_gpu = ( + self.config.ref.ppo_micro_batch_size_per_gpu + ) + ref_training_config.engine_config.use_remove_padding = model_config.use_remove_padding + + self.ref = TrainingWorker(config=ref_training_config) + self.ref.reset() self.set_dispatch_collect(mesh_name="ref", **self.ref.get_dispatch_collect()) # 2. build actor model @@ -631,9 +501,41 @@ def init_model(self): actor_config: ActorConfig = omega_conf_to_dataclass(self.config.actor) actor_config.model_config = model_config - self.actor = ActorWorker(actor_config) - self.actor.init_model() - self.actor.engine.to("cpu") + actor_training_config = TrainingWorkerConfig( + model_type="language_model", + model_config=actor_config.model_config, + engine_config=actor_config.engine, + optimizer_config=actor_config.optim, + checkpoint_config=actor_config.checkpoint, + ) + + assert self.config.actor.use_dynamic_bsz == self.config.rollout.log_prob_use_dynamic_bsz + + # assign engine configs + actor_training_config.engine_config.use_dynamic_bsz = self.config.actor.use_dynamic_bsz + actor_training_config.engine_config.infer_max_token_len_per_gpu = ( + self.config.rollout.log_prob_max_token_len_per_gpu + ) + actor_training_config.engine_config.infer_micro_batch_size_per_gpu = ( + self.config.rollout.log_prob_micro_batch_size_per_gpu + ) + actor_training_config.engine_config.max_token_len_per_gpu = self.config.actor.ppo_max_token_len_per_gpu + actor_training_config.engine_config.micro_batch_size_per_gpu = ( + self.config.actor.ppo_micro_batch_size_per_gpu + ) + actor_training_config.engine_config.use_remove_padding = model_config.use_remove_padding + + if self.config.actor.use_dynamic_bsz: + assert self.config.rollout.log_prob_max_token_len_per_gpu is not None + assert self.config.actor.ppo_max_token_len_per_gpu is not None + else: + assert self.config.rollout.log_prob_micro_batch_size_per_gpu is not None + assert self.config.rollout.ppo_micro_batch_size_per_gpu is not None + + self.loss_fn = partial(ppo_loss, config=actor_config) + self.actor = TrainingWorker(config=actor_training_config) + self.actor.reset() + self.actor.set_loss_fn(self.loss_fn) self.set_dispatch_collect(mesh_name="actor", **self.actor.get_dispatch_collect()) # 3. build rollout engine @@ -673,22 +575,19 @@ def init_model(self): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref")) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") - def compute_ref_log_prob(self, data: DataProto): - data.meta_info["calculate_entropy"] = False - output = self.ref.compute_log_prob(data) - if output is not None: - output.batch["ref_log_prob"] = output.batch.pop("old_log_probs") - return output + def compute_ref_log_prob(self, data: TensorDict) -> TensorDict: + return self.ref.infer_batch(data=data).cpu() @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") - def compute_log_prob(self, data: DataProto): - return self.actor.compute_log_prob(data) + def compute_log_prob(self, data: TensorDict) -> TensorDict: + return self.actor.infer_batch(data).cpu() - @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"), blocking=False) @DistProfiler.annotate(color="red", role="actor_update") - def update_actor(self, data: DataProto): - return self.actor.update_actor(data) + def train_batch(self, data: TensorDict) -> TensorDict: + output = self.actor.train_batch(data=data) + return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): diff --git a/verl/workers/utils/padding.py b/verl/workers/utils/padding.py index e4ea2bd8c31..35541455833 100644 --- a/verl/workers/utils/padding.py +++ b/verl/workers/utils/padding.py @@ -16,15 +16,7 @@ from tensordict import TensorDict from verl.utils import tensordict_utils as tu -from verl.utils.device import ( - is_cuda_available, - is_npu_available, -) - -if is_cuda_available: - from flash_attn.bert_padding import pad_input, unpad_input -elif is_npu_available: - from transformers.integrations.npu_flash_attention import pad_input, unpad_input +from verl.utils.attention_utils import pad_input, unpad_input def left_right_2_no_padding(data: TensorDict) -> TensorDict: