diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 91c0d21f2a3..dcc710d2270 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -6,7 +6,7 @@ - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off` + - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `vllm_omni`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. diff --git a/docs/algo/flowgrpo.md b/docs/algo/flowgrpo.md new file mode 100644 index 00000000000..bb719133994 --- /dev/null +++ b/docs/algo/flowgrpo.md @@ -0,0 +1,136 @@ +# Training Flow Matching Models via Online RL (Flow-GRPO) + +Flow-GRPO ([paper](https://arxiv.org/abs/2505.05470), [code](https://github.com/yifan123/flow_grpo)) is the first method to integrate online policy gradient reinforcement learning into **flow matching** generative models (e.g., Stable Diffusion 3, FLUX). It enables direct reward optimization for tasks such as compositional text-to-image generation, visual text rendering, and human preference alignment, without modifying the standard inference pipeline. + +Two core technical contributions make this possible: + +1. **ODE-to-SDE Conversion**: Flow matching models natively use a deterministic ODE sampler. Flow-GRPO converts this ODE into an equivalent SDE that preserves the model's marginal distribution at every timestep. This introduces the stochasticity required for group sampling and RL exploration. + +2. **Denoising Reduction**: Training on all denoising steps is expensive. Flow-GRPO reduces the number of *training* steps while keeping the original number of *inference* steps, significantly improving sampling efficiency without sacrificing reward performance. + +Empirically, RL-tuned SD3.5-M with Flow-GRPO raises GenEval accuracy from 63% to 95% and visual text rendering accuracy from 59% to 92%. + +## Key Components + +- **Flow Matching Backbone**: operates on continuous-time flow matching models (e.g., SD3.5, FLUX) rather than discrete-token LLMs. +- **ODE-to-SDE Rollout**: generates a group of diverse image trajectories by injecting controlled noise via SDE sampling at selected denoising steps. +- **Denoising Reduction**: trains on a reduced subset of denoising steps (configurable via `sde_window_size` and `sde_window_range`) while inference uses the full step count. +- **Image Reward Models**: rewards are assigned by external reward models (e.g., GenEval, OCR, PickScore, aesthetic score) rather than rule-based verifiers. +- **No Critic**: like GRPO for LLMs, no separate value network is trained; advantages are computed from group-relative rewards. + +## Key Differences: GRPO vs. Flow-GRPO + +| Dimension | GRPO (LLM) | Flow-GRPO (Diffusion) | +|---|---|---| +| **Model type** | Autoregressive language model | Flow matching / diffusion model | +| **Action space** | Discrete token sequences | Continuous denoising trajectories (SDE paths) | +| **Rollout mechanism** | Sample `n` token sequences per prompt | Convert ODE to SDE; sample `n` image trajectories per prompt via stochastic denoising | +| **Log-probability** | Standard next-token log-prob | Log-prob of the SDE noise prediction at each selected denoising step | +| **Training steps** | All decoding steps are trivially identical in cost | Denoising Reduction: train on a small window of steps, infer with full steps | +| **Reward signal** | Rule-based verifiers or LLM judges on text | Image reward models (GenEval, OCR, PickScore, aesthetic, etc.) | +| **KL regularization** | KL penalty added to reward or directly to loss | KL loss applied to SDE steps; `use_kl_loss=True` recommended | +| **CFG (guidance)** | Not applicable | CFG distillation occurs naturally; CFG can be disabled at both train and test time | +| **Advantage estimator** | `algorithm.adv_estimator=grpo` | `algorithm.adv_estimator=flow_grpo` | +| **Loss mode** | `actor_rollout_ref.actor.policy_loss.loss_mode` not diffusion-specific | `actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo` | + +## Configuration + +### Core parameters + +- `algorithm.adv_estimator`: Set to `flow_grpo` (instead of `grpo`). + +- `actor_rollout_ref.actor.policy_loss.loss_mode`: Set to `flow_grpo`. + +- `actor_rollout_ref.rollout.n`: Number of image trajectories to sample per prompt for group-relative advantage computation. Analogous to GRPO's group size; should be > 1 (default in examples: `16`). + +- `actor_rollout_ref.rollout.noise_level`: Controls the SDE noise injection level during rollout. Larger values increase diversity but may degrade image quality. Typical value: `1.2`. + +- `actor_rollout_ref.rollout.sde_window_size`: Number of denoising steps to train on per trajectory (Denoising Reduction). Reducing this from the full step count speeds up training significantly. + +- `actor_rollout_ref.rollout.sde_window_range`: The range of denoising steps from which the training window is sampled, e.g., `[0, 5]` to focus on early (high-noise) steps. + +- `actor_rollout_ref.rollout.val_kwargs.num_inference_steps`: Full number of denoising steps used during inference/evaluation. This is kept at its original value (e.g., `50`) and is independent of `sde_window_size`. + +- `actor_rollout_ref.rollout.guidance_scale`: Classifier-free guidance scale during rollout. Can be set to `1.0` (no CFG) because the RL process naturally performs CFG distillation. + +- `actor_rollout_ref.actor.use_kl_loss`: Set to `True` to add a KL divergence term between the trained policy and the reference policy to the loss. + +- `actor_rollout_ref.actor.kl_loss_coef`: Coefficient for the KL loss term. + +## Data Preprocessing + +All training scripts expect the dataset in parquet format. The examples use an OCR dataset from the [Flow-GRPO repository](https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr). The raw dataset consists of text files where each ground-truth answer is stored in the format `The image displays "xxx".`. Before running any training script, convert it to parquet format using the provided preprocessing script. + +### Step 1: Download the raw dataset + +Download the OCR dataset from the Flow-GRPO repository and place it at `~/dataset/ocr/` (or any path of your choice): + +```bash +# Clone or download from https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr +# Place the dataset directory at ~/dataset/ocr/ +# Expected structure: +# ~/dataset/ocr/ +# train/ (or train split files) +# test/ (or test split files) +``` + +### Step 2: Run the preprocessing script + +```bash +python examples/data_preprocess/qwenimage_ocr.py \ + --local_dataset_path ~/dataset/ocr \ + --local_save_dir ~/data/ocr +``` + +The output parquet files are consumed directly by all training scripts via `data.train_files` and `data.val_files`. + +## Variants + +### Flow-GRPO-Fast + +Flow-GRPO-Fast accelerates training by confining stochasticity to only one or two denoising steps per trajectory: + +1. Generate a deterministic ODE trajectory for each prompt. +2. At a randomly chosen intermediate step, inject noise and switch to SDE sampling to produce the group. +3. Continue the remaining steps with ODE sampling. + +This significantly reduces training cost: only the selected step(s) require gradient computation, and sampling before the branching point does not need group expansion. Flow-GRPO-Fast with 2 training steps matches full Flow-GRPO reward performance. + +```bash +bash examples/flowgrpo_trainer/run_flowgrpo_fast.sh +``` + +### Async Reward + +For reward models that are expensive to evaluate (e.g., a VLM judge), the reward model can be allocated its own dedicated GPU resource pool and run asynchronously alongside the policy. This avoids blocking policy training on reward computation. + +```bash +bash examples/flowgrpo_trainer/run_flowgrpo_async_reward.sh +``` + +### Full Fine-Tuning + +To fine-tune all model weights instead of using LoRA: + +```bash +bash examples/flowgrpo_trainer/run_flowgrpo_full_ft.sh +``` + +## Reference Example + +Standard LoRA training with OCR reward (Qwen-Image, 4 GPUs) with CFG and KL loss enabled: + +```bash +bash examples/flowgrpo_trainer/run_flowgrpo.sh +``` + +## Citation + +```bibtex +@article{liu2025flow, + title={Flow-GRPO: Training Flow Matching Models via Online RL}, + author={Liu, Jie and Liu, Gongye and Liang, Jiajun and Li, Yangguang and Liu, Jiaheng and Wang, Xintao and Wan, Pengfei and Zhang, Di and Ouyang, Wanli}, + journal={arXiv preprint arXiv:2505.05470}, + year={2025} +} +``` diff --git a/examples/data_preprocess/qwenimage_ocr.py b/examples/data_preprocess/qwenimage_ocr.py new file mode 100644 index 00000000000..3953e86a646 --- /dev/null +++ b/examples/data_preprocess/qwenimage_ocr.py @@ -0,0 +1,103 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the OCR dataset to parquet format (for Qwen-Image training). +You can obtain the raw dataset from https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr +""" + +import argparse +import os + +import datasets + +from verl.utils.hdfs_io import copy, makedirs + + +def extract_solution(solution_str): + # The solution is stored in the format: 'The image displays "xxx".' + return solution_str.split('"')[1] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default=None) + parser.add_argument("--hdfs_dir", default=None) + parser.add_argument( + "--local_dataset_path", default="~/dataset/ocr/", help="The local path to the raw dataset, if it exists." + ) + parser.add_argument( + "--local_save_dir", default="~/data/ocr", help="The save directory for the preprocessed dataset." + ) + + args = parser.parse_args() + if args.local_dataset_path is not None: + local_dataset_path = os.path.expanduser(args.local_dataset_path) + + data_source = "flow_grpo/ocr" + + if local_dataset_path is not None: + dataset = datasets.load_dataset(local_dataset_path) + else: + raise NotImplementedError( + "It is not existed in huggingface hub. " + "Please get dataset from https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr" + ) + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + system_prompt = ( + "Describe the image by detailing the color, shape, size, " + "texture, quantity, text, spatial relationships of the objects and background:" + ) + negative_user_prompt = " " + + def make_map_fn(split): + def process_fn(example, idx): + text = example.pop("text") + solution = extract_solution(text) + data = { + "data_source": data_source, + "prompt": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": text}, + ], + "negative_prompt": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": negative_user_prompt}, + ], + "ability": "ocr", + "reward_model": {"style": "model", "ground_truth": solution}, + "extra_info": {"split": split, "index": idx}, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + hdfs_dir = args.hdfs_dir + local_save_dir = args.local_dir + if local_save_dir is not None: + print("Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.") + else: + local_save_dir = args.local_save_dir + + train_dataset.to_parquet(os.path.join(local_save_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_save_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + copy(src=local_save_dir, dst=hdfs_dir) diff --git a/examples/flowgrpo_trainer/README.md b/examples/flowgrpo_trainer/README.md new file mode 100644 index 00000000000..bb719133994 --- /dev/null +++ b/examples/flowgrpo_trainer/README.md @@ -0,0 +1,136 @@ +# Training Flow Matching Models via Online RL (Flow-GRPO) + +Flow-GRPO ([paper](https://arxiv.org/abs/2505.05470), [code](https://github.com/yifan123/flow_grpo)) is the first method to integrate online policy gradient reinforcement learning into **flow matching** generative models (e.g., Stable Diffusion 3, FLUX). It enables direct reward optimization for tasks such as compositional text-to-image generation, visual text rendering, and human preference alignment, without modifying the standard inference pipeline. + +Two core technical contributions make this possible: + +1. **ODE-to-SDE Conversion**: Flow matching models natively use a deterministic ODE sampler. Flow-GRPO converts this ODE into an equivalent SDE that preserves the model's marginal distribution at every timestep. This introduces the stochasticity required for group sampling and RL exploration. + +2. **Denoising Reduction**: Training on all denoising steps is expensive. Flow-GRPO reduces the number of *training* steps while keeping the original number of *inference* steps, significantly improving sampling efficiency without sacrificing reward performance. + +Empirically, RL-tuned SD3.5-M with Flow-GRPO raises GenEval accuracy from 63% to 95% and visual text rendering accuracy from 59% to 92%. + +## Key Components + +- **Flow Matching Backbone**: operates on continuous-time flow matching models (e.g., SD3.5, FLUX) rather than discrete-token LLMs. +- **ODE-to-SDE Rollout**: generates a group of diverse image trajectories by injecting controlled noise via SDE sampling at selected denoising steps. +- **Denoising Reduction**: trains on a reduced subset of denoising steps (configurable via `sde_window_size` and `sde_window_range`) while inference uses the full step count. +- **Image Reward Models**: rewards are assigned by external reward models (e.g., GenEval, OCR, PickScore, aesthetic score) rather than rule-based verifiers. +- **No Critic**: like GRPO for LLMs, no separate value network is trained; advantages are computed from group-relative rewards. + +## Key Differences: GRPO vs. Flow-GRPO + +| Dimension | GRPO (LLM) | Flow-GRPO (Diffusion) | +|---|---|---| +| **Model type** | Autoregressive language model | Flow matching / diffusion model | +| **Action space** | Discrete token sequences | Continuous denoising trajectories (SDE paths) | +| **Rollout mechanism** | Sample `n` token sequences per prompt | Convert ODE to SDE; sample `n` image trajectories per prompt via stochastic denoising | +| **Log-probability** | Standard next-token log-prob | Log-prob of the SDE noise prediction at each selected denoising step | +| **Training steps** | All decoding steps are trivially identical in cost | Denoising Reduction: train on a small window of steps, infer with full steps | +| **Reward signal** | Rule-based verifiers or LLM judges on text | Image reward models (GenEval, OCR, PickScore, aesthetic, etc.) | +| **KL regularization** | KL penalty added to reward or directly to loss | KL loss applied to SDE steps; `use_kl_loss=True` recommended | +| **CFG (guidance)** | Not applicable | CFG distillation occurs naturally; CFG can be disabled at both train and test time | +| **Advantage estimator** | `algorithm.adv_estimator=grpo` | `algorithm.adv_estimator=flow_grpo` | +| **Loss mode** | `actor_rollout_ref.actor.policy_loss.loss_mode` not diffusion-specific | `actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo` | + +## Configuration + +### Core parameters + +- `algorithm.adv_estimator`: Set to `flow_grpo` (instead of `grpo`). + +- `actor_rollout_ref.actor.policy_loss.loss_mode`: Set to `flow_grpo`. + +- `actor_rollout_ref.rollout.n`: Number of image trajectories to sample per prompt for group-relative advantage computation. Analogous to GRPO's group size; should be > 1 (default in examples: `16`). + +- `actor_rollout_ref.rollout.noise_level`: Controls the SDE noise injection level during rollout. Larger values increase diversity but may degrade image quality. Typical value: `1.2`. + +- `actor_rollout_ref.rollout.sde_window_size`: Number of denoising steps to train on per trajectory (Denoising Reduction). Reducing this from the full step count speeds up training significantly. + +- `actor_rollout_ref.rollout.sde_window_range`: The range of denoising steps from which the training window is sampled, e.g., `[0, 5]` to focus on early (high-noise) steps. + +- `actor_rollout_ref.rollout.val_kwargs.num_inference_steps`: Full number of denoising steps used during inference/evaluation. This is kept at its original value (e.g., `50`) and is independent of `sde_window_size`. + +- `actor_rollout_ref.rollout.guidance_scale`: Classifier-free guidance scale during rollout. Can be set to `1.0` (no CFG) because the RL process naturally performs CFG distillation. + +- `actor_rollout_ref.actor.use_kl_loss`: Set to `True` to add a KL divergence term between the trained policy and the reference policy to the loss. + +- `actor_rollout_ref.actor.kl_loss_coef`: Coefficient for the KL loss term. + +## Data Preprocessing + +All training scripts expect the dataset in parquet format. The examples use an OCR dataset from the [Flow-GRPO repository](https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr). The raw dataset consists of text files where each ground-truth answer is stored in the format `The image displays "xxx".`. Before running any training script, convert it to parquet format using the provided preprocessing script. + +### Step 1: Download the raw dataset + +Download the OCR dataset from the Flow-GRPO repository and place it at `~/dataset/ocr/` (or any path of your choice): + +```bash +# Clone or download from https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr +# Place the dataset directory at ~/dataset/ocr/ +# Expected structure: +# ~/dataset/ocr/ +# train/ (or train split files) +# test/ (or test split files) +``` + +### Step 2: Run the preprocessing script + +```bash +python examples/data_preprocess/qwenimage_ocr.py \ + --local_dataset_path ~/dataset/ocr \ + --local_save_dir ~/data/ocr +``` + +The output parquet files are consumed directly by all training scripts via `data.train_files` and `data.val_files`. + +## Variants + +### Flow-GRPO-Fast + +Flow-GRPO-Fast accelerates training by confining stochasticity to only one or two denoising steps per trajectory: + +1. Generate a deterministic ODE trajectory for each prompt. +2. At a randomly chosen intermediate step, inject noise and switch to SDE sampling to produce the group. +3. Continue the remaining steps with ODE sampling. + +This significantly reduces training cost: only the selected step(s) require gradient computation, and sampling before the branching point does not need group expansion. Flow-GRPO-Fast with 2 training steps matches full Flow-GRPO reward performance. + +```bash +bash examples/flowgrpo_trainer/run_flowgrpo_fast.sh +``` + +### Async Reward + +For reward models that are expensive to evaluate (e.g., a VLM judge), the reward model can be allocated its own dedicated GPU resource pool and run asynchronously alongside the policy. This avoids blocking policy training on reward computation. + +```bash +bash examples/flowgrpo_trainer/run_flowgrpo_async_reward.sh +``` + +### Full Fine-Tuning + +To fine-tune all model weights instead of using LoRA: + +```bash +bash examples/flowgrpo_trainer/run_flowgrpo_full_ft.sh +``` + +## Reference Example + +Standard LoRA training with OCR reward (Qwen-Image, 4 GPUs) with CFG and KL loss enabled: + +```bash +bash examples/flowgrpo_trainer/run_flowgrpo.sh +``` + +## Citation + +```bibtex +@article{liu2025flow, + title={Flow-GRPO: Training Flow Matching Models via Online RL}, + author={Liu, Jie and Liu, Gongye and Liang, Jiajun and Li, Yangguang and Liu, Jiaheng and Wang, Xintao and Wan, Pengfei and Zhang, Di and Ouyang, Wanli}, + journal={arXiv preprint arXiv:2505.05470}, + year={2025} +} +``` diff --git a/examples/flowgrpo_trainer/run_flowgrpo.sh b/examples/flowgrpo_trainer/run_flowgrpo.sh new file mode 100644 index 00000000000..e52b70e692c --- /dev/null +++ b/examples/flowgrpo_trainer/run_flowgrpo.sh @@ -0,0 +1,74 @@ +# Qwen-Image lora, vllm_omni rollout +set -x + +ocr_train_path=$HOME/data/ocr/train.parquet +ocr_test_path=$HOME/data/ocr/test.parquet + +ENGINE=vllm_omni +REWARD_ENGINE=vllm + +reward_path=tests/experimental/reward_loop/reward_fn.py +reward_model_name=$HOME/models/Qwen/Qwen3-VL-8B-Instruct + + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_diffusion_trainer.yaml' \ + algorithm.adv_estimator=flow_grpo \ + data.train_files=$ocr_train_path \ + data.val_files=$ocr_test_path \ + data.train_batch_size=32 \ + data.max_prompt_length=1058 \ + data.filter_overlong_prompts=True \ + +data.apply_chat_template_kwargs.max_length=1058 \ + +data.apply_chat_template_kwargs.padding=True \ + +data.apply_chat_template_kwargs.truncation=True \ + actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen-Image \ + actor_rollout_ref.model.tokenizer_path=$HOME/models/Qwen/Qwen-Image/tokenizer \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=128 \ + actor_rollout_ref.model.target_modules="['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','to_add_out','img_mlp.net.0.proj','img_mlp.net.2','txt_mlp.net.0.proj','txt_mlp.net.2']" \ + actor_rollout_ref.actor.optim.lr=3e-4 \ + actor_rollout_ref.actor.optim.weight_decay=0.0001 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.04 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.rollout.guidance_scale=4.0 \ + actor_rollout_ref.rollout.agent.default_agent_loop=diffusion_single_turn_agent \ + actor_rollout_ref.rollout.agent.num_workers=4 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.rollout.max_model_len=1058 \ + actor_rollout_ref.rollout.noise_level=1.2 \ + actor_rollout_ref.rollout.sde_window_size=2 \ + actor_rollout_ref.rollout.sde_window_range="[0,5]" \ + actor_rollout_ref.rollout.val_kwargs.num_inference_steps=50 \ + +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.custom_pipeline=verl.utils.vllm_omni.pipelines.QwenImagePipelineWithLogProb \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + reward.num_workers=4 \ + reward.reward_manager.name=image \ + reward.reward_model.enable=True \ + reward.reward_model.model_path=$reward_model_name \ + reward.reward_model.rollout.name=$REWARD_ENGINE \ + reward.reward_model.rollout.tensor_model_parallel_size=4 \ + reward.custom_reward_function.path=$reward_path \ + reward.custom_reward_function.name=compute_score_ocr \ + trainer.use_legacy_worker_impl=disable \ + trainer.logger='["console", "wandb"]' \ + trainer.project_name=flow_grpo \ + trainer.experiment_name=qwen_image_ocr \ + trainer.log_val_generations=8 \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=30 \ + trainer.test_freq=30 \ + trainer.total_epochs=15 $@ diff --git a/examples/flowgrpo_trainer/run_flowgrpo_async_reward.sh b/examples/flowgrpo_trainer/run_flowgrpo_async_reward.sh new file mode 100644 index 00000000000..f7923046700 --- /dev/null +++ b/examples/flowgrpo_trainer/run_flowgrpo_async_reward.sh @@ -0,0 +1,78 @@ +# Qwen-Image lora, vllm_omni rollout +set -x + +ocr_train_path=$HOME/data/ocr/train.parquet +ocr_test_path=$HOME/data/ocr/test.parquet + +ENGINE=vllm_omni +REWARD_ENGINE=vllm + +reward_path=tests/experimental/reward_loop/reward_fn.py +reward_model_name=$HOME/models/Qwen/Qwen3-VL-8B-Instruct + + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_diffusion_trainer.yaml' \ + algorithm.adv_estimator=flow_grpo \ + data.train_files=$ocr_train_path \ + data.val_files=$ocr_test_path \ + data.train_batch_size=32 \ + data.max_prompt_length=1058 \ + data.filter_overlong_prompts=True \ + +data.apply_chat_template_kwargs.max_length=1058 \ + +data.apply_chat_template_kwargs.padding=True \ + +data.apply_chat_template_kwargs.truncation=True \ + actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen-Image \ + actor_rollout_ref.model.tokenizer_path=$HOME/models/Qwen/Qwen-Image/tokenizer \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=128 \ + actor_rollout_ref.model.target_modules="['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','to_add_out','img_mlp.net.0.proj','img_mlp.net.2','txt_mlp.net.0.proj','txt_mlp.net.2']" \ + actor_rollout_ref.actor.optim.lr=3e-4 \ + actor_rollout_ref.actor.optim.weight_decay=0.0001 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.rollout.guidance_scale=1.0 \ + actor_rollout_ref.rollout.agent.default_agent_loop=diffusion_single_turn_agent \ + actor_rollout_ref.rollout.agent.num_workers=4 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.rollout.max_model_len=1058 \ + actor_rollout_ref.rollout.noise_level=1.2 \ + actor_rollout_ref.rollout.sde_window_size=2 \ + actor_rollout_ref.rollout.sde_window_range="[0,5]" \ + actor_rollout_ref.rollout.val_kwargs.num_inference_steps=50 \ + +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.custom_pipeline=verl.utils.vllm_omni.pipelines.QwenImagePipelineWithLogProb \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + reward.num_workers=4 \ + reward.reward_manager.name=image \ + reward.reward_model.enable=True \ + reward.reward_model.model_path=$reward_model_name \ + reward.reward_model.rollout.name=$REWARD_ENGINE \ + reward.reward_model.enable_resource_pool=True \ + reward.reward_model.nnodes=1 \ + reward.reward_model.n_gpus_per_node=1 \ + reward.reward_model.rollout.gpu_memory_utilization=0.9 \ + reward.reward_model.rollout.free_cache_engine=False \ + reward.reward_model.rollout.tensor_model_parallel_size=1 \ + reward.reward_model.rollout.enforce_eager=False \ + reward.custom_reward_function.path=$reward_path \ + reward.custom_reward_function.name=compute_score_ocr \ + trainer.use_legacy_worker_impl=disable \ + trainer.logger='["console", "wandb"]' \ + trainer.project_name=flow_grpo \ + trainer.experiment_name=qwen_image_ocr \ + trainer.log_val_generations=8 \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=30 \ + trainer.test_freq=30 \ + trainer.total_epochs=15 $@ diff --git a/examples/flowgrpo_trainer/run_flowgrpo_fast.sh b/examples/flowgrpo_trainer/run_flowgrpo_fast.sh new file mode 100644 index 00000000000..57074241954 --- /dev/null +++ b/examples/flowgrpo_trainer/run_flowgrpo_fast.sh @@ -0,0 +1,72 @@ +# Qwen-Image lora, vllm_omni rollout +set -x + +ocr_train_path=$HOME/data/ocr/train.parquet +ocr_test_path=$HOME/data/ocr/test.parquet + +ENGINE=vllm_omni +REWARD_ENGINE=vllm + +reward_path=tests/experimental/reward_loop/reward_fn.py +reward_model_name=$HOME/models/Qwen/Qwen3-VL-8B-Instruct + + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_diffusion_trainer.yaml' \ + algorithm.adv_estimator=flow_grpo \ + data.train_files=$ocr_train_path \ + data.val_files=$ocr_test_path \ + data.train_batch_size=32 \ + data.max_prompt_length=1058 \ + data.filter_overlong_prompts=True \ + +data.apply_chat_template_kwargs.max_length=1058 \ + +data.apply_chat_template_kwargs.padding=True \ + +data.apply_chat_template_kwargs.truncation=True \ + actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen-Image \ + actor_rollout_ref.model.tokenizer_path=$HOME/models/Qwen/Qwen-Image/tokenizer \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=128 \ + actor_rollout_ref.model.target_modules="['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','to_add_out','img_mlp.net.0.proj','img_mlp.net.2','txt_mlp.net.0.proj','txt_mlp.net.2']" \ + actor_rollout_ref.actor.optim.lr=3e-4 \ + actor_rollout_ref.actor.optim.weight_decay=0.0001 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.rollout.guidance_scale=1.0 \ + actor_rollout_ref.rollout.agent.default_agent_loop=diffusion_single_turn_agent \ + actor_rollout_ref.rollout.agent.num_workers=4 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.rollout.max_model_len=1058 \ + actor_rollout_ref.rollout.noise_level=1.2 \ + actor_rollout_ref.rollout.sde_window_size=2 \ + actor_rollout_ref.rollout.sde_window_range="[0,5]" \ + actor_rollout_ref.rollout.val_kwargs.num_inference_steps=50 \ + +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.custom_pipeline=verl.utils.vllm_omni.pipelines.QwenImagePipelineWithLogProb \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + reward.num_workers=4 \ + reward.reward_manager.name=image \ + reward.reward_model.enable=True \ + reward.reward_model.model_path=$reward_model_name \ + reward.reward_model.rollout.name=$REWARD_ENGINE \ + reward.reward_model.rollout.tensor_model_parallel_size=4 \ + reward.custom_reward_function.path=$reward_path \ + reward.custom_reward_function.name=compute_score_ocr \ + trainer.use_legacy_worker_impl=disable \ + trainer.logger='["console", "wandb"]' \ + trainer.project_name=flow_grpo \ + trainer.experiment_name=qwen_image_ocr \ + trainer.log_val_generations=8 \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=30 \ + trainer.test_freq=30 \ + trainer.total_epochs=15 $@ diff --git a/examples/flowgrpo_trainer/run_flowgrpo_full_ft.sh b/examples/flowgrpo_trainer/run_flowgrpo_full_ft.sh new file mode 100644 index 00000000000..691d5ed2042 --- /dev/null +++ b/examples/flowgrpo_trainer/run_flowgrpo_full_ft.sh @@ -0,0 +1,69 @@ +# Qwen-Image full weight finetuning, vllm_omni rollout +set -x + +ocr_train_path=$HOME/data/ocr/train.parquet +ocr_test_path=$HOME/data/ocr/test.parquet + +ENGINE=vllm_omni +REWARD_ENGINE=vllm + +reward_path=tests/experimental/reward_loop/reward_fn.py +reward_model_name=$HOME/models/Qwen/Qwen3-VL-8B-Instruct + + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_diffusion_trainer.yaml' \ + algorithm.adv_estimator=flow_grpo \ + data.train_files=$ocr_train_path \ + data.val_files=$ocr_test_path \ + data.train_batch_size=32 \ + data.max_prompt_length=1058 \ + data.filter_overlong_prompts=True \ + +data.apply_chat_template_kwargs.max_length=1058 \ + +data.apply_chat_template_kwargs.padding=True \ + +data.apply_chat_template_kwargs.truncation=True \ + actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen-Image \ + actor_rollout_ref.model.tokenizer_path=$HOME/models/Qwen/Qwen-Image/tokenizer \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.actor.optim.weight_decay=0.0001 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.rollout.guidance_scale=1.0 \ + actor_rollout_ref.rollout.agent.default_agent_loop=diffusion_single_turn_agent \ + actor_rollout_ref.rollout.agent.num_workers=4 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.rollout.max_model_len=1058 \ + actor_rollout_ref.rollout.noise_level=1.2 \ + actor_rollout_ref.rollout.sde_window_size=2 \ + actor_rollout_ref.rollout.sde_window_range="[0,5]" \ + actor_rollout_ref.rollout.val_kwargs.num_inference_steps=50 \ + +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.custom_pipeline=verl.utils.vllm_omni.pipelines.QwenImagePipelineWithLogProb \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + reward.num_workers=4 \ + reward.reward_manager.name=image \ + reward.reward_model.enable=True \ + reward.reward_model.model_path=$reward_model_name \ + reward.reward_model.rollout.name=$REWARD_ENGINE \ + reward.reward_model.rollout.tensor_model_parallel_size=4 \ + reward.custom_reward_function.path=$reward_path \ + reward.custom_reward_function.name=compute_score_ocr \ + trainer.use_legacy_worker_impl=disable \ + trainer.logger='["console", "wandb"]' \ + trainer.project_name=flow_grpo \ + trainer.experiment_name=qwen_image_ocr \ + trainer.log_val_generations=8 \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=30 \ + trainer.test_freq=30 \ + trainer.total_epochs=15 $@ diff --git a/scripts/generate_trainer_config.sh b/scripts/generate_trainer_config.sh index c4c89cdbdba..bfd3ba12ef3 100755 --- a/scripts/generate_trainer_config.sh +++ b/scripts/generate_trainer_config.sh @@ -6,6 +6,7 @@ set -euox pipefail CONFIG_SPECS=( "ppo_trainer:_generated_ppo_trainer.yaml:" "ppo_megatron_trainer:_generated_ppo_megatron_trainer.yaml:--config-name=ppo_megatron_trainer.yaml" + "ppo_diffusion_trainer:_generated_ppo_diffusion_trainer.yaml:--config-name=ppo_diffusion_trainer.yaml" "ppo_trainer:_generated_ppo_veomni_trainer.yaml:model_engine=veomni" "ppo_trainer:_generated_ppo_torchtitan_trainer.yaml:model_engine=torchtitan" ) diff --git a/tests/experimental/agent_loop/test_diffusion_agent_loop.py b/tests/experimental/agent_loop/test_diffusion_agent_loop.py new file mode 100644 index 00000000000..6f615b2db29 --- /dev/null +++ b/tests/experimental/agent_loop/test_diffusion_agent_loop.py @@ -0,0 +1,135 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import pytest +import ray +from omegaconf import DictConfig + +from verl.experimental.agent_loop.agent_loop import AgentLoopManager +from verl.protocol import DataProto + + +@pytest.fixture +def init_config() -> DictConfig: + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_diffusion_trainer") + + model_path = os.path.expanduser("~/models/tiny-random/Qwen-Image") + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.model.tokenizer_path = os.path.join(model_path, "tokenizer") + config.actor_rollout_ref.rollout.name = "vllm_omni" + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.enforce_eager = True + config.actor_rollout_ref.rollout.n = 4 + config.actor_rollout_ref.rollout.num_inference_steps = 10 + config.actor_rollout_ref.rollout.guidance_scale = 4.0 + config.actor_rollout_ref.rollout.agent.num_workers = 2 + config.actor_rollout_ref.rollout.agent.default_agent_loop = "diffusion_single_turn_agent" + config.actor_rollout_ref.rollout.noise_level = 1.0 + config.actor_rollout_ref.rollout.sde_window_size = 2 + config.actor_rollout_ref.rollout.sde_window_range = [0, 5] + config.actor_rollout_ref.rollout.calculate_log_probs = True + config.actor_rollout_ref.rollout.nnodes = 1 + + qwen_pipeline = "verl.utils.vllm_omni.pipelines.QwenImagePipelineWithLogProb" + config.actor_rollout_ref.rollout.engine_kwargs.vllm_omni = {"custom_pipeline": qwen_pipeline} + config.reward.reward_manager.name = "image" + config.trainer.n_gpus_per_node = 4 + + tokenizer_max_length = 1024 + prompt_template_encode_start_idx = 34 + max_length = tokenizer_max_length + prompt_template_encode_start_idx + + config.data.apply_chat_template_kwargs = dict(max_length=max_length, padding=True, truncation=True) + config.data.max_prompt_length = max_length + config.actor_rollout_ref.rollout.max_model_len = max_length + + # TODO (mike): test with TP later + config.actor_rollout_ref.rollout.tensor_model_parallel_size = 1 + return config + + +def test_single_turn(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + } + } + ) + + agent_loop_manager = AgentLoopManager.create(init_config) + + system_prompt = ( + "Describe the image by detailing the color, shape, size, texture, quantity, text, " + "spatial relationships of the objects and background:" + ) + user_prompts = ["A photo of cute cat with long fur and big eyes.", "A photo of cute dog with short hair."] + + raw_prompts = [] + for user_prompt in user_prompts: + raw_prompts.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + ) + + raw_negative_prompts = [] + for user_prompt in user_prompts: + raw_negative_prompts.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": " "}, + ] + ) + + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array(raw_prompts), + "raw_negative_prompt": np.array(raw_negative_prompts), + "data_source": np.array(["jpeg_compressibility"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": ""}] * len(raw_prompts)), + }, + ) + n = init_config.actor_rollout_ref.rollout.n + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + expected_batch_keys = [ + "responses", + "all_latents", + "all_timesteps", + "prompt_embeds", + "prompt_embeds_mask", + "input_ids", + "attention_mask", + "rollout_log_probs", + ] + for key in expected_batch_keys: + assert key in result.batch, f"Key {key} not found in result batch with keys {list(result.batch.keys())}." + + # check turns + num_turns = result.non_tensor_batch["__num_turns__"] + assert np.all(num_turns == 2) + + print("Test passed!") + ray.shutdown() diff --git a/tests/experimental/reward_loop/assets/ocr.jpg b/tests/experimental/reward_loop/assets/ocr.jpg new file mode 100644 index 00000000000..3d80bacfdf5 Binary files /dev/null and b/tests/experimental/reward_loop/assets/ocr.jpg differ diff --git a/tests/experimental/reward_loop/reward_fn.py b/tests/experimental/reward_loop/reward_fn.py index 27da6ff1884..6e24782e3d5 100644 --- a/tests/experimental/reward_loop/reward_fn.py +++ b/tests/experimental/reward_loop/reward_fn.py @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import json import os +from io import BytesIO import aiohttp +import numpy as np +import torch from openai.types.chat import ChatCompletion +from PIL import Image from transformers import PreTrainedTokenizer GRM_PROMPT_TEMPLATE = """ @@ -98,3 +103,86 @@ def compute_score_math_verify( model_output=solution_str, ground_truth=ground_truth, ) + + +def _pil_image_to_base64(image: Image.Image) -> str: + buffered = BytesIO() + image.save(buffered, format="PNG") + encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8") + base64_image = f"data:image;base64,{encoded_image_text}" + return base64_image + + +async def compute_score_ocr( + data_source: str, + solution_image: Image.Image | np.ndarray | torch.Tensor, + ground_truth: str, + extra_info: dict, + reward_router_address: str, + reward_model_tokenizer: PreTrainedTokenizer = None, + model_name: str = None, +): + """Compute the reward score.""" + import re + + import Levenshtein + + from verl.utils.ray_utils import get_event_loop + + # preprocess image to base64 + image = solution_image + if isinstance(image, torch.Tensor): + image = image.float().permute(1, 2, 0).cpu().numpy() + if isinstance(image, np.ndarray): + assert image.shape[-1] == 3, "must be in HWC format" + image = (image * 255).round().clip(0, 255).astype(np.uint8) + image = Image.fromarray(image) + assert isinstance(image, Image.Image) + + image_base64 = await get_event_loop().run_in_executor(None, _pil_image_to_base64, image) + + # prepare chat template + grm_prompt = "Please output only the text content from the image without any additional descriptions or formatting." + query = [ + { + "type": "image_url", + "image_url": {"url": image_base64}, + }, + {"type": "text", "text": grm_prompt}, + ] + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": query, + }, + ] + + sampling_params = {"temperature": 0.7, "top_p": 0.8, "max_tokens": 4096} + model_name = model_name or os.path.expanduser("~/models/Qwen/Qwen2.5-VL-3B-Instruct") + chat_complete_request = { + "messages": messages, + "model": model_name, + **sampling_params, + } + result = await chat_complete( + router_address=reward_router_address, + chat_complete_request=chat_complete_request, + ) + grm_response = result.choices[0].message.content + + # compute OCR score + text = grm_response + # remove any nonvisible characters and convert to lowercase + gt = re.sub(r"\s+", "", ground_truth).lower() + text = re.sub(r"\s+", "", text).lower() + if gt in text: + dist = 0 + else: + dist = Levenshtein.distance(text, gt) + + # recognized many unrelated characters, only add one character penalty + dist = min(dist, len(gt)) + score = 1 - dist / len(gt) + + return {"score": score, "acc": score == 1, "genrm_response": grm_response} diff --git a/tests/experimental/reward_loop/test_diffusion_reward_model_genrm.py b/tests/experimental/reward_loop/test_diffusion_reward_model_genrm.py new file mode 100644 index 00000000000..c20760d7e6d --- /dev/null +++ b/tests/experimental/reward_loop/test_diffusion_reward_model_genrm.py @@ -0,0 +1,111 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import ray +import torch +from hydra import compose, initialize_config_dir +from PIL import Image + +from verl.experimental.reward_loop import RewardLoopManager +from verl.protocol import DataProto +from verl.utils import hf_tokenizer + + +def create_data_samples(tokenizer) -> DataProto: + images = ["tests/experimental/reward_loop/assets/ocr.jpg"] + prompts = ['a photo of displaying "OCR"'] + pil_images = [np.array(Image.open(img).convert("RGB").resize((512, 512))) for img in images] + responses = [torch.tensor(img).permute(2, 0, 1) / 255.0 for img in pil_images] + data_source = ["ocr"] * len(images) + reward_info = [{"ground_truth": "OCR"}] * len(images) + extra_info = [{}] * len(images) + + responses = torch.stack(responses) + prompt_length = 1024 + pad_token_id = tokenizer.pad_token_id + prompt_ids = [] + for prompt in prompts: + prompt_tokens = tokenizer.encode(prompt) + padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens + prompt_ids.append(torch.tensor(padded_prompt)) + prompt_ids = torch.stack(prompt_ids) + + data = DataProto.from_dict( + tensors={ + "input_ids": prompt_ids, + "responses": responses, + }, + non_tensors={ + "data_source": data_source, + "reward_model": reward_info, + "extra_info": extra_info, + }, + ) + return data + + +def test_diffusion_reward_model_manager(): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + + rollout_model_name = os.path.expanduser("~/models/Qwen/Qwen-Image") + reward_model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-VL-3B-Instruct") + + config.actor_rollout_ref.model.path = rollout_model_name + config.actor_rollout_ref.model.tokenizer_path = os.path.join(rollout_model_name, "tokenizer") + config.reward.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py" + config.reward.custom_reward_function.name = "compute_score_ocr" + config.reward.num_workers = 1 + config.reward.reward_manager.name = "image" + config.reward.reward_model.enable = True + config.reward.reward_model.enable_resource_pool = True + config.reward.reward_model.n_gpus_per_node = 2 + config.reward.reward_model.nnodes = 1 + config.reward.reward_model.model_path = reward_model_name + config.reward.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") + config.reward.reward_model.rollout.gpu_memory_utilization = 0.9 + config.reward.reward_model.rollout.tensor_model_parallel_size = 2 + config.reward.reward_model.rollout.skip_tokenizer_init = False + config.reward.reward_model.rollout.prompt_length = 2048 + config.reward.reward_model.rollout.response_length = 4096 + + # 1. init reward model manager + reward_loop_manager = RewardLoopManager(config) + + # 2. init test data + rollout_tokenizer = hf_tokenizer(config.actor_rollout_ref.model.tokenizer_path) + data = create_data_samples(rollout_tokenizer) + + # 3. generate responses + outputs = reward_loop_manager.compute_rm_score(data) + + for idx, output in enumerate(outputs): + print(f"GRM Response {idx}:\n{output.non_tensor_batch['genrm_response']}\n") + print(f"Score:\n{output.non_tensor_batch['score']}\n") + print("=" * 50 + "\n") + + ray.shutdown() diff --git a/tests/models/test_diffusers_fsdp_engine.py b/tests/models/test_diffusers_fsdp_engine.py new file mode 100644 index 00000000000..cd4fc7e50e3 --- /dev/null +++ b/tests/models/test_diffusers_fsdp_engine.py @@ -0,0 +1,221 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import partial + +import numpy as np +import pytest +import ray +import torch + +from verl import DataProto +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.utils import tensordict_utils as tu +from verl.utils.diffusers.schedulers import FlowMatchSDEDiscreteScheduler +from verl.utils.diffusers.utils import set_timesteps +from verl.workers.config import DiffusersModelConfig, FSDPActorConfig, TrainingWorkerConfig +from verl.workers.engine_workers import TrainingWorker +from verl.workers.utils.losses import ppo_loss +from verl.workers.utils.padding import embeds_padding_2_no_padding + + +def create_training_config(model_type, strategy, device_count, model): + if device_count == 1: + cp = fsdp_size = 1 + else: + cp = 1 # TODO (mike): test with cp = 2 + fsdp_size = 4 + path = os.path.expanduser(model) + tokenizer_path = os.path.join(path, "tokenizer") + model_config = DiffusersModelConfig( + path=path, + tokenizer_path=tokenizer_path, + use_remove_padding=True, + ) + + if strategy in ["fsdp", "fsdp2"]: + from hydra import compose, initialize_config_dir + + from verl.utils.config import omega_conf_to_dataclass + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/model")): + cfg = compose( + config_name="diffusion_model", + overrides=[ + "path=" + path, + "tokenizer_path=" + tokenizer_path, + "lora_rank=8", + "lora_alpha=16", + ], + ) + model_config: DiffusersModelConfig = omega_conf_to_dataclass(cfg) + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/actor")): + cfg = compose( + config_name="dp_actor", + overrides=[ + "strategy=" + strategy, + "clip_ratio=0.0001", + "clip_ratio_high=5.0", + "ppo_mini_batch_size=4", + "ppo_micro_batch_size_per_gpu=4", + "optim.lr=1e-4", + "optim.weight_decay=0.0001", + "fsdp_config.param_offload=False", + "fsdp_config.optimizer_offload=False", + "fsdp_config.model_dtype='bfloat16'", + "fsdp_config.dtype='bfloat16'", + "+fsdp_config.mixed_precision.param_dtype='bfloat16'", + "fsdp_config.forward_only=False", + "fsdp_config.fsdp_size=" + str(fsdp_size), + "fsdp_config.ulysses_sequence_parallel_size=" + str(cp), + "policy_loss.loss_mode='flow_grpo'", + ], + ) + actor_config: FSDPActorConfig = omega_conf_to_dataclass(cfg) + + engine_config = actor_config.engine + optimizer_config = actor_config.optim + checkpoint_config = actor_config.checkpoint + else: + raise NotImplementedError(f"strategy {strategy} is not supported") + + training_config = TrainingWorkerConfig( + model_type=model_type, + model_config=model_config, + engine_config=engine_config, + optimizer_config=optimizer_config, + checkpoint_config=checkpoint_config, + ) + return training_config, actor_config + + +def create_data_samples(num_device: int, model_config: DiffusersModelConfig) -> DataProto: + from tensordict import TensorDict + + scheduler = FlowMatchSDEDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path=model_config.local_path, subfolder="scheduler" + ) + set_timesteps(scheduler, model_config) + + batch_size = 8 * num_device + seq_len = 64 + img_size = 512 + latent_dim = 64 + encoder_latent_dim = 32 + inference_steps = 40 + vocab_size = 99 + vae_scale_factor = 8 + height, width = img_size, img_size + latent_height, latent_width = height // vae_scale_factor // 2, width // vae_scale_factor // 2 + num_diffusion_steps = 10 + timesteps = scheduler.timesteps[None].repeat(batch_size, 1) + + torch.manual_seed(1) + np.random.seed(1) + + batch = TensorDict( + { + "input_ids": torch.randint(0, vocab_size, (batch_size, seq_len)), + "attention_mask": torch.ones((batch_size, inference_steps)), + "response_mask": torch.ones((batch_size, inference_steps)), + "old_log_probs": torch.randn((batch_size, num_diffusion_steps)), + "advantages": torch.randn((batch_size, num_diffusion_steps)), + "responses": torch.randn((batch_size, 3, height, width)), + "all_latents": torch.randn((batch_size, inference_steps, latent_height * latent_width, latent_dim)), + "rollout_log_probs": torch.randn((batch_size, num_diffusion_steps)), + "all_timesteps": timesteps, + "prompt_embeds": torch.randn((batch_size, seq_len, encoder_latent_dim)), + "prompt_embeds_mask": torch.ones((batch_size, seq_len), dtype=torch.int32), + "negative_prompt_embeds": torch.randn((batch_size, seq_len, encoder_latent_dim)), + "negative_prompt_embeds_mask": torch.ones((batch_size, seq_len), dtype=torch.int32), + "loss_mask": torch.ones((batch_size, inference_steps), dtype=torch.int32), + }, + batch_size=batch_size, + ) + data = DataProto(batch=batch) + data.meta_info["global_token_num"] = torch.sum(data.batch["attention_mask"], dim=-1).tolist() + data.meta_info["use_dynamic_bsz"] = False + data.meta_info["micro_batch_size_per_gpu"] = 4 + data.meta_info["height"] = height + data.meta_info["width"] = width + + return data + + +@pytest.mark.parametrize("strategy", ["fsdp", "fsdp2"]) +def test_diffusers_fsdp_engine(strategy): + # Create configs + ray.init() + device_count = torch.cuda.device_count() + training_config, actor_config = create_training_config( + model_type="diffusion_model", + strategy=strategy, + device_count=device_count, + model="~/models/tiny-random/Qwen-Image", + ) + # init model + ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(TrainingWorker), config=training_config) + resource_pool = RayResourcePool(process_on_nodes=[device_count]) + wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) # TrainigWorker + wg.reset() + + # forward only without loss function + data_td = create_data_samples(device_count, training_config.model_config).to_tensordict() + data_td = embeds_padding_2_no_padding(data_td) + tu.assign_non_tensor( + data_td, + compute_loss=False, + image_height=training_config.model_config.get("image_height", 512), + image_width=training_config.model_config.get("image_width", 512), + vae_scale_factor=training_config.model_config.get("vae_scale_factor", 8), + ) + output = wg.infer_batch(data_td) + output_dict = output.get() + + print("Output:", output_dict) + for key in ["log_probs", "metrics"]: + assert key in output_dict + + # forward and backward with loss function + # set loss function + loss_fn = partial(ppo_loss, config=actor_config) + wg.set_loss_fn(loss_fn) + + # train batch + data_td = create_data_samples(device_count, training_config.model_config).to_tensordict() + data_td = embeds_padding_2_no_padding(data_td) + ppo_mini_batch_size = 4 + ppo_epochs = actor_config.ppo_epochs + seed = 42 + shuffle = actor_config.shuffle + tu.assign_non_tensor( + data_td, + global_batch_size=ppo_mini_batch_size * device_count, + mini_batch_size=ppo_mini_batch_size * device_count, + epochs=ppo_epochs, + seed=seed, + dataloader_kwargs={"shuffle": shuffle}, + image_height=training_config.model_config.get("image_height", 512), + image_width=training_config.model_config.get("image_width", 512), + vae_scale_factor=training_config.model_config.get("vae_scale_factor", 8), + ) + output = wg.train_mini_batch(data_td) + output_dict = output.get() + + print("Output:", output_dict) + assert "metrics" in output_dict.keys() + + ray.shutdown() diff --git a/tests/special_sanity/check_pr_title.py b/tests/special_sanity/check_pr_title.py index 1153d9d77af..26fb412cba7 100644 --- a/tests/special_sanity/check_pr_title.py +++ b/tests/special_sanity/check_pr_title.py @@ -19,7 +19,7 @@ pr_title = os.environ.get("PR_TITLE", "").strip() # Define rules -allowed_modules = ["fsdp", "megatron", "veomni", "sglang", "vllm", "trtllm", "rollout", "trainer"] +allowed_modules = ["fsdp", "megatron", "veomni", "sglang", "vllm", "vllm_omni", "trtllm", "rollout", "trainer"] allowed_modules += ["tests", "training_utils", "recipe", "hardware", "deployment"] allowed_modules += ["ray", "worker", "single_controller", "misc", "docker", "ci"] allowed_modules += ["perf", "model", "algo", "env", "tool", "ckpt", "doc", "data", "cfg", "reward"] diff --git a/tests/trainer/ppo/test_flow_grpo_core_algos.py b/tests/trainer/ppo/test_flow_grpo_core_algos.py new file mode 100644 index 00000000000..8ebdd8ae315 --- /dev/null +++ b/tests/trainer/ppo/test_flow_grpo_core_algos.py @@ -0,0 +1,95 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import uuid + +import numpy as np +import pytest +import torch + +from verl.trainer.ppo.core_algos import ( + compute_flow_grpo_outcome_advantage, + compute_policy_loss_flow_grpo, +) +from verl.utils.config import omega_conf_to_dataclass + + +@pytest.mark.parametrize("norm_adv_by_std_in_grpo", [True, False]) +@pytest.mark.parametrize("global_std", [True, False]) +def test_flow_grpo_advantage_return(norm_adv_by_std_in_grpo: bool, global_std: bool) -> None: + """Test flow-GRPO advantage and return computation.""" + + # prepere input + batch_size = 8 + steps = 10 + token_level_rewards = torch.randn((batch_size, 1), dtype=torch.float32) + response_mask = torch.ones((batch_size, steps), dtype=torch.int32) + uid = np.array([uuid.uuid4().hex for _ in range(batch_size)]) + + advantages, returns = compute_flow_grpo_outcome_advantage( + token_level_rewards=token_level_rewards, + response_mask=response_mask, + index=uid, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + global_std=global_std, + ) + + assert advantages.shape == returns.shape == (batch_size, steps) + + +def test_compute_policy_loss_flow_grpo() -> None: + """Test flow-GRPO policy loss computation.""" + + # prepare input + batch_size = 8 + steps = 10 + rollout_log_probs = torch.randn((batch_size, steps), dtype=torch.float32) + current_log_probs = torch.randn((batch_size, steps), dtype=torch.float32) + advantages = torch.randn((batch_size, steps), dtype=torch.float32) + response_mask = torch.ones((batch_size, steps), dtype=torch.int32) + from hydra import compose, initialize_config_dir + + from verl.workers.config.actor import FSDPActorConfig + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/actor")): + cfg = compose( + config_name="dp_actor", + overrides=[ + "strategy=fsdp", + "clip_ratio=0.0001", + "clip_ratio_high=5.0", + "ppo_micro_batch_size_per_gpu=8", + ], + ) + actor_config: FSDPActorConfig = omega_conf_to_dataclass(cfg) + + for step in range(steps): + pg_loss, pg_metrics = compute_policy_loss_flow_grpo( + old_log_prob=rollout_log_probs[:, step], + log_prob=current_log_probs[:, step], + advantages=advantages[:, step], + response_mask=response_mask[:, step], + loss_agg_mode="token-mean", + config=actor_config, + ) + + assert pg_loss.shape == () + assert isinstance(pg_loss.item(), float) + assert "actor/ppo_kl" in pg_metrics.keys() + + +if __name__ == "__main__": + unittest.main() diff --git a/verl/experimental/agent_loop/__init__.py b/verl/experimental/agent_loop/__init__.py index d43683df3e4..e819dd134a5 100644 --- a/verl/experimental/agent_loop/__init__.py +++ b/verl/experimental/agent_loop/__init__.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .agent_loop import AgentLoopBase, AgentLoopManager, AgentLoopWorker, AsyncLLMServerManager +from .agent_loop import ( + AgentLoopBase, + AgentLoopManager, + AgentLoopWorker, + AsyncLLMServerManager, + DiffusionAgentLoopWorker, +) from .single_turn_agent_loop import SingleTurnAgentLoop from .tool_agent_loop import ToolAgentLoop _ = [SingleTurnAgentLoop, ToolAgentLoop] -__all__ = ["AgentLoopBase", "AgentLoopManager", "AsyncLLMServerManager", "AgentLoopWorker"] +__all__ = ["AgentLoopBase", "AgentLoopManager", "AsyncLLMServerManager", "AgentLoopWorker", "DiffusionAgentLoopWorker"] diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 5383ae4a2a5..c17ca93398c 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -23,6 +23,7 @@ import numpy as np import ray import torch +import torch.nn.functional as F from cachetools import LRUCache from omegaconf import DictConfig, OmegaConf from PIL import Image @@ -45,8 +46,8 @@ rollout_trace_op, ) from verl.utils.tokenizer import normalize_token_ids -from verl.workers.config import HFModelConfig, RolloutConfig -from verl.workers.rollout.replica import TokenOutput, get_rollout_replica_class +from verl.workers.config import DiffusersModelConfig, HFModelConfig, RolloutConfig +from verl.workers.rollout.replica import ImageOutput, TokenOutput, get_rollout_replica_class logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -141,7 +142,8 @@ async def generate( sampling_params: dict[str, Any], image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, - ) -> TokenOutput: + **kwargs: Any, + ) -> TokenOutput | ImageOutput: """Generate tokens from prompt ids. Args: @@ -150,7 +152,7 @@ async def generate( sampling_params (Dict[str, Any]): Sampling parameters for the chat completion. Returns: - TokenOutput: token output + TokenOutput | ImageOutput: token or image output """ server_id, server = await self._acquire_server(request_id) try: @@ -160,6 +162,7 @@ async def generate( sampling_params=sampling_params, image_data=image_data, video_data=video_data, + **kwargs, ) return output finally: @@ -226,6 +229,48 @@ class _InternalAgentLoopOutput(AgentLoopOutput): """Extra fields for dynamic addition.""" +class DiffusionAgentLoopOutput(BaseModel): + """Agent loop output.""" + + prompt_ids: list[int] + """Prompt token ids.""" + response_image: list[list[list[float]]] + """Response image (CHW format).""" + response_logprobs: Optional[list[float]] = None + """Log probabilities for the response tokens.""" + multi_modal_data: Optional[dict[str, Any]] = None + """Multi-modal data for multi-modal tools.""" + reward_score: Optional[float] = None + """Reward score for the trajectory.""" + num_turns: int = 0 + """Number of chat turns, including user, assistant, tool.""" + metrics: AgentLoopMetrics + """Auxiliary performance metrics""" + extra_fields: dict[str, Any] = {} + """Extra fields for dynamic addition.""" + + +class _InternalDiffusionAgentLoopOutput(DiffusionAgentLoopOutput): + """Internal agent loop output with padded sequences.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + prompt_ids: torch.Tensor + """Padded prompt token ids.""" + response_image: torch.Tensor + """Response image (NCHW format).""" + input_ids: torch.Tensor + """Padded input ids(prompt_ids).""" + attention_mask: torch.Tensor + """Padded attention mask.""" + response_logprobs: Optional[torch.Tensor] = None + """Log probabilities for the response tokens.""" + multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None + """Multi-modal inputs for processors (e.g., pixel_values, image_grid_thw).""" + extra_fields: dict[str, Any] = {} + """Extra fields for dynamic addition.""" + + class DictConfigWrap: """Wrapper for DictConfig to avoid hydra.utils.instantiate recursive resolve.""" @@ -866,6 +911,371 @@ def _postprocess( ) +class DiffusionAgentLoopWorker: + """Diffusion Agent loop worker takes a batch of messages and run each message in an agent loop. + + Args: + config (DictConfig): whole config for main entrypoint. + server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. + reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation. + """ + + def __init__( + self, + config: DictConfig, + servers: list[tuple[str, ray.actor.ActorHandle]], + load_balancer_handle: ray.actor.ActorHandle, + reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, + ): + """Initialize agent loop manager. + Args: + config (DictConfig): YAML config. + servers (list[tuple[str, ray.actor.ActorHandle]]): (address, handle) pairs for each LLM server. + load_balancer_handle (ray.actor.ActorHandle): shared global load balancer actor. + reward_loop_worker_handles (list[ray.actor.ActorHandle]): Actor handles for streaming reward computation. + """ + self.config = config + rollout_config, model_config = _get_rollout_and_model_config(config) + self.rollout_config: RolloutConfig = omega_conf_to_dataclass(rollout_config) + self.model_config: DiffusersModelConfig = omega_conf_to_dataclass(model_config) + + # for recipe to change + if not hasattr(self, "server_manager"): + self.server_manager = AsyncLLMServerManager( + config, + servers, + load_balancer_handle=load_balancer_handle, + ) + + self.dataset_cls = get_dataset_class(config.data) + self.reward_loop_worker_handles = reward_loop_worker_handles + + self.tokenizer = self.model_config.tokenizer + self.processor = self.model_config.processor + + agent_loop_config_path = self.rollout_config.agent.agent_loop_config_path + if agent_loop_config_path: + resolved_path = resolve_config_path(agent_loop_config_path) + agent_loop_configs = OmegaConf.load(resolved_path) + for agent_loop_config in agent_loop_configs: + _agent_loop_registry[agent_loop_config.name] = agent_loop_config + if self.model_config.get("custom_chat_template", None) is not None: + if self.model_config.processor is not None: + self.model_config.processor.chat_template = self.model_config.custom_chat_template + self.model_config.tokenizer.chat_template = self.model_config.custom_chat_template + + trace_config = self.rollout_config.trace + RolloutTraceConfig.init( + self.rollout_config.trace.project_name, + self.rollout_config.trace.experiment_name, + trace_config.get("backend"), + trace_config.get("token2text", False), + trace_config.get("max_samples_per_step_per_worker", None), + ) + + async def generate_sequences(self, batch: DataProto) -> DataProto: + """Generate sequences from agent loop. + + Args: + batch (DataProto): Input batch. + + Returns: + DataProto: Output batch. + - prompts: [bsz, prompt_length], prompt token ids from dataset. + - responses: [bsz, channel, height, width], output images from diffusion generation. + ... + """ + config = self.rollout_config + + # TODO (mike): it is for Qwen-Image only, need to generalize later + sampling_params = dict( + logprobs=config.calculate_log_probs, + height=config.image_height, + width=config.image_width, + true_cfg_scale=config.guidance_scale, + max_sequence_length=config.max_model_len, + sde_type=config.sde_type, + sde_window_size=config.sde_window_size, + sde_window_range=config.sde_window_range, + ) + + # override sampling params for validation + if batch.meta_info.get("validate", False): + sampling_params["num_inference_steps"] = config.val_kwargs.num_inference_steps + sampling_params["seed"] = config.val_kwargs.seed + sampling_params["noise_level"] = config.val_kwargs.noise_level + else: + sampling_params["num_inference_steps"] = config.num_inference_steps + sampling_params["noise_level"] = config.noise_level + + # by default, we assume it's a single turn agent + if "agent_name" not in batch.non_tensor_batch: + default_agent_loop = config.agent.default_agent_loop + batch.non_tensor_batch["agent_name"] = np.array([default_agent_loop] * len(batch), dtype=object) + + if "index" in batch.non_tensor_batch: + index = batch.non_tensor_batch["index"] + else: + index = np.arange(len(batch)) + + max_samples_per_worker = RolloutTraceConfig.get_instance().max_samples_per_step_per_worker + + # For n rollouts per sample, we trace all n rollouts for selected samples + # Note: This sampling happens per-worker, so total traces = max_samples_per_worker * num_workers * n + if max_samples_per_worker is not None: + unique_sample_indices = np.unique(index) + if max_samples_per_worker < len(unique_sample_indices): + selected_samples = set( + np.random.choice(unique_sample_indices, max_samples_per_worker, replace=False).tolist() + ) + traced_indices = set(i for i in range(len(batch)) if index[i] in selected_samples) + else: + traced_indices = set(range(len(batch))) + else: + traced_indices = set(range(len(batch))) + + trajectory_info = await get_trajectory_info( + batch.meta_info.get("global_steps", -1), index.tolist(), batch.meta_info.get("validate", False) + ) + + tasks = [] + for i in range(len(batch)): + trace_this_sample = i in traced_indices + kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()} + tasks.append( + asyncio.create_task( + self._run_agent_loop(sampling_params, trajectory_info[i], trace=trace_this_sample, **kwargs) + ) + ) + outputs = await asyncio.gather(*tasks) + + output = self._postprocess(outputs, input_non_tensor_batch=batch.non_tensor_batch) + + return output + + async def _run_agent_loop( + self, + sampling_params: dict[str, Any], + trajectory: dict[str, Any], + *, + agent_name: str, + trace: bool = True, + **kwargs, + ) -> _InternalDiffusionAgentLoopOutput: + with rollout_trace_attr( + step=trajectory["step"], + sample_index=trajectory["sample_index"], + rollout_n=trajectory["rollout_n"], + validate=trajectory["validate"], + name="agent_loop", + trace=trace, + ): + assert agent_name in _agent_loop_registry, ( + f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}" + ) + + agent_loop_config = _agent_loop_registry[agent_name] + agent_loop = hydra.utils.instantiate( + config=agent_loop_config, + trainer_config=DictConfigWrap(config=self.config), + server_manager=self.server_manager, + tokenizer=self.tokenizer, + processor=self.processor, + dataset_cls=self.dataset_cls, + data_config=DictConfigWrap(self.config.data), + ) + output: DiffusionAgentLoopOutput = await agent_loop.run(sampling_params, **kwargs) + return await self._agent_loop_postprocess(output, **kwargs) + + async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionAgentLoopOutput: + """Perform post-processing operations on the output of each individual agent loop.""" + # handling extra tensor ouputs from vllm-omni, like prompt embedding, etc. + extra_fields = {} + for k, v in output.extra_fields.items(): + if isinstance(v, torch.Tensor): + # handle prompt embedding padding + if k in ["prompt_embeds", "negative_prompt_embeds"]: + pad_tuple = (0, 0, 0, self.config.actor_rollout_ref.rollout.prompt_length - v.shape[0]) + v = F.pad(v, pad_tuple, value=0) + elif k in ["prompt_embeds_mask", "negative_prompt_embeds_mask"]: + pad_tuple = (0, self.config.actor_rollout_ref.rollout.prompt_length - v.shape[0]) + v = F.pad(v, pad_tuple, value=0) + extra_fields[k] = v.unsqueeze(0) + else: + extra_fields[k] = v + + extra_fields["raw_prompt"] = kwargs["raw_prompt"] + + # TODO(wuxibin): remove padding and use tensordict. + self.tokenizer.padding_side = "left" + prompt_output = self.tokenizer.pad( + {"input_ids": output.prompt_ids}, + padding="max_length", + max_length=self.rollout_config.prompt_length, + return_tensors="pt", + return_attention_mask=True, + ) + if prompt_output["input_ids"].dim() == 1: + prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0) + prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0) + + self.tokenizer.padding_side = "right" + + response_image = torch.tensor(output.response_image) + if response_image.dim() == 3: + response_image = response_image.unsqueeze(0) + + response_logprobs = None + if output.response_logprobs is not None: + response_logprobs = torch.tensor(output.response_logprobs).unsqueeze(0) + + attention_mask = prompt_output["attention_mask"] + input_ids = prompt_output["input_ids"] + + multi_modal_inputs = self._compute_multi_modal_inputs(output, input_ids) + await self._compute_score( + output, + prompts=input_ids, + responses=response_image, + attention_mask=attention_mask, + input_ids=input_ids, + kwargs=kwargs, + ) + + if "reward_extra_info" in output.extra_fields: + extra_fields["reward_extra_info"] = output.extra_fields["reward_extra_info"] + + return _InternalDiffusionAgentLoopOutput( + prompt_ids=input_ids, + response_image=response_image, + input_ids=input_ids, + attention_mask=attention_mask, + response_logprobs=response_logprobs, + multi_modal_inputs=multi_modal_inputs, + multi_modal_data=output.multi_modal_data, + reward_score=output.reward_score, + num_turns=output.num_turns, + metrics=output.metrics, + extra_fields=extra_fields, + ) + + def _compute_multi_modal_inputs(self, output, input_ids) -> dict[str, torch.Tensor]: + """Compute multi-modal inputs with image and video.""" + multi_modal_inputs = {} + if self.processor is None: + return multi_modal_inputs + + raise NotImplementedError("Multi-modal input processing not implemented yet.") + + async def _compute_score(self, output, prompts, responses, attention_mask, input_ids, kwargs): + """Compute reward score for single sample.""" + enable_async_reward = self.reward_loop_worker_handles is not None + + if output.reward_score is None and enable_async_reward: + batch = TensorDict( + { + "prompts": prompts, # [1, prompt_length] + "responses": responses, # [1, channel, height, width] + "attention_mask": attention_mask, # [1, prompt_length] + "input_ids": input_ids, # [1, prompt_length] + }, + batch_size=1, + ) + non_tensor_batch = { + **{k: np.array([v]) for k, v in kwargs.items()}, + "__num_turns__": np.array([output.num_turns]), + "tool_extra_fields": np.array([output.extra_fields], dtype=object), + } + + data = DataProto( + batch=batch, + non_tensor_batch=non_tensor_batch, + ) + selected_reward_loop_worker_handle = random.choice(self.reward_loop_worker_handles) + result = await selected_reward_loop_worker_handle.compute_score.remote(data) + output.reward_score = result["reward_score"] + output.extra_fields["reward_extra_info"] = result["reward_extra_info"] + + def _postprocess( + self, + inputs: list[_InternalDiffusionAgentLoopOutput], + input_non_tensor_batch: dict | None = None, + ) -> DataProto: + """Process the padded outputs from _run_agent_loop and combine them into a batch.""" + # Convert lists back to tensors and stack them to create a batch. + prompt_ids = torch.cat([input.prompt_ids for input in inputs], dim=0) + response_image = torch.cat([input.response_image for input in inputs], dim=0) + attention_mask = torch.cat([input.attention_mask for input in inputs], dim=0) + input_ids = torch.cat([input.input_ids for input in inputs], dim=0) + optional_outputs = {} + if inputs[0].response_logprobs is not None: + optional_outputs["rollout_log_probs"] = torch.cat([input.response_logprobs for input in inputs], dim=0) + + # Handle extra fields that are tensors + extra_keys = [k for k, v in inputs[0].extra_fields.items() if isinstance(v, torch.Tensor)] + for key in extra_keys: + optional_outputs[key] = torch.cat([input.extra_fields[key] for input in inputs], dim=0) + for input in inputs: + del input.extra_fields[key] + + batch = TensorDict( + { + "prompts": prompt_ids, # [bsz, prompt_length] + "responses": response_image, # [bsz, channel, height, width] + "input_ids": input_ids, # [bsz, prompt_length] + "attention_mask": attention_mask, # [bsz, prompt_length] + **optional_outputs, + }, + batch_size=len(inputs), + ) + + scores = [input.reward_score for input in inputs] + if all(score is not None for score in scores): + rm_scores = torch.tensor(scores, dtype=torch.float32).unsqueeze(-1) + batch["rm_scores"] = rm_scores + + non_tensor_batch = { + "__num_turns__": np.array([input.num_turns for input in inputs], dtype=np.int32), + } + if self.reward_loop_worker_handles is None and input_non_tensor_batch: + non_tensor_batch.update(input_non_tensor_batch) + + # add reward_extra_info to non_tensor_batch + reward_extra_infos = [input.extra_fields.get("reward_extra_info", {}) for input in inputs] + reward_extra_keys = list(reward_extra_infos[0].keys()) + for key in reward_extra_keys: + non_tensor_batch[key] = np.array([info[key] for info in reward_extra_infos]) + + # Add multi_modal_inputs to non_tensor_batch if any samples have them + multi_modal_inputs_list = [input.multi_modal_inputs for input in inputs] + if any(mmi is not None for mmi in multi_modal_inputs_list): + non_tensor_batch["multi_modal_inputs"] = np.array(multi_modal_inputs_list, dtype=object) + + metrics = [input.metrics.model_dump() for input in inputs] + # Collect extra fields from all inputs and convert them to np.ndarray + extra_fields = {} + all_keys = set(key for input_item in inputs for key in input_item.extra_fields) + for key in all_keys: + temp_arr = np.empty(len(inputs), dtype=object) + temp_arr[:] = [input.extra_fields.get(key) for input in inputs] + extra_fields[key] = temp_arr + + non_tensor_batch.update(extra_fields) + + # Only include reward_extra_keys in meta_info if rm_scores is in batch + # This avoids conflicts when reward_tensor is merged later in ray_trainer.py + if "rm_scores" in batch.keys(): + meta_info = {"metrics": metrics, "reward_extra_keys": reward_extra_keys} + else: + meta_info = {"metrics": metrics} + + return DataProto( + batch=batch, + non_tensor_batch=non_tensor_batch, + meta_info=meta_info, + ) + + async def get_trajectory_info(step, index, validate): """Get trajectory info. @@ -920,7 +1330,10 @@ def __init__( if not hasattr(self, "rollout_replica_class"): self.rollout_replica_class = get_rollout_replica_class(self.rollout_config.name) if not hasattr(self, "agent_loop_workers_class"): - self.agent_loop_workers_class = ray.remote(AgentLoopWorker) + if self.config.actor_rollout_ref.model.model_type == "diffusion_model": + self.agent_loop_workers_class = ray.remote(DiffusionAgentLoopWorker) + else: + self.agent_loop_workers_class = ray.remote(AgentLoopWorker) @classmethod @auto_await @@ -1059,14 +1472,16 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data # batch sequence generation is bounded by the slowest sample slowest = np.argmax(t_generate_sequences + t_tool_calls) - attention_mask = output.batch["attention_mask"][slowest] prompt_length = output.batch["prompts"].shape[1] timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest] timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest] - timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item() - timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item() timing["agent_loop/slowest/num_preempted"] = num_preempted[slowest] + if "attention_mask" in output.batch: + attention_mask = output.batch["attention_mask"][slowest] + timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item() + timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item() + return timing @auto_await diff --git a/verl/experimental/agent_loop/single_turn_agent_loop.py b/verl/experimental/agent_loop/single_turn_agent_loop.py index 6ad3aa429b3..ab045e3f4ad 100644 --- a/verl/experimental/agent_loop/single_turn_agent_loop.py +++ b/verl/experimental/agent_loop/single_turn_agent_loop.py @@ -16,7 +16,7 @@ from typing import Any from uuid import uuid4 -from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register +from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, DiffusionAgentLoopOutput, register from verl.utils.profiler import simple_timer logger = logging.getLogger(__file__) @@ -80,3 +80,52 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu output.extra_fields.update({"turn_scores": [], "tool_rewards": []}) return output + + +@register("diffusion_single_turn_agent") +class DiffusionSingleTurnAgentLoop(AgentLoopBase): + """Agent loop for diffusion model serving.""" + + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: + raw_prompt = kwargs["raw_prompt"] + + if self.config.actor_rollout_ref.rollout.guidance_scale > 0: + raw_negative_prompt = kwargs["raw_negative_prompt"] + else: + raw_negative_prompt = None + + # 1. extract images and videos from messages + multi_modal_data = await self.process_vision_info(raw_prompt) + images = multi_modal_data.get("images") + videos = multi_modal_data.get("videos") + + # 2. apply chat template and tokenize + prompt_ids = await self.apply_chat_template(raw_prompt, images=images, videos=videos) + + if raw_negative_prompt is not None: + negative_prompt_ids = await self.apply_chat_template(raw_negative_prompt, images=images, videos=videos) + + # 3. generate sequences + metrics = {} + with simple_timer("generate_sequences", metrics): + output = await self.server_manager.generate( + request_id=uuid4().hex, + prompt_ids=prompt_ids, + sampling_params=sampling_params, + image_data=images, + video_data=videos, + negative_prompt_ids=negative_prompt_ids, + ) + if metrics.get("num_preempted") is None: + metrics["num_preempted"] = output.num_preempted if output.num_preempted is not None else -1 + + output = DiffusionAgentLoopOutput( + prompt_ids=prompt_ids, + response_image=output.image, + response_logprobs=output.log_probs, + multi_modal_data=multi_modal_data, + num_turns=2, + metrics=metrics, + extra_fields=output.extra_info, + ) + return output diff --git a/verl/experimental/reward_loop/reward_loop.py b/verl/experimental/reward_loop/reward_loop.py index 151089ec5c6..c2ba1e7470c 100644 --- a/verl/experimental/reward_loop/reward_loop.py +++ b/verl/experimental/reward_loop/reward_loop.py @@ -13,14 +13,17 @@ # limitations under the License. import asyncio +import base64 import logging import os +from io import BytesIO import aiohttp import numpy as np import ray import torch from omegaconf import DictConfig, open_dict +from PIL import Image from tensordict import TensorDict from verl.protocol import DataProto @@ -28,6 +31,7 @@ from verl.trainer.ppo.reward import load_reward_manager from verl.utils import hf_tokenizer from verl.utils.fs import copy_to_local +from verl.utils.ray_utils import get_event_loop from .reward_model import RewardModelManager @@ -114,9 +118,10 @@ def __init__(self, config: DictConfig, reward_router_address: str = None): self.config = config self.reward_router_address = reward_router_address self._init_reward_fn() + self.loop = get_event_loop() def _init_reward_fn(self): - input_tokenizer_local_path = copy_to_local(self.config.actor_rollout_ref.model.path) + input_tokenizer_local_path = copy_to_local(self.config.actor_rollout_ref.model.tokenizer_path) self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, trust_remote_code=True) self.reward_model_tokenizer = None if self.config.reward.reward_model.enable: @@ -199,17 +204,32 @@ async def _preprocess_reward_inputs(self, data: DataProto) -> str: chat: list = list(data_item.non_tensor_batch["raw_prompt"]) # extract response - response_ids = data_item.batch["responses"] - response_length = response_ids.shape[-1] - valid_response_length = data_item.batch["attention_mask"][-response_length:].sum() - valid_response_ids = response_ids[:valid_response_length] + response = data_item.batch["responses"] + if response.ndim == 3: + # handling multi-modal response + response_image = response + if isinstance(response_image, torch.Tensor): + response_image = response_image.float().permute(1, 2, 0).cpu().numpy() + assert response_image.shape[-1] == 3, "must be in HWC format" + response_image = (response_image * 255).round().clip(0, 255).astype(np.uint8) + response_image = Image.fromarray(response_image) + + image_base64 = await self.loop.run_in_executor(None, self._pil_image_to_base64, response_image) + query = self.prepare_query_for_multi_modal(image_base64) + + chat.append({"role": "assistant", "content": query}) + else: + response_ids = response + response_length = response_ids.shape[-1] + valid_response_length = data_item.batch["attention_mask"][-response_length:].sum() + valid_response_ids = response_ids[:valid_response_length] - # decode - rollout_response = self.input_tokenizer.decode(valid_response_ids) - # remove bos and eos - rollout_response = rollout_response.replace(self.input_tokenizer.eos_token, "") + # decode + rollout_response = self.input_tokenizer.decode(valid_response_ids) + # remove bos and eos + rollout_response = rollout_response.replace(self.input_tokenizer.eos_token, "") - chat.append({"role": "assistant", "content": rollout_response}) + chat.append({"role": "assistant", "content": rollout_response}) rm_prompt = self.reward_model_tokenizer.apply_chat_template( chat, @@ -267,6 +287,22 @@ async def compute_score_disrm(self, data: DataProto) -> dict: return {"reward_score": rm_score} + def _pil_image_to_base64(self, image: Image.Image) -> str: + buffered = BytesIO() + image.save(buffered, format="PNG") + encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8") + base64_image = f"data:image;base64,{encoded_image_text}" + return base64_image + + def prepare_query_for_multi_modal(self, image_base64: str) -> list: + query = [ + { + "type": "image_url", + "image_url": {"url": image_base64}, + }, + ] + return query + class RewardLoopManager: """ @@ -281,7 +317,7 @@ def __init__(self, config: DictConfig, rm_resource_pool: RayResourcePool = None) self.reward_router_address = self.reward_model_manager.get_router_address() else: self.reward_model_manager = None - self.reward_router_address = None + self.reward_router_address = self.config.reward.reward_model.get("reward_router_address", None) self.reward_loop_workers_class = ray.remote(RewardLoopWorker) self._init_reward_loop_workers() @@ -319,12 +355,15 @@ def compute_rm_score(self, data: DataProto) -> DataProto: # compute rm score scores = [item["reward_score"] for item in outputs_flat] - prompt_length = data.batch["prompts"].size(1) - valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=1) - rm_scores = torch.zeros_like(data.batch["responses"], dtype=torch.float32) - rm_scores[torch.arange(rm_scores.size(0)), valid_response_length - 1] = torch.tensor( - scores, dtype=torch.float32 - ) + if self.config.reward.reward_manager.name == "image": + rm_scores = torch.tensor(scores, dtype=torch.float32).unsqueeze(-1) + else: + prompt_length = data.batch["prompts"].size(1) + valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=1) + rm_scores = torch.zeros_like(data.batch["responses"], dtype=torch.float32) + rm_scores[torch.arange(rm_scores.size(0)), valid_response_length - 1] = torch.tensor( + scores, dtype=torch.float32 + ) batch = TensorDict({"rm_scores": rm_scores}, batch_size=len(data)) reward_extra_infos = [output.get("reward_extra_info", {}) for output in outputs_flat] diff --git a/verl/experimental/reward_loop/reward_manager/__init__.py b/verl/experimental/reward_loop/reward_manager/__init__.py index 75a440a2324..dc0a541c09b 100644 --- a/verl/experimental/reward_loop/reward_manager/__init__.py +++ b/verl/experimental/reward_loop/reward_manager/__init__.py @@ -17,12 +17,14 @@ from .naive import NaiveRewardManager from .limited import RateLimitedRewardManager from .remote import RemoteRewardManager +from .image import ImageRewardManager __all__ = [ "DAPORewardManager", "NaiveRewardManager", "RateLimitedRewardManager", "RemoteRewardManager", + "ImageRewardManager", "register", "get_reward_manager_cls", ] diff --git a/verl/experimental/reward_loop/reward_manager/image.py b/verl/experimental/reward_loop/reward_manager/image.py new file mode 100644 index 00000000000..1852c83041d --- /dev/null +++ b/verl/experimental/reward_loop/reward_manager/image.py @@ -0,0 +1,92 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +from verl import DataProto +from verl.experimental.reward_loop.reward_manager import register +from verl.experimental.reward_loop.reward_manager.base import RewardManagerBase +from verl.utils.reward_score import default_compute_score_image + + +@register("image") +class ImageRewardManager(RewardManagerBase): + """The reward manager for image response.""" + + def __init__(self, config, tokenizer, compute_score, reward_router_address=None, reward_model_tokenizer=None): + super().__init__(config, tokenizer, compute_score) + self.compute_score = compute_score or default_compute_score_image + self.is_async_reward_score = inspect.iscoroutinefunction(self.compute_score) + self.reward_router_address = reward_router_address + self.reward_model_tokenizer = reward_model_tokenizer + + async def run_single(self, data: DataProto) -> dict: + assert len(data) == 1, "Only support single data item" + data_item = data[0] + response_image = data_item.batch["responses"] + data_source = data_item.non_tensor_batch["data_source"] + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] + extra_info = data_item.non_tensor_batch.get("extra_info", {}) + tool_extra_fields = data_item.non_tensor_batch.get("tool_extra_fields", None) + if tool_extra_fields is not None: + extra_info.update(tool_extra_fields.items()) + + num_turns = data_item.non_tensor_batch.get("__num_turns__", None) + rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {}) + extra_info["num_turns"] = num_turns + extra_info["rollout_reward_scores"] = rollout_reward_scores + + extra_reward_kwargs = ( + { + "reward_router_address": self.reward_router_address, + "reward_model_tokenizer": self.reward_model_tokenizer, + "model_name": self.config.reward.reward_model.model_path, + } + if self.reward_router_address is not None + else {} + ) + if self.is_async_reward_score: + result = await self.compute_score( + data_source=data_source, + solution_image=response_image, + ground_truth=ground_truth, + extra_info=extra_info, + **extra_reward_kwargs, + ) + else: + result = await self.loop.run_in_executor( + None, + lambda: self.compute_score( + data_source=data_source, + solution_image=response_image, + ground_truth=ground_truth, + extra_info=extra_info, + **extra_reward_kwargs, + ), + ) + + reward_extra_info = {} + + score: float + if isinstance(result, dict): + score = result["score"] + for key, value in result.items(): + reward_extra_info[key] = value + else: + score = result + reward_extra_info["acc"] = score + + reward = score + + return {"reward_score": reward, "reward_extra_info": reward_extra_info} diff --git a/verl/trainer/config/_generated_ppo_diffusion_trainer.yaml b/verl/trainer/config/_generated_ppo_diffusion_trainer.yaml new file mode 100644 index 00000000000..42f546d365f --- /dev/null +++ b/verl/trainer/config/_generated_ppo_diffusion_trainer.yaml @@ -0,0 +1,703 @@ +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job --config-name=ppo_diffusion_trainer.yaml' to flatten the 'verl/trainer/config/ppo_diffusion_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. + +actor_rollout_ref: + actor: + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + zero_indexed_step: true + warmup_style: null + override_optimizer_config: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + dtype: bfloat16 + qat: + _target_: verl.workers.config.QATEngineConfig + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + _target_: verl.workers.config.FSDPActorConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.0001 + clip_ratio_low: 0.2 + clip_ratio_high: 5.0 + tau_pos: 1.0 + tau_neg: 1.05 + freeze_vision_tower: false + policy_loss: + _target_: verl.workers.config.PolicyLossConfig + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + loss_scale_factor: null + entropy_coeff: 0 + calculate_entropy: false + use_kl_loss: false + use_prefix_grouper: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + data_loader_seed: 42 + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + mbridge_config: {} + use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} + calculate_sum_pi_squared: false + sum_pi_squared_checkpointing: false + qat: + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + ref: + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: ${actor_rollout_ref.actor.strategy} + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: true + strategy: fsdp + dtype: bfloat16 + qat: + _target_: verl.workers.config.QATEngineConfig + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + _target_: verl.workers.config.FSDPActorConfig + ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + rollout: + _target_: verl.workers.config.DiffusionRolloutConfig + name: ??? + mode: async + nnodes: 0 + n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8} + prompt_length: ${oc.select:data.max_prompt_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + enforce_eager: false + cudagraph_capture_sizes: null + free_cache_engine: true + tensor_model_parallel_size: 2 + data_parallel_size: 1 + expert_parallel_size: 1 + pipeline_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + enable_chunked_prefill: true + enable_prefix_caching: true + logprobs_mode: processed_logprobs + scheduling_policy: fcfs + load_format: dummy + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + image_height: 512 + image_width: 512 + num_inference_steps: 10 + noise_level: 0.7 + guidance_scale: 4.5 + sde_type: sde + sde_window_size: null + sde_window_range: null + engine_kwargs: + vllm_omni: {} + val_kwargs: + _target_: verl.workers.config.DiffusionSamplingConfig + 'n': 1 + do_sample: false + num_inference_steps: 40 + noise_level: 0.0 + seed: 42 + multi_turn: + _target_: verl.workers.config.MultiTurnConfig + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + num_repeat_rollouts: null + calculate_log_probs: false + agent: + _target_: verl.workers.config.AgentLoopConfig + num_workers: 8 + default_agent_loop: single_turn_agent + agent_loop_config_path: null + custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig + path: null + name: null + checkpoint_engine: + _target_: verl.workers.config.CheckpointEngineConfig + backend: naive + update_weights_bucket_megabytes: 2048 + engine_kwargs: {} + enable_rollout_routing_replay: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.contents,[]} + level: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.level,level0} + analysis: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.analysis,false} + discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.discrete,false} + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.contents,[]} + discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.discrete,false} + prometheus: + _target_: verl.workers.config.PrometheusConfig + enable: false + port: 9090 + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + quantization: null + quantization_config_file: null + layered_summon: false + model: + _target_: verl.workers.config.DiffusersModelConfig + path: ~/models/Qwen/Qwen-Image + tokenizer_path: null + use_shm: false + trust_remote_code: false + custom_chat_template: null + external_lib: null + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: true + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + lora_adapter_path: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + tiled_mlp: + enabled: false + num_shards: 4 + image_height: ${oc.select:actor_rollout_ref.rollout.image_height,512} + image_width: ${oc.select:actor_rollout_ref.rollout.image_width,512} + num_inference_steps: ${oc.select:actor_rollout_ref.rollout.num_inference_steps,10} + noise_level: ${oc.select:actor_rollout_ref.rollout.noise_level,0.7} + guidance_scale: ${oc.select:actor_rollout_ref.rollout.guidance_scale,1.0} + sde_type: ${oc.select:actor_rollout_ref.rollout.sde_type,sde} + model_type: diffusion_model + hybrid_engine: true + nccl_timeout: 600 +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, + null} + return_raw_input_ids: false + return_raw_chat: true + return_full_prompt: false + shuffle: true + seed: null + dataloader_num_workers: 8 + image_patch_size: 14 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null + apply_chat_template_kwargs: {} + data_source: prompt +critic: + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + zero_indexed_step: true + warmup_style: null + override_optimizer_config: null + model: + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + dtype: bfloat16 + qat: + _target_: verl.workers.config.QATEngineConfig + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: {} + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + _target_: verl.workers.config.FSDPCriticModelCfg + use_shm: false + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + tiled_mlp: + enabled: false + num_shards: 4 + _target_: verl.workers.config.FSDPCriticConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + enable: null + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + data_loader_seed: 42 + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + mbridge_config: {} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} + forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} + ulysses_sequence_parallel_size: 1 + grad_clip: 1.0 +custom_reward_function: + path: null + name: null +reward_model: + num_workers: null + reward_manager: null + enable: null + enable_resource_pool: null + n_gpus_per_node: null + nnodes: null + reward_loop_source: null + reward_loop_module_path: null + reward_loop_class_name: null + model: + path: null + external_lib: null + trust_remote_code: null + rollout: + name: null + dtype: null + gpu_memory_utilization: null + enforce_eager: null + cudagraph_capture_sizes: null + free_cache_engine: null + data_parallel_size: null + expert_parallel_size: null + tensor_model_parallel_size: null + max_num_batched_tokens: null + max_model_len: null + max_num_seqs: null + load_format: null + engine_kwargs: null + limit_images: null + enable_chunked_prefill: null + enable_prefix_caching: null + disable_log_stats: null + skip_tokenizer_init: null + prompt_length: null + response_length: null +sandbox_fusion: + url: null + max_concurrent: null + memory_limit_mb: null +reward: + num_workers: 8 + custom_reward_function: + path: null + name: compute_score + reward_manager: + _target_: verl.workers.config.reward_model.RewardManagerConfig + source: register + name: naive + module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager + reward_model: + enable: false + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + model_path: null + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + dtype: bfloat16 + gpu_memory_utilization: 0.5 + enforce_eager: true + cudagraph_capture_sizes: null + free_cache_engine: true + data_parallel_size: 1 + expert_parallel_size: 1 + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + load_format: auto + engine_kwargs: {} + limit_images: null + enable_chunked_prefill: true + enable_prefix_caching: true + disable_log_stats: true + skip_tokenizer_init: false + prompt_length: 2048 + response_length: 2048 + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 +algorithm: + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + bypass_mode: false + loss_type: ppo_clip + rollout_is_batch_normalize: false + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + reweight_method: pow + weight_pow: 2.0 + global_std: true +trainer: + balance_batch: true + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + rollout_data_dir: null + validation_data_dir: null + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + val_before_train: true + val_only: false + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + del_local_ckpt_after_load: false + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda + use_legacy_worker_impl: auto +global_profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: null + steps: null + profile_continuous_steps: false + save_path: outputs/profile + global_tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: false + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + torch_memory: + trace_alloc_max_entries: 100000 + stack_depth: 32 + context: all + stacks: all + kw_args: {} +transfer_queue: + enable: false +ray_kwargs: + ray_init: + num_cpus: null + timeline_json_file: null diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py index 5aa650d7bf9..1c9497e4b77 100644 --- a/verl/trainer/config/algorithm.py +++ b/verl/trainer/config/algorithm.py @@ -612,3 +612,4 @@ class AlgoConfig(BaseConfig): # Rollout Correction: corrects off-policy issues (policy mismatch, model staleness, distribution shifts) # Set to None to disable, use RolloutCorrectionConfig presets (e.g., .tis(), .mis()), or pass dict rollout_correction: Optional[RolloutCorrectionConfig] = None + global_std: bool = True diff --git a/verl/trainer/config/model/diffusion_model.yaml b/verl/trainer/config/model/diffusion_model.yaml new file mode 100644 index 00000000000..12fd0241a41 --- /dev/null +++ b/verl/trainer/config/model/diffusion_model.yaml @@ -0,0 +1,95 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +_target_: verl.workers.config.DiffusersModelConfig + +# path to the huggingface model +path: ~/models/Qwen/Qwen-Image + +# config to the huggingface config. In case it is not the same as path +# hf_config_path: null + +# path to the huggingface tokenizer. In case it is not the same as path +tokenizer_path: null + +# whether to use shared memory for model loading +use_shm: False + +# whether to trust remote code. +trust_remote_code: False + +# custom chat template for the model +custom_chat_template: null + +# whether to use external libs for the model +external_lib: null + +# override hf config +# override_config: {} + +# whether to enable gradient checkpointing. Only valid when we use hf model definition +enable_gradient_checkpointing: True + +# whether to enable activation offload. Only valid when we use hf model definition +enable_activation_offload: False + +# whether to use remove padding. Only valid when we use hf model definition +use_remove_padding: True + +# Set to positive value to enable LoRA (e.g., 32) +lora_rank: 0 + +# LoRA scaling factor +lora_alpha: 16 + +# Target modules for LoRA adaptation +target_modules: all-linear + +# Exclude modules from LoRA adaptation +exclude_modules: null + +# Path to pre-trained LoRA adapter to load for continued training +lora_adapter_path: null + +# whether to use liger. Only valid when we use hf model definition +use_liger: False + +# whether to use fused kernels. +use_fused_kernels: False + +# fused kernel options. +fused_kernel_options: + + # the implementation backend for fused kernels. + impl_backend: torch + +# TiledMLP configuration for memory-efficient MLP computation. +# Reduces peak memory by processing MLP forward/backward in tiles. +tiled_mlp: + + # whether to enable TiledMLP + enabled: False + + # number of shards to split the input. Higher values reduce peak memory but may slightly impact performance. + num_shards: 4 + +# image height +image_height: ${oc.select:actor_rollout_ref.rollout.image_height,512} + +# image width +image_width: ${oc.select:actor_rollout_ref.rollout.image_width,512} + +# inference steps +num_inference_steps: ${oc.select:actor_rollout_ref.rollout.num_inference_steps,10} + +# noise in SDE +noise_level: ${oc.select:actor_rollout_ref.rollout.noise_level,0.7} + +# guidance scale for classifier-free guidance +guidance_scale: ${oc.select:actor_rollout_ref.rollout.guidance_scale,1.0} + +# SDE type during rollout +sde_type: ${oc.select:actor_rollout_ref.rollout.sde_type,sde} diff --git a/verl/trainer/config/ppo_diffusion_trainer.yaml b/verl/trainer/config/ppo_diffusion_trainer.yaml new file mode 100644 index 00000000000..3da7c98adfd --- /dev/null +++ b/verl/trainer/config/ppo_diffusion_trainer.yaml @@ -0,0 +1,332 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# specify the default per-component configs +defaults: + + - model_engine: dp + + # @.: + # actor_rollout_ref.actor: trainer/config/actor/dp_actor.yaml + - actor@actor_rollout_ref.actor: ${model_engine}_actor + + # data: trainer/config/data/legacy_data.yaml + - data@data: legacy_data + + # Reference model config. + # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. + - ref@actor_rollout_ref.ref: ${model_engine}_ref + + # Rollout model config. + - rollout@actor_rollout_ref.rollout: diffusion_rollout + + # Model config. + - model@actor_rollout_ref.model: diffusion_model + + # Critic model config. + - critic@critic: ${model_engine}_critic + + # legacy reward impl config, for backward compatibility + - legacy_reward_impl + + # Reward config. + - reward@reward: reward + + # Rollout correction config. + - algorithm@algorithm.rollout_correction: rollout_correction + + # load the reference default config, then apply the fields in the current yaml + # self config override anything above + - _self_ + +data: + + # get ground-truth based on data_source, now support ["ocr", "prompt"] + data_source: "prompt" + +# config for actor, rollout and reference model +actor_rollout_ref: + + # Model config + model: + model_type: "diffusion_model" + + # Whether it's a hybrid engine, currently only supports hybrid engine + hybrid_engine: true + + # Timeout for operations executed against the process group + nccl_timeout: 600 + + # Actor config + actor: + # PPO clip ratio + clip_ratio: 0.0001 + + # Maximum absolute value for advantage clipping + clip_ratio_high: 5.0 + + # Rollout model config. + rollout: + + # for huge model, layered summon can save memory (prevent OOM) but make it slower + layered_summon: False + +# config for the algorithm +algorithm: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.AlgoConfig + + # Discount factor for future rewards + gamma: 1.0 + + # Trade-off between bias and variance in the GAE estimator + lam: 1.0 + + # Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. + adv_estimator: gae + + # Whether to normalize advantages by std (specific to GRPO) + norm_adv_by_std_in_grpo: True + + # Whether to enable in-reward KL penalty + use_kl_in_reward: False + + # How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full" + kl_penalty: kl + + # KL control configuration + kl_ctrl: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.KLControlConfig + + # KL control type: "fixed" or "adaptive" + type: fixed + + # Initial coefficient for KL penalty + kl_coef: 0.001 + + # Horizon value for adaptive controller (if enabled) + horizon: 10000 + + # Target KL divergence (used for adaptive controller) + target_kl: 0.1 + + # Whether to enable preference feedback PPO + use_pf_ppo: False + + # Preference feedback PPO settings + pf_ppo: + + # Method for reweighting samples: "pow", "max_min", or "max_random" + reweight_method: pow + + # Power used for weight scaling in "pow" method + weight_pow: 2.0 + + # Whether to normalize advantages using global standard deviation + global_std: True + +# config for the trainer +trainer: + + # Whether to balance batch sizes across distributed workers + balance_batch: True + + # Number of epochs in training + total_epochs: 30 + + # Total training steps (can be set explicitly or derived from epochs) + total_training_steps: null + + # Project name for experiment tracking (e.g., wandb) + project_name: verl_examples + + # Experiment name for run identification in tracking tools + experiment_name: gsm8k + + # Logging backends to use: "console", "wandb", etc. + logger: ["console", "wandb"] + + # Number of generations to log during validation + log_val_generations: 0 + + # Directory for logging rollout data; no dump if null + rollout_data_dir: null + + # Directory for logging validation data; no dump if null + validation_data_dir: null + + # Number of nodes used in the training + nnodes: 1 + + # Number of GPUs per node + n_gpus_per_node: 8 + + # Save frequency (by iteration) for model checkpoints + save_freq: -1 + + # ESI refers to the elastic server instance used during training, similar to the training plan. For example, + # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training. + # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance. + # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time. + # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety. + esi_redundant_time: 0 + + # Resume mode: "auto", "disable", or "resume_path" + # "auto": resume from last checkpoint if available + # "disable": start from scratch + # "resume_path": resume from a user-defined path + resume_mode: auto + + # Path to resume training from (only used when resume_mode is "resume_path") + resume_from_path: null + + # Whether to run validation before training begins + val_before_train: True + + # Whether to run validation only + val_only: False + + # Validation frequency (in training iterations) + test_freq: -1 + + # Number of iterations to warm up the critic before updating policy + critic_warmup: 0 + + # Default path to distributed filesystem for saving checkpoints + default_hdfs_dir: null + + # Whether to delete local checkpoints after loading + del_local_ckpt_after_load: False + + # Default local directory for saving checkpoints + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + + # Maximum number of actor checkpoints to keep + max_actor_ckpt_to_keep: null + + # Maximum number of critic checkpoints to keep + max_critic_ckpt_to_keep: null + + # Timeout (in seconds) for Ray worker to wait for registration + ray_wait_register_center_timeout: 300 + + # Device to run training on (e.g., "cuda", "cpu") + device: cuda + + # whether to use legacy worker implementation + # mode: "auto", "enable", or "disable" + use_legacy_worker_impl: auto + +# profiler configs +global_profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # Profiling tool: choose between nsys, npu, torch, torch_memory + tool: null + + # profile steps + steps: null + + # Whether to combine continuous steps into one database. + ## If True, worker.profiler.discrete must be False, [1,2] in one, [5] in another. + ## If False, [1] in one, [2] in another, [5] in another. + profile_continuous_steps: False + + # Path to save profiling contents + save_path: "outputs/profile" + + # Specific tool configs, can use +profiler.tool_config.[tool].xxx to config + global_tool_config: + + # nsys config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. + ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html + ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html + controller_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. + worker_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. + capture-range: "cudaProfilerApi" + + # Specify the desired behavior when a capture range ends. + # In verl we need the torch.cuda.profiler.start/stop pair to repeats n times. + # valid values are "repeat-shutdown:n" or null. + # For normal whole step profiling, n = len(profile_steps); + # but for discrete profiling, n = len(profile_steps) * Number(subtasks). + # Or you can just leave it null and the program will use n = len(profile_steps) * 6; + capture-range-end: null + + # Send signal to the target application's process group. We let the program to exit by itself. + kill: none + + # enable memory visualization for debugging memory usage + torch_memory: + + # Maximum number of allocation entries to record + trace_alloc_max_entries: 100_000 + + # The depth of the call stack to capture for each allocation + stack_depth: 32 + + # 'alloc': records only allocation events || 'state': records memory state changes || 'all': records both. + context: "all" + + # 'python': records Python stacks || 'cpp': records C++ stacks (available in some versions) || 'all': records both. + stacks: "all" + + # devices, record_context etc. + kw_args: {} + +# configs for TransferQueue +transfer_queue: + + # Whether to enable transfer queue + enable: False + +# configs related to ray +ray_kwargs: + + # configs related to ray initialization + ray_init: + + # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM. + num_cpus: null + + # Path to save Ray timeline JSON for performance profiling + timeline_json_file: null diff --git a/verl/trainer/config/rollout/diffusion_rollout.yaml b/verl/trainer/config/rollout/diffusion_rollout.yaml new file mode 100644 index 00000000000..669b6616b73 --- /dev/null +++ b/verl/trainer/config/rollout/diffusion_rollout.yaml @@ -0,0 +1,371 @@ +# Target class for this configuration +_target_: verl.workers.config.DiffusionRolloutConfig + +# actor_rollout_ref.rollout.name: vllm_omni/hf/vllm/sglang/trtllm. The default value will be removed in the future +name: ??? + +# sync: LLM, async: AsyncLLM +mode: async + +# Number of nodes for standalone rollout server, must be > 0 in one-step-off/fully async training. +nnodes: 0 + +# Number of GPUs per node for rollout server. +n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8} + +# typically the same as data max prompt length +# same as data.max_prompt_length if it exists +prompt_length: ${oc.select:data.max_prompt_length,512} + +# for vllm rollout +# Rollout model parameters type. Align with actor model's FSDP/Megatron type. +dtype: bfloat16 + +# Fraction of GPU memory used by vLLM/SGLang/TRTLLM for KV cache. +gpu_memory_utilization: 0.5 + +# Whether to disable CUDA graph. Default False to best performance. +enforce_eager: False + +# batch size of cudagraph to capture. Require enforce_eager: False to use this option +# Since cudagraph in inference engine can not be offloaded during update policy, +# you can use smaller batch size to save memory used in cuda graph, eg: [1 ,2, 4, 8, 16, 32] +# supported engines: vllm +cudagraph_capture_sizes: null + +# Whether to free engine KVCache after generation. +free_cache_engine: True + +# TP size for rollout. Not effective for hf +tensor_model_parallel_size: 2 + +# DP size for rollout +data_parallel_size: 1 + +# EP size for rollout +expert_parallel_size: 1 + +# PP size for rollout. +pipeline_model_parallel_size: 1 + +# max number of tokens in a batch +max_num_batched_tokens: 8192 + +# max length for rollout +max_model_len: null + +# max length of sequences +max_num_seqs: 1024 + +# may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. +enable_chunked_prefill: True + +# Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. +enable_prefix_caching: True + +# logprobs mode for rollout logprobs +logprobs_mode: processed_logprobs + +# scheduling policy for vllm rollout +scheduling_policy: fcfs + +# Which loader to use for rollout model weights: dummy, hf, megatron, etc. +# safetensors (for huge model, and set use_shm=True); dummy: randomly init model weight +load_format: dummy + +# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size. +log_prob_micro_batch_size: null + +# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. +log_prob_micro_batch_size_per_gpu: null + +# enable dynamic batch size (sequence packing) for log_prob computation +# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false +log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# max token length for log_prob computation +# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384 +log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + +# disable logging statistics +disable_log_stats: True + +# for hf rollout +# Whether to sample during training rollout. False uses greedy sampling. +do_sample: True + +# number of responses (i.e. num sample times). > 1 for grpo +n: 1 + +# image height for diffusion model rollout +image_height: 512 + +# image width for diffusion model rollout +image_width: 512 + +# number of inference steps for diffusion model rollout +num_inference_steps: 10 + +# noise level for diffusion model rollout +noise_level: 0.7 + +# guidance scale for classifier-free guidance +guidance_scale: 4.5 + +# SDE type during rollout +sde_type: "sde" + +# SDE window size +sde_window_size: null + +# SDE window range +sde_window_range: null + +# Extra inference engine arguments (vllm, sglang, trtllm), please refer vllm/sglang/trtllm official doc for detail +engine_kwargs: + + # vllm-omni engine config + vllm_omni: {} + +# Sampling parameters used during validation. +val_kwargs: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.DiffusionSamplingConfig + + # whether to repeat n times for validation + n: 1 + + # Whether to sample during training rollout. False uses greedy sampling. + do_sample: False + + # number of inference steps for diffusion model rollout + num_inference_steps: 40 + + # noise level for diffusion model rollout + noise_level: 0.0 + + # random seed for validation + seed: 42 + +# Multi-turn interaction config for tools or chat. +multi_turn: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.MultiTurnConfig + + # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well + enable: False + + # null for no limit (default max_length // 3) + max_assistant_turns: null + + # null for no tool + tool_config_path: null + + # null for no limit (default max_length // 3) + max_user_turns: null + + # max parallel call for tools in single turn + max_parallel_calls: 1 + + # max length of tool response + max_tool_response_length: 256 + + # truncate side of tool response: left, middle, right + tool_response_truncate_side: middle + + # null for no interaction + interaction_config_path: null + + # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. + # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, + # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. + use_inference_chat_template: False + + # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. + # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. + # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. + # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: + # Qwen/QwQ-32B, Qwen/Qwen3-xxB + # - disable: disable tokenization sanity check + # - strict: enable strict tokenization sanity check (default) + # - ignore_strippable: ignore strippable tokens when checking tokenization sanity + tokenization_sanity_check_mode: strict + + # Format of the multi-turn interaction. Options: hermes, llama3_json, ... + format: hermes + + # Number of repeat rollouts for each interaction + num_repeat_rollouts: null + +# support logging rollout prob for debugging purpose +# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling +calculate_log_probs: False + +# [Experimental] agent loop based rollout configs +agent: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.AgentLoopConfig + + # Number of agent loop workers + num_workers: 8 + + # default agent loop to use if `agent_name` not set in RL dataset + default_agent_loop: single_turn_agent + + # custom agent loop config path, which should contain list of configs to initialize AgentLoop instances. + # https://hydra.cc/docs/advanced/instantiate_objects/overview/ + # + # - name: react_agent + # _target_: recipe.langgraph_agent.react_agent_loop.ReactAgentLoop + # tools: ["get_current_temperature"] + # - name: math_expression + # _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop + # min_terms: 2 + # max_terms: 6 + agent_loop_config_path: null + + # custom async server configs + custom_async_server: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.CustomAsyncServerConfig + + # Path to the custom async server implementation + path: null + + # Class name of the custom async server class (e.g. AsyncvLLMServer) + name: null + +# Checkpoint Engine config for update weights from trainer to rollout +checkpoint_engine: + + # Target class for checkpoint engine config + _target_: verl.workers.config.CheckpointEngineConfig + + # Backend for checkpoint engine: naive, nccl, nixl, hccl + backend: naive + + # Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations. + # This parameter controls the maximum payload size for a single weight update request. + # Reference: https://github.com/volcengine/verl/pull/2418 + # Currently only supported in SGLang rollout implementations + # Larger values may improve throughput but increase memory overhead + # Detailed performance comparison: + # https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/169#issuecomment-3070686720 + # Default value (512MB) is optimized for typical GPU memory configurations + # For the best performance of `rebuild_cuda_tensor`, it is recommended to: + # 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES` + # 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` + # when using Tensor Parallelism (TP) >= 8. + update_weights_bucket_megabytes: 2048 + + # Additional keyword arguments to pass to the checkpoint engine constructor + engine_kwargs: {} + +# trace rollout data +# trace: + +# # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +# _target_: verl.workers.config.TraceConfig + +# # trace backend, support mlflow, weave +# backend: null + +# # whether translate token id to text in output +# token2text: False + +# # Maximum number of unique samples to trace per agent worker per training step. +# # If null, all samples are traced. If set to N, each agent loop worker will randomly +# # select N unique samples to trace (including all their rollouts for GRPO). +# # Total traces per step = max_samples_per_step_per_worker * num_workers * n_rollouts_per_sample +# max_samples_per_step_per_worker: null + +# Whether to enable rollout routing replay for MoE models +# When enabled (True), the rollout will record the routing decisions. +enable_rollout_routing_replay: False + + +# profile the rollout model in `generate_sequence` +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # profiler tool, default same as profiler.tool in global config + # choices: npu, torch + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on rollout + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + + # Whether to profile all ranks. + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + + # The ranks that will be profiled. [] or [0,1,...] + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config + tool_config: + + # npu config + npu: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NPUToolConfig + + # Contents to profile, can be empty + # options: npu, cpu, memory, shapes, module, stack + contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.contents,[]} + + # Collection level, optional values: level_none, level0, level1, level2. + level: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.level,level0} + + # Whether to automatically parse the data. + analysis: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.analysis,false} + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.discrete,false} + + # torch profiler config + torch: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + + # Contents to profile, can be empty + # options: cuda, cpu, memory, shapes, stack + contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.contents,[]} + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.discrete,false} + +# prometheus configuration for vllm/sglang server mode +prometheus: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.PrometheusConfig + + # whether enable prometheus on server mode rollout + enable: false + + # Port number that Prometheus listens on, default is 9090 + port: 9090 + + # Path to Prometheus configuration file + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + + # Specify served_model_name to avoid displaying overly long model paths in Grafana + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + +# type of quantization in vllm, currently support fp8 and torchao +quantization: null + +# extra quantization information serialized in a config file, e.g. torchao_config.json +quantization_config_file: null + diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 2c84374d245..ea177564057 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -25,6 +25,7 @@ from verl.experimental.dataset.sampler import AbstractSampler from verl.experimental.reward_loop import migrate_legacy_reward_impl from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.ppo.ray_diffusion_trainer import RayFlowGRPOTrainer from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.trainer.ppo.utils import need_critic, need_reference_policy from verl.utils.config import validate_config @@ -312,9 +313,13 @@ def run(self, config): from verl.utils import hf_processor, hf_tokenizer trust_remote_code = config.data.get("trust_remote_code", False) - tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + tokenizer = hf_tokenizer(config.actor_rollout_ref.model.tokenizer_path, trust_remote_code=trust_remote_code) # Used for multimodal LLM, could be None - processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + if os.path.exists(os.path.join(local_path, "processor")): + processor_path = os.path.join(local_path, "processor") + else: + processor_path = local_path + processor = hf_processor(processor_path, trust_remote_code=trust_remote_code, use_fast=True) resource_pool_manager = self.init_resource_pool_mgr(config) @@ -340,7 +345,12 @@ def run(self, config): train_sampler = create_rl_sampler(config.data, train_dataset) # Initialize the PPO trainer. - trainer = RayPPOTrainer( + trainer_cls = ( + RayFlowGRPOTrainer + if config.actor_rollout_ref.model.get("model_type", None) == "diffusion_model" + else RayPPOTrainer + ) + trainer = trainer_cls( config=config, tokenizer=tokenizer, processor=processor, diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index a78bd400e10..046fd8e0728 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -107,6 +107,7 @@ class AdvantageEstimator(str, Enum): GRPO_VECTORIZED = "grpo_vectorized" OPTIMAL_TOKEN_BASELINE = "optimal_token_baseline" TIR_OPTIMAL_TOKEN_BASELINE = "tir_optimal_token_baseline" + FLOW_GRPO = "flow_grpo" ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {} @@ -1006,6 +1007,89 @@ def compute_multi_turn_optimal_token_baseline_advantage( return advantages, token_returns +@register_adv_est(AdvantageEstimator.FLOW_GRPO) +def compute_flow_grpo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-4, + norm_adv_by_std_in_grpo: bool = True, + global_std: bool = True, + config: Optional[DictConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for GRPO, operating only on Outcome reward + (with only one scalar reward for each response). + + Args: + token_level_rewards: `(torch.Tensor)` + shape is (bs, ), (bs, 1) or (bs, response_length) + response_mask: `(torch.Tensor)` + shape is (bs, response_length) + index: `(np.ndarray)` + index array for grouping + epsilon: `(float)` + small value to avoid division by zero + norm_adv_by_std_in_grpo: `(bool)` + whether to scale the GRPO advantage + global_std: `(bool)` + whether to use global std for advantage normalization + config: `(Optional[DictConfig])` + algorithm configuration object + + Note: + If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. + If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). + + Returns: + advantages: `(torch.Tensor)` + shape is (bs, response_length) + Returns: `(torch.Tensor)` + shape is (bs, response_length) + """ + scores = token_level_rewards + if scores.ndim == 1: + scores = scores.unsqueeze(-1) + scores = scores.expand_as(response_mask).clone() + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + if global_std: + batch_std = torch.std(scores) + else: + batch_std = None + + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + if global_std: + id2std[idx] = batch_std + else: + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + scores_tensor = torch.stack(id2score[idx]) + id2mean[idx] = torch.mean(scores_tensor) + if global_std: + id2std[idx] = batch_std + else: + id2std[idx] = torch.std(scores_tensor) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + if norm_adv_by_std_in_grpo: + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + else: + scores[i] = scores[i] - id2mean[index[i]] + + return scores, scores + + def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): """Compute token-level rewards with KL penalty. @@ -1951,6 +2035,58 @@ def compute_policy_loss_cispo( return pg_loss, pg_metrics +@register_policy_loss("flow_grpo") +def compute_policy_loss_flow_grpo( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute the clipped policy objective and related metrics for FlowGRPO. + + Adapted from + https://github.com/yifan123/flow_grpo/blob/main/scripts/train_sd3_fast.py#L885 + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size,). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size,). + response_mask (torch.Tensor): + Not used. + loss_agg_mode (str, optional): + Not used. + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size,). + config: `(verl.trainer.config.ActorConfig)`: + config for the actor. + rollout_is_weights: `(torch.Tensor, optional)`: + Not used. + """ + assert config is not None + assert isinstance(config, ActorConfig) + advantages = torch.clamp( + advantages, + -config.clip_ratio_high, + config.clip_ratio_high, + ) + ratio = torch.exp(log_prob - old_log_prob) + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, + 1.0 - config.clip_ratio, + 1.0 + config.clip_ratio, + ) + pg_loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + pg_metrics = {"actor/ppo_kl": pg_loss.detach().item()} + return pg_loss, pg_metrics + + def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"): """Compute categorical entropy loss (For backward compatibility) @@ -2074,6 +2210,20 @@ def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTenso raise NotImplementedError +def kl_penalty_image( + prev_sample_mean: torch.Tensor, ref_prev_sample_mean: torch.Tensor, std_dev_t: torch.Tensor +) -> torch.Tensor: + """Compute KL divergence given previous sample mean and reference previous sample mean (for images or videos). + + Args: + prev_sample_mean: (torch.Tensor) shape is (bs, s, c) + ref_prev_sample_mean: (torch.Tensor) shape is (bs, s, c) + std_dev_t: (torch.Tensor) shape is (bs, 1, 1) + """ + kl_loss = ((prev_sample_mean - ref_prev_sample_mean) ** 2).mean(dim=(1, 2), keepdim=True) / (2 * std_dev_t**2) + return kl_loss.mean() + + def compute_pf_ppo_reweight_data( data, reweight_method: str = "pow", diff --git a/verl/trainer/ppo/ray_diffusion_trainer.py b/verl/trainer/ppo/ray_diffusion_trainer.py new file mode 100644 index 00000000000..c96134c0fa7 --- /dev/null +++ b/verl/trainer/ppo/ray_diffusion_trainer.py @@ -0,0 +1,1486 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import json +import os +import uuid +from collections import defaultdict +from pprint import pprint +from typing import Any, Optional + +import numpy as np +import torch +from omegaconf import OmegaConf, open_dict +from PIL import Image +from torch.utils.data import Dataset, Sampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl import DataProto +from verl.checkpoint_engine import CheckpointEngineManager +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.config import AlgoConfig +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import AdvantageEstimator +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + compute_variance_proxy_metrics, + process_validation_metrics, +) +from verl.trainer.ppo.reward import extract_reward +from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model +from verl.utils import tensordict_utils as tu +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.import_utils import load_class_from_fqn +from verl.utils.metric import reduce_metrics +from verl.utils.py_functional import rename_dict +from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.tracking import ValidationGenerationsLogger +from verl.workers.config import FSDPEngineConfig +from verl.workers.utils.padding import embeds_padding_2_no_padding + + +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): + raise NotImplementedError("KL penalty is not supported.") + + +def compute_response_mask(data: DataProto): + """Compute the attention mask for latents + + Args: + data (DataProto): The data containing batched model outputs and inputs. + + Returns: + torch.Tensor: The attention mask for the response tokens. + """ + all_latents = data.batch["all_latents"] + b, t, _, _ = all_latents.shape + response_mask = torch.ones((b, t), dtype=torch.int32) + return response_mask + + +def compute_advantage( + data: DataProto, + adv_estimator: AdvantageEstimator, + norm_adv_by_std_in_grpo: bool = True, + global_std: bool = True, + config: Optional[AlgoConfig] = None, +) -> DataProto: + """Compute advantage estimates for policy optimization. + + This function computes advantage estimates using various estimators like FlowGRPO, etc. + The advantage estimates are used to guide policy optimization in RL algorithms. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). + norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in + GRPO. Defaults to True. + global_std (bool, optional): Whether to use global standard deviation for advantage normalization. + config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None. + + Returns: + DataProto: The updated data with computed advantages and returns. + """ + # Back-compatible with trainers that do not compute response mask in fit + if "response_mask" not in data.batch.keys(): + data.batch["response_mask"] = compute_response_mask(data) + # prepare response group + if adv_estimator == AdvantageEstimator.FLOW_GRPO: + # Initialize the mask for GRPO calculation + grpo_calculation_mask = data.batch["response_mask"] + + # Call compute_grpo_outcome_advantage with parameters matching its definition + advantages, returns = core_algos.compute_flow_grpo_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=grpo_calculation_mask, + index=data.non_tensor_batch["uid"], + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + global_std=global_std, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + else: + # handle all other adv estimator type other than GAE and GRPO + adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) + adv_kwargs = { + "token_level_rewards": data.batch["token_level_rewards"], + "response_mask": data.batch["response_mask"], + "config": config, + } + if "uid" in data.non_tensor_batch: # optional + adv_kwargs["index"] = data.non_tensor_batch["uid"] + if "reward_baselines" in data.batch: # optional + adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] + + # calculate advantage estimator + advantages, returns = adv_estimator_fn(**adv_kwargs) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + return data + + +class RayFlowGRPOTrainer: + """Distributed Flow-GRPO trainer using Ray for scalable reinforcement learning. + + This trainer orchestrates distributed PPO training across multiple nodes and GPUs, + managing actor rollouts, critic training, and reward computation with Ray backend. + Supports various model architectures including FSDP, Megatron, vLLM, and SGLang integration. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup, + processor=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name=None, + ): + """ + Initialize distributed PPO trainer with Ray backend. + Note that this trainer runs on the driver process on a single CPU/GPU node. + + Args: + config: Configuration object containing training parameters. + tokenizer: Tokenizer used for encoding and decoding text. + role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. + resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. + ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. + processor: Optional data processor, used for multimodal data + train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. + collate_fn: Function to collate data samples into batches. + train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. + """ + + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, "Currently, only support hybrid engine" + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping or Role.ActorRolloutRef in role_worker_mapping, ( + f"{role_worker_mapping.keys()=}" + ) + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = need_reference_policy(self.config) + + self.use_rm = need_reward_model(self.config) + + self.use_critic = need_critic(self.config) + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name if device_name else self.config.trainer.device + self.validation_generations_logger = ValidationGenerationsLogger( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + ) + + # if ref_in_actor is True, the reference policy will be actor without lora applied + lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0) + if lora_rank <= 0: + lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) + self.ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None + + # define in-reward KL control + # kl loss control currently not suppoorted + if self.config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) + + self.use_prefix_grouper = self.config.actor_rollout_ref.actor.get("use_prefix_grouper", False) + self.use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + self.checkpoint_manager = None + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset( + self.config.data.train_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("train_max_samples", -1), + ) + if val_dataset is None: + val_dataset = create_rl_dataset( + self.config.data.val_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("val_max_samples", -1), + ) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + num_workers = self.config.data["dataloader_num_workers"] + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print( + f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " + f"{len(self.val_dataloader)}" + ) + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): + """Dump rollout/validation samples as JSONL.""" + os.makedirs(dump_path, exist_ok=True) + + visual_folder = os.path.join(dump_path, f"{self.global_steps}") + os.makedirs(visual_folder, exist_ok=True) + + output_paths = [] + images_pil = outputs.cpu().float().permute(0, 2, 3, 1).numpy() + images_pil = (images_pil * 255).round().clip(0, 255).astype("uint8") + for i, image in enumerate(images_pil): + image_path = os.path.join(visual_folder, f"{i}.jpg") + Image.fromarray(image).save(image_path) + output_paths.append(image_path) + + filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") + + n = len(inputs) + base_data = { + "input": inputs, + "output": output_paths, + "gts": gts, + "score": scores, + "step": [self.global_steps] * n, + } + + for k, v in reward_extra_infos_dict.items(): + if len(v) == n: + base_data[k] = v + + lines = [] + for i in range(n): + entry = {k: v[i] for k, v in base_data.items()} + lines.append(json.dumps(entry, ensure_ascii=False)) + + with open(filename, "w") as f: + f.write("\n".join(lines) + "\n") + + print(f"Dumped generations to {filename}") + + def _log_rollout_data( + self, batch: DataProto, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str + ): + """Log rollout data to disk. + Args: + batch (DataProto): The batch containing rollout data + reward_extra_infos_dict (dict): Additional reward information to log + timing_raw (dict): Timing information for profiling + rollout_data_dir (str): Directory path to save the rollout data + """ + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = batch.batch["responses"] + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch] + + reward_extra_infos_to_dump = reward_extra_infos_dict.copy() + if "request_id" in batch.non_tensor_batch: + reward_extra_infos_dict.setdefault( + "request_id", + batch.non_tensor_batch["request_id"].tolist(), + ) + + self._dump_generations( + inputs=inputs, + outputs=outputs, + gts=sample_gts, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_to_dump, + dump_path=rollout_data_dir, + ) + + def _maybe_log_val_generations(self, inputs, outputs, scores): + """Log a table of validation samples to the configured logger (wandb or swanlab)""" + + generations_to_log = self.config.trainer.log_val_generations + + if generations_to_log == 0: + return + + import numpy as np + + # Create tuples of (input, output, score) and sort by input text + if "wandb" in self.config.trainer.logger: + import wandb + + outputs = [wandb.Image(image.float(), file_type="jpg") for image in outputs] + samples = list(zip(inputs, outputs, scores, strict=True)) + samples.sort(key=lambda x: x[0]) # Sort by input text + + # Use fixed random seed for deterministic shuffling + rng = np.random.RandomState(42) + rng.shuffle(samples) + + # Take first N samples after shuffling + samples = samples[:generations_to_log] + + # Log to each configured logger + self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + + def _get_gen_batch(self, batch: DataProto) -> DataProto: + reward_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() + + # pop those keys for generation + batch_keys_to_pop = [] + non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_keys + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), + ) + + # For agent loop, we need reward model keys to compute score. + gen_batch.non_tensor_batch.update(batch.non_tensor_batch) + + return gen_batch + + def _compute_reward_colocate(self, batch: DataProto) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: + """ + compute reward use colocate reward model + """ + assert self.reward_loop_manager is not None, "RewardLoopManager is None" + batch_reward = self.reward_loop_manager.compute_rm_score(batch) + return batch_reward + + def _validate(self, merged: bool = False): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_gts = [] + sample_scores = [] + sample_turns = [] + sample_uids = [] + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + + if "uid" not in test_batch.non_tensor_batch: + test_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object + ) + + # repeat test batch + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) + + ground_truths = [ + item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch + ] + sample_gts.extend(ground_truths) + + test_gen_batch = self._get_gen_batch(test_batch) + test_gen_batch.meta_info = { + "recompute_log_prob": False, + "validate": True, + "global_steps": self.global_steps, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + size_divisor = self.config.actor_rollout_ref.rollout.agent.num_workers + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + + if self.use_rm and "rm_scores" not in test_output_gen_batch_padded.batch.keys(): + # for colocate reward models, we need to sleep rollout model + # to spare GPU memory for reward model + self.checkpoint_manager.sleep_replicas() + batch_reward = self._compute_reward_colocate(test_output_gen_batch_padded) + test_output_gen_batch_padded = test_output_gen_batch_padded.union(batch_reward) + # wake up rollout model + # replace with wake_up method once supported + self.checkpoint_manager.update_weights(self.global_steps) + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + + print("validation generation end") + + # Store generated outputs + output_images = test_output_gen_batch.batch["responses"] + sample_outputs.append(output_images) + + test_batch = test_batch.union(test_output_gen_batch) + test_batch.meta_info["validate"] = True + + # Store original inputs + input_ids = test_batch.batch["input_ids"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + sample_uids.extend(test_batch.non_tensor_batch["uid"]) + + # evaluate using reward_function + reward_tensor, reward_extra_info = extract_reward(test_batch) + + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + for key, values in reward_extra_info.items(): + if key not in reward_extra_infos_dict: + reward_extra_infos_dict[key] = [] + if isinstance(values, np.ndarray): + reward_extra_infos_dict[key].extend(values.tolist()) + else: + reward_extra_infos_dict[key].extend(values if isinstance(values, list) else [values]) + + # collect num_turns of each prompt + if "__num_turns__" in test_batch.non_tensor_batch: + sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + sample_outputs = torch.cat(sample_outputs, dim=0) + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + gts=sample_gts, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + if merged: + print("_merge_validation_results validate result will be merged") + return { + "data_sources": data_source_lst, + "sample_uids": sample_uids, + "sample_turns": sample_turns, + "reward_extra_infos_dict": reward_extra_infos_dict, + } + data_sources = np.concatenate(data_source_lst, axis=0) + return self._val_metrics_update(data_sources, sample_uids, reward_extra_infos_dict, sample_turns) + + def _val_metrics_update(self, data_sources, sample_uids, reward_extra_infos_dict, sample_turns): + data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + if len(sample_turns) > 0: + sample_turns = np.concatenate(sample_turns) + metric_dict["val-aux/num_turns/min"] = sample_turns.min() + metric_dict["val-aux/num_turns/max"] = sample_turns.max() + metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() + + return metric_dict + + def _merge_validation_results(self, result_a, result_b): + if result_a is None and result_b is None: + return {} + if result_a is None: + result_a = {"data_sources": [], "sample_uids": [], "sample_turns": [], "reward_extra_infos_dict": {}} + if result_b is None: + result_b = {"data_sources": [], "sample_uids": [], "sample_turns": [], "reward_extra_infos_dict": {}} + + if not result_a.get("data_sources") and not result_b.get("data_sources"): + return {} + + data_sources = np.concatenate(result_a["data_sources"] + result_b["data_sources"], axis=0) + sample_uids = result_a["sample_uids"] + result_b["sample_uids"] + sample_turns = result_a["sample_turns"] + result_b["sample_turns"] + + reward_extra_infos_dict = {} + all_keys = set(result_a["reward_extra_infos_dict"].keys()) | set(result_b["reward_extra_infos_dict"].keys()) + for key in all_keys: + list_a = result_a["reward_extra_infos_dict"].get(key, []) + list_b = result_b["reward_extra_infos_dict"].get(key, []) + reward_extra_infos_dict[key] = list_a + list_b + + return self._val_metrics_update(data_sources, sample_uids, reward_extra_infos_dict, sample_turns) + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + actor_role = Role.ActorRolloutRef if Role.ActorRolloutRef in self.role_worker_mapping else Role.ActorRollout + if self.hybrid_engine: + actor_rollout_resource_pool = self.resource_pool_manager.get_resource_pool(actor_role) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[actor_role], + config=self.config.actor_rollout_ref, + role=str(actor_role), + ) + self.resource_pool_to_cls[actor_rollout_resource_pool][str(actor_role)] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + + from verl.workers.config import CriticConfig + + critic_cfg: CriticConfig = omega_conf_to_dataclass(self.config.critic) + + if self.use_legacy_worker_impl == "disable": + # convert critic_cfg into TrainingWorkerConfig + from verl.workers.engine_workers import TrainingWorkerConfig + + orig_critic_cfg = critic_cfg + if orig_critic_cfg.strategy == "fsdp": + engine_config: FSDPEngineConfig = orig_critic_cfg.model.fsdp_config + engine_config.infer_max_token_len_per_gpu = critic_cfg.ppo_infer_max_token_len_per_gpu + engine_config.max_token_len_per_gpu = critic_cfg.ppo_max_token_len_per_gpu + else: + raise NotImplementedError(f"Unknown strategy {orig_critic_cfg.strategy=}") + + critic_cfg = TrainingWorkerConfig( + model_type="value_model", + model_config=orig_critic_cfg.model_config, + engine_config=engine_config, + optimizer_config=orig_critic_cfg.optim, + checkpoint_config=orig_critic_cfg.checkpoint, + ) + + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls + + # create reference policy if needed + if self.use_reference_policy and Role.RefPolicy in self.role_worker_mapping: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role=str(Role.RefPolicy), + ) + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + if not class_dict: + continue + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg[str(Role.Critic)] + if self.use_legacy_worker_impl == "disable": + self.critic_wg.reset() + # assign critic loss + from functools import partial + + from verl.workers.utils.losses import value_loss + + value_loss_ = partial(value_loss, config=orig_critic_cfg) + self.critic_wg.set_loss_fn(value_loss_) + else: + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + if str(Role.RefPolicy) in all_wg: + self.ref_policy_wg = all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + else: + # Model engine: ActorRolloutRefWorker + assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}" + self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)] + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg[str(actor_role)] + self.actor_rollout_wg.init_model() + + if self.ref_in_actor: + self.ref_policy_wg = self.actor_rollout_wg + + # create reward loop manager + from verl.experimental.reward_loop import RewardLoopManager + + # initalize reward loop manager + # reward model (colocate or standalone): get resource_pool + # no reward model: resource_pool = None + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) if self.use_rm else None + self.reward_loop_manager = RewardLoopManager( + config=self.config, + rm_resource_pool=resource_pool, + ) + + # create async rollout manager and request scheduler + # Note: mode is always "async" since sync mode is deprecated + self.async_rollout_mode = True + + # Support custom AgentLoopManager via config + manager_class_fqn = self.config.actor_rollout_ref.rollout.get("agent", {}).get("agent_loop_manager_class") + if manager_class_fqn: + AgentLoopManager = load_class_from_fqn(manager_class_fqn, "AgentLoopManager") + else: + from verl.experimental.agent_loop import AgentLoopManager + + # infrastructure overview: https://verl.readthedocs.io/en/latest/advance/reward_loop.html#architecture-design + # agent_reward_loop: streaming reward computation with actor rollout + # two conditions satisfied: (1) no reward model, or (2) reward model with extra resource pool + enable_agent_reward_loop = not self.use_rm or self.config.reward.reward_model.enable_resource_pool + + # if enable_agent_reward_loop, we directly pass reward_loop_workers to agent loop manager + # to stream reward computation with actor rollout + reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None + self.async_rollout_manager = AgentLoopManager.create( + config=self.config, + worker_group=self.actor_rollout_wg, + rollout_resource_pool=actor_rollout_resource_pool, + reward_loop_worker_handles=reward_loop_worker_handles, + ) + + checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine) + self.checkpoint_manager = CheckpointEngineManager( + config=checkpoint_engine_config, + trainer=self.actor_rollout_wg, + replicas=self.async_rollout_manager.rollout_replicas, + ) + + # sleep all replicas to load checkpoint + self.checkpoint_manager.sleep_replicas() + + def _save_checkpoint(self): + from verl.utils.fs import local_mkdir_safe + + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") + + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) + + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + if remove_previous_ckpt_in_save: + print( + "Warning: remove_previous_ckpt_in_save is deprecated," + + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" + ) + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + max_critic_ckpt_to_keep = ( + self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) + + if self.use_critic: + critic_local_path = os.path.join(local_global_step_folder, str(Role.Critic)) + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join( + self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", str(Role.Critic) + ) + ) + self.critic_wg.save_checkpoint( + critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + ) + + # save dataloader + local_mkdir_safe(local_global_step_folder) + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + + # latest checkpointed iteration tracker (for atomic usage) + if ( + hasattr(self.config.actor_rollout_ref.actor.checkpoint, "async_save") + and self.config.actor_rollout_ref.actor.checkpoint.async_save + ) or ( + "async_save" in self.config.actor_rollout_ref.actor.checkpoint + and self.config.actor_rollout_ref.actor.checkpoint["async_save"] + ): + print("skip write latest_checkpointed_iteration.txt when async_save is True") + return + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.global_steps)) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("Training from scratch") + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"Load from checkpoint folder: {global_step_folder}") + # set global step + self.global_steps = int(global_step_folder.split("global_step_")[-1]) + + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, str(Role.Critic)) + # load actor + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + # load critic + if self.use_critic: + self.critic_wg.load_checkpoint( + critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + if os.path.exists(dataloader_local_path): + dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + else: + print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + + def _start_profiling(self, do_profile: bool) -> None: + """Start profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + if self.use_reference_policy: + self.ref_policy_wg.start_profile(profile_step=self.global_steps) + if self.use_critic: + self.critic_wg.start_profile(profile_step=self.global_steps) + + def _stop_profiling(self, do_profile: bool) -> None: + """Stop profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + if self.use_critic: + self.critic_wg.stop_profile() + + def _get_dp_size(self, worker_group, role: str) -> int: + """Get data parallel size from worker group dispatch info. + + This method retrieves the data parallel size by querying the dispatch info + for the specified role. The dispatch info is cached for subsequent calls. + + Args: + worker_group: The worker group to query dispatch info from. + role: The role name (e.g., "actor", "critic") to get DP size for. + + Returns: + The data parallel size (number of DP ranks). + """ + if role not in worker_group._dispatch_info: + dp_rank_mapping = worker_group._query_dispatch_info(role) + worker_group._dispatch_info[role] = dp_rank_mapping + else: + dp_rank_mapping = worker_group._dispatch_info[role] + return max(dp_rank_mapping) + 1 + + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False): + """Reorder the data on single controller such that each dp rank gets similar total tokens. + + When use_prefix_grouper is enabled, uses group-level balancing to keep samples with + the same uid together on the same rank for prefix sharing optimization. + """ + attention_mask = batch.batch["attention_mask"] + batch_size = attention_mask.shape[0] + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) + workload_lst = calculate_workload(global_seqlen_lst) + # Get dp_size from dispatch info to correctly balance across data parallel ranks + # Note: world_size may include tensor/pipeline parallel dimensions, but we only want DP + dp_size = self._get_dp_size(self.actor_rollout_wg, "actor") + + # Use group-level balancing for PrefixGrouper to keep same-uid samples together + if getattr(self, "use_prefix_grouper", False) and "uid" in batch.non_tensor_batch: + from verl.utils.seqlen_balancing import get_group_balanced_partitions + + uid_list = list(batch.non_tensor_batch["uid"]) + seqlen_list = global_seqlen_lst.tolist() + + # Count number of uid groups + num_groups = len(set(uid_list)) + + if num_groups % dp_size != 0: + raise ValueError( + f"PrefixGrouper with balance_batch requires num_uid_groups ({num_groups}) " + f"% dp_size ({dp_size}) == 0. " + f"This ensures each rank gets equal number of groups. " + f"Current batch_size={batch_size}, adjust batch_size to be a multiple of " + f"dp_size * rollout.n." + ) + + global_partition_lst = get_group_balanced_partitions( + seqlen_list=seqlen_list, + uid_list=uid_list, + k_partitions=dp_size, + ) + + elif keep_minibatch: + # Decouple the DP balancing and mini-batching. + minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size") + minibatch_num = len(workload_lst) // minibatch_size + global_partition_lst = [[] for _ in range(dp_size)] + for i in range(minibatch_num): + rearrange_minibatch_lst = get_seqlen_balanced_partitions( + workload_lst[i * minibatch_size : (i + 1) * minibatch_size], + k_partitions=dp_size, + equal_size=True, + ) + for j, part in enumerate(rearrange_minibatch_lst): + global_partition_lst[j].extend([x + minibatch_size * i for x in part]) + else: + global_partition_lst = get_seqlen_balanced_partitions(workload_lst, k_partitions=dp_size, equal_size=True) + # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. + # Skip reordering within partitions for PrefixGrouper to maintain uid grouping + if not getattr(self, "use_prefix_grouper", False): + for idx, partition in enumerate(global_partition_lst): + partition.sort(key=lambda x: (workload_lst[x], x)) + ordered_partition = partition[::2] + partition[1::2][::-1] + global_partition_lst[idx] = ordered_partition + + # reorder based on index. The data will be automatically equally partitioned by dispatch function + global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + batch.reorder(global_idx) + global_balance_stats = log_seqlen_unbalance( + seqlen_list=global_seqlen_lst.tolist(), partitions=global_partition_lst, prefix=logging_prefix + ) + metrics.update(global_balance_stats) + + def _compute_values(self, batch: DataProto) -> DataProto: + if self.use_legacy_worker_impl == "disable": + batch_td = batch.to_tensordict() + # step 2: convert from padding to nopadding + batch_td = embeds_padding_2_no_padding(batch_td) + # step 3: add meta info + tu.assign_non_tensor(batch_td, compute_loss=False) + output = self.critic_wg.infer_batch(batch_td) + output = output.get() + values = tu.get(output, "values") + values = tu.get_tensordict({"values": values.float()}) + values = DataProto.from_tensordict(values) + else: + values = self.critic_wg.compute_values(batch) + return values + + def _compute_ref_log_prob(self, batch: DataProto) -> DataProto: + if self.use_legacy_worker_impl == "disable": + # step 1: convert dataproto to tensordict. + batch_td = batch.to_tensordict() + # step 2: convert from padding to nopadding + batch_td = embeds_padding_2_no_padding(batch_td) + # step 3: add meta info + metadata = { + "compute_loss": False, + "height": self.config.actor_rollout_ref.model.image_height, + "width": self.config.actor_rollout_ref.model.image_width, + "vae_scale_factor": self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), + } + if self.ref_in_actor: + metadata["no_lora_adapter"] = True + tu.assign_non_tensor(batch_td, **metadata) + if self.ref_in_actor: + output = self.actor_rollout_wg.compute_log_prob(batch_td) + else: + output = self.ref_policy_wg.compute_ref_log_prob(batch_td) + # gather output + log_probs = tu.get(output, "log_probs") + prev_sample_mean = tu.get(output, "prev_sample_mean") + # step 5: rebuild a tensordict and convert to dataproto + ref_log_prob = tu.get_tensordict( + {"ref_log_prob": log_probs.float(), "ref_prev_sample_mean": prev_sample_mean.float()} + ) + ref_log_prob = DataProto.from_tensordict(ref_log_prob) + else: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + + return ref_log_prob + + def _compute_old_log_prob(self, batch: DataProto): + if self.use_legacy_worker_impl == "disable": + # TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free + # step 1: convert dataproto to tensordict. + batch_td = batch.to_tensordict() + # step 2: convert from padding to nopadding + batch_td = embeds_padding_2_no_padding(batch_td) + # step 3: add meta info + tu.assign_non_tensor( + batch_td, + compute_loss=False, + height=self.config.actor_rollout_ref.model.image_height, + width=self.config.actor_rollout_ref.model.image_width, + vae_scale_factor=self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), + ) + output = self.actor_rollout_wg.compute_log_prob(batch_td) + # gather output + log_probs = tu.get(output, "log_probs") + # step 5: rebuild a tensordict and convert to dataproto + old_log_prob = tu.get_tensordict({"old_log_probs": log_probs.float()}) + old_log_prob = DataProto.from_tensordict(old_log_prob) + else: + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + return old_log_prob + + def _update_actor(self, batch: DataProto) -> DataProto: + rollout_config = self.config.actor_rollout_ref.rollout + batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable + # update actor + if self.use_legacy_worker_impl == "disable": + batch_td = batch.to_tensordict() + # step 2: convert from padding to no-padding + batch_td = embeds_padding_2_no_padding(batch_td) + ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size + ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n + ppo_epochs = self.config.actor_rollout_ref.actor.ppo_epochs + seed = self.config.actor_rollout_ref.actor.data_loader_seed + shuffle = self.config.actor_rollout_ref.actor.shuffle + tu.assign_non_tensor( + batch_td, + global_batch_size=ppo_mini_batch_size, + mini_batch_size=ppo_mini_batch_size, + epochs=ppo_epochs, + seed=seed, + dataloader_kwargs={"shuffle": shuffle}, + height=self.config.actor_rollout_ref.model.image_height, + width=self.config.actor_rollout_ref.model.image_width, + vae_scale_factor=self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), + ) + + actor_output = self.actor_rollout_wg.update_actor(batch_td) + actor_output = tu.get(actor_output, "metrics") + actor_output = rename_dict(actor_output, "actor/") + actor_output = DataProto.from_single_dict(data={}, meta_info={"metrics": actor_output}) + else: + actor_output = self.actor_rollout_wg.update_actor(batch) + + return actor_output + + def _update_critic(self, batch: DataProto) -> DataProto: + if self.use_legacy_worker_impl == "disable": + batch_td = batch.to_tensordict() + # step 2: convert from padding to no-padding + batch_td = embeds_padding_2_no_padding(batch_td) + ppo_mini_batch_size = self.config.critic.ppo_mini_batch_size + ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n + ppo_epochs = self.config.critic.ppo_epochs + seed = self.config.critic.data_loader_seed + shuffle = self.config.critic.shuffle + tu.assign_non_tensor( + batch_td, + global_batch_size=ppo_mini_batch_size, + mini_batch_size=ppo_mini_batch_size, + epochs=ppo_epochs, + seed=seed, + dataloader_kwargs={"shuffle": shuffle}, + height=self.config.actor_rollout_ref.model.image_height, + width=self.config.actor_rollout_ref.model.image_width, + vae_scale_factor=self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), + ) + + output = self.critic_wg.train_mini_batch(batch_td) + output = output.get() + output = tu.get(output, "metrics") + output = rename_dict(output, "critic/") + # modify key name + output["perf/mfu/critic"] = output.pop("critic/mfu") + critic_output = DataProto.from_single_dict(data={}, meta_info={"metrics": output}) + else: + critic_output = self.critic_wg.update_critic(batch) + return critic_output + + def fit(self): + """ + The training loop of FlowGRPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint and update weights before doing anything + self._load_checkpoint() + self.checkpoint_manager.update_weights(self.global_steps) + + current_epoch = self.global_steps // len(self.train_dataloader) + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + for epoch in range(current_epoch, self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) + metrics = {} + timing_raw = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + + gen_batch = self._get_gen_batch(batch) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if curr_step_profile: + self.async_rollout_manager.start_profile() + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + self.checkpoint_manager.sleep_replicas() + if curr_step_profile: + self.async_rollout_manager.stop_profile() + + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + # get images_seqlens + images_seqlens_all = [] + for multi_modal_input in batch.non_tensor_batch["multi_modal_inputs"]: + if "image_grid_thw" not in multi_modal_input.keys(): + continue + images_seqlens_all.extend(multi_modal_input["images_seqlens"].tolist()) + batch.meta_info["images_seqlens"] = images_seqlens_all + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm and "rm_scores" not in batch.batch.keys(): + batch_reward = self._compute_reward_colocate(batch) + batch = batch.union(batch_reward) + + # extract reward_tensor and reward_extra_infos_dict for training + reward_tensor, reward_extra_infos_dict = extract_reward(batch) + + # Operating Mode Selection: + # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ) + # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ) + # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False) + if bypass_recomputing_logprobs: # Use `rollout_log_probs` + batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"] + else: # Recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self._compute_old_log_prob(batch) + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): + ref_log_prob = self._compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self._compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # Compute rollout correction: IS weights, rejection sampling, and metrics + # Only runs in decoupled mode (computes once per batch using stable π_old) + # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout + if ( + rollout_corr_config is not None + and "rollout_log_probs" in batch.batch + and not bypass_recomputing_logprobs # Only in decoupled mode + ): + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + # Compute IS weights, apply rejection sampling, compute metrics + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + global_std=self.config.algorithm.global_std, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self._update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + actor_output = self._update_actor(batch) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 + or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + # update weights from trainer to rollout + with marked_timer("update_weights", timing_raw, color="red"): + self.checkpoint_manager.update_weights(self.global_steps) + + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if self.config.trainer.test_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.test_freq == 0 + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # compute variance proxy metrics + gradient_norm = metrics.get("actor/grad_norm", None) + metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm)) + # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/verl/trainer/ppo/reward.py b/verl/trainer/ppo/reward.py index f13a3abf976..0d8b108e1dd 100644 --- a/verl/trainer/ppo/reward.py +++ b/verl/trainer/ppo/reward.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast from verl import DataProto -from verl.utils.reward_score import default_compute_score +from verl.utils.reward_score import default_compute_score, default_compute_score_image if TYPE_CHECKING: from omegaconf import DictConfig @@ -123,6 +123,9 @@ def load_reward_manager(config: DictConfig, tokenizer: Any, **reward_kwargs: Any load_extern_object(module_path=module_cfg.path, object_name=reward_manager_cls_name), ) + default_compute_score_ = ( + default_compute_score_image if reward_manager_cfg.name == "image" else default_compute_score + ) if compute_score is None: sandbox_config = config.reward.get("sandbox_fusion") sandbox_url = sandbox_config.get("url") if sandbox_config else None @@ -132,13 +135,13 @@ def load_reward_manager(config: DictConfig, tokenizer: Any, **reward_kwargs: Any # Create a semaphore to control concurrent access to the sandbox _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) final_compute_score = partial( - default_compute_score, + default_compute_score_, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore, memory_limit_mb=memory_limit_mb, ) else: - final_compute_score = default_compute_score + final_compute_score = default_compute_score_ # Instantiate and return the reward manager with the specified parameters return reward_manager_cls( diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 117f2df8d41..63fe574fd83 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -141,6 +141,9 @@ def __init__( self.shuffle = config.get("shuffle", False) self.seed = config.get("seed") + # For diffusion model training only + self.negative_prompt_key = config.get("negative_prompt_key", "negative_prompt") + self._download() self._read_files_and_tokenize() @@ -289,7 +292,7 @@ def __getstate__(self): def __len__(self): return len(self.dataframe) - def _build_messages(self, example: dict): + def _build_messages(self, example: dict, key: str): """Replace and