-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[fsdp,trainer,vllm_omni,algo] feat: support FlowGRPO training for QwenImage #5297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zhtmike
wants to merge
70
commits into
verl-project:main
Choose a base branch
from
zhtmike:verl-omni-pr
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+8,519
−87
Open
Changes from 34 commits
Commits
Show all changes
70 commits
Select commit
Hold shift + click to select a range
70a155a
add entroypoint (#1)
zhtmike 62c5286
add training engine (#2)
zhtmike c0150da
move folders & make for two-forward pass in training loop (#4)
zhtmike 43915bc
Add diffusion reward loop (#3)
chenyingshu 0833f81
[fix] update customized reward func in UT (#5)
chenyingshu 4d0a8d8
Update 20260109 (#8)
zhtmike 4480199
[data] feat: Add dataset for Qwen-Image (#6)
chenyingshu 3c354d1
small fix after rebase (#12)
zhtmike 01f6f7c
[trainer, cfg] fix: actor engine and trainer debug (#10)
chenyingshu b418656
Merge branch 'main' into verl-omni
zhtmike 7d522ee
merge main (#13)
zhtmike abdb5d4
Merge remote-tracking branch 'origin/main' into verl-omni
zhtmike 647c043
[data] fix: QwenDataset update (#14)
chenyingshu d3d2ac4
[rollout] feat: Add vllm-omni for rollout (#9)
zhtmike 80738a3
fix worker extension (#15)
zhtmike a9b88f3
fix worker extension
zhtmike cf314d0
Merge branch 'main' into verl-omni
zhtmike 6eb395a
merge main
zhtmike a32de27
[rollout] feat: flowgrpo with vllm-omni (rollout part) (#16)
zhtmike 24d00a7
[reward, misc] fix: support async reward loop for validation (#18)
chenyingshu be667a3
[rollout] feat: enable reward model (#17)
zhtmike 8edd6d5
[trainer] feat: fix training loop (#19)
zhtmike b008b15
[rollout] fix: fix misc. bugs (#20)
zhtmike 46ffce8
turn on offload to avoid oom
zhtmike af7ab01
[misc] feat: support sync reward loop for validation (#21)
chenyingshu 109427b
[rollout] fix: fix sleep mode & non-lora weight update (#22)
zhtmike 37f60a3
add padding conversion (#24)
chenyingshu 8fe64da
[rollout] fix: fix lora weight export from trainer (#23)
zhtmike 838e28c
[trainer] fix: fix training (#25)
zhtmike ac8122a
Merge branch 'main' into verl-omni-main
zhtmike c903e63
Merge branch 'main' into verl-omni-main
zhtmike 937abd0
Merge branch 'main' into verl-omni-main
zhtmike e3b41ff
[fsdp,vllm_omni,algo] fix: Merge main (#26)
zhtmike 1942ed3
revert python change
zhtmike 89b49e5
fix bug during ckpt saving (#27)
chenyingshu 5da0abe
[vllm_omni] fix: add cfg & clean codes (#28)
zhtmike 0c0acfd
update license (#29)
zhtmike 0a2e3b9
Merge branch 'main' into verl-omni
zhtmike 156014f
Merge branch 'main' into verl-omni
zhtmike 4ec5021
[trainer] refactor: support kl training & clean codes (#30)
zhtmike 0905629
Merge branch 'verl-omni' into verl-omni-pr
zhtmike e53770c
update ocr model (#31)
zhtmike c59767f
[cfg] refactor: refactor rollout configurations (#32)
chenyingshu 0df76a9
[reward] feat: async reward via a separate api call (#34)
chenyingshu b4d5f80
[misc] chore: change to fast UT (#33)
zhtmike d8eb0d2
Merge branch 'verl-omni' into verl-omni-pr
zhtmike 9a81399
[rollout] feat: support bypass mode (#35)
zhtmike 01fe220
[perf] chore: align flowgrpo Qwen-Image training config (#36)
chenyingshu a80c0c4
Merge branch 'main' into verl-omni
zhtmike 6a7798f
merge main (#37)
zhtmike 8dc25ae
Merge branch 'verl-omni' into verl-omni-pr
zhtmike 3169944
update script (#38)
zhtmike f89a7e2
[doc] chore: add README (#39)
zhtmike 9ff7986
Merge branch 'verl-omni' into verl-omni-pr
zhtmike 5fd3362
update doc (#40)
zhtmike 574363f
Merge branch 'verl-omni' into verl-omni-pr
zhtmike 41d0173
Merge branch 'main' into verl-omni
zhtmike edfdee2
merge main (#43)
zhtmike 3e28e06
Merge branch 'verl-omni' into verl-omni-pr
zhtmike 1e0fd88
[misc] chore: misc changes (#44)
zhtmike cad2165
Merge branch 'verl-omni' into verl-omni-pr
zhtmike 79e6427
Merge branch 'main' into verl-omni
zhtmike 36177ff
Merge branch 'main' into verl-omni
zhtmike d1379df
[misc] chore: merge main (#46)
zhtmike 716436b
Merge branch 'verl-omni' into verl-omni-pr
zhtmike 9938cd8
[rollout] feat: Rebase with vllm-omni 0.16.0 (#42)
knlnguyen1802 a5fdd4b
[misc] chore: fix CI & bugs after vllm-omni upgrade (#47)
zhtmike 2e428f5
Merge branch 'verl-omni' into verl-omni-pr
zhtmike 6b7a4f0
fix mask
zhtmike 1aa6693
Merge branch 'verl-omni' into verl-omni-pr
zhtmike File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| # Qwen-Image lora, vllm_omni rollout | ||
| set -x | ||
| export TOKENIZERS_PARALLELISM="false" | ||
|
|
||
| ENGINE=vllm_omni | ||
| REWARD_ENGINE=vllm | ||
|
|
||
| reward_path=tests/experimental/reward_loop/reward_fn.py | ||
| reward_model_name=$HOME/models/Qwen/Qwen2.5-VL-3B-Instruct | ||
|
|
||
|
|
||
| python3 -m verl.trainer.main_flowgrpo \ | ||
| algorithm.adv_estimator=flow_grpo \ | ||
| data.train_files=$HOME/dataset/ocr/train.txt \ | ||
| data.val_files=$HOME/dataset/ocr/test.txt \ | ||
| data.train_batch_size=32 \ | ||
| data.val_max_samples=128 \ | ||
| data.max_prompt_length=1058 \ | ||
| data.filter_overlong_prompts=True \ | ||
| data.data_source=ocr \ | ||
| data.custom_cls.path=verl/utils/dataset/qwen_dataset.py \ | ||
| data.custom_cls.name=QwenDataset \ | ||
| +data.apply_chat_template_kwargs.max_length=1058 \ | ||
| +data.apply_chat_template_kwargs.padding=True \ | ||
| +data.apply_chat_template_kwargs.truncation=True \ | ||
| actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen-Image \ | ||
| actor_rollout_ref.model.tokenizer_path=$HOME/models/Qwen/Qwen-Image/tokenizer \ | ||
| actor_rollout_ref.model.lora_rank=64 \ | ||
| actor_rollout_ref.model.lora_alpha=128 \ | ||
| actor_rollout_ref.model.target_modules="['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','to_add_out','img_mlp.net.0.proj','img_mlp.net.2','txt_mlp.net.0.proj','txt_mlp.net.2']" \ | ||
| actor_rollout_ref.actor.optim.lr=3e-4 \ | ||
| actor_rollout_ref.actor.optim.weight_decay=0.0001 \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=16 \ | ||
| actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ | ||
| actor_rollout_ref.actor.use_kl_loss=False \ | ||
| actor_rollout_ref.actor.fsdp_config.param_offload=True \ | ||
| actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ | ||
| actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ | ||
| actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo \ | ||
| actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ | ||
| actor_rollout_ref.rollout.name=$ENGINE \ | ||
| actor_rollout_ref.rollout.n=16 \ | ||
| actor_rollout_ref.rollout.guidance_scale=1.0 \ | ||
| actor_rollout_ref.rollout.agent.default_agent_loop=diffusion_single_turn_agent \ | ||
| actor_rollout_ref.rollout.load_format=safetensors \ | ||
| actor_rollout_ref.rollout.layered_summon=True \ | ||
| actor_rollout_ref.rollout.max_model_len=1058 \ | ||
| actor_rollout_ref.rollout.sde_window_size=3 \ | ||
| actor_rollout_ref.rollout.sde_window_range="[0,5]" \ | ||
| +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.custom_pipeline=verl.workers.utils.vllm_omni_patch.pipelines.pipeline_qwenimage.QwenImagePipelineWithLogProb \ | ||
| actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ | ||
| reward.reward_manager.name=diffusion \ | ||
| reward.reward_model.model_path=$reward_model_name \ | ||
| reward.reward_model.enable=True \ | ||
| reward.reward_model.rollout.name=$REWARD_ENGINE \ | ||
| reward.custom_reward_function.path=$reward_path \ | ||
| reward.custom_reward_function.name=compute_score_ocr \ | ||
| trainer.use_legacy_worker_impl=disable \ | ||
| trainer.logger='["console", "wandb"]' \ | ||
| trainer.project_name=flow_grpo \ | ||
| trainer.experiment_name=qwen_image_ocr \ | ||
| trainer.log_val_generations=8 \ | ||
| trainer.n_gpus_per_node=4 \ | ||
| trainer.nnodes=1 \ | ||
| trainer.save_freq=100 \ | ||
| trainer.test_freq=5 \ | ||
| trainer.total_epochs=15 $@ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
131 changes: 131 additions & 0 deletions
131
tests/experimental/agent_loop/test_diffusion_agent_loop.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import os | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import ray | ||
| from omegaconf import DictConfig | ||
| from PIL import Image | ||
|
|
||
| from verl.experimental.agent_loop.diffusion_agent_loop import DiffusionAgentLoopManager | ||
| from verl.protocol import DataProto | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def init_config() -> DictConfig: | ||
| from hydra import compose, initialize_config_dir | ||
|
|
||
| with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): | ||
| config = compose(config_name="ppo_diffusion_trainer") | ||
|
|
||
| model_path = os.path.expanduser("~/models/Qwen/Qwen-Image") | ||
| config.actor_rollout_ref.model.path = model_path | ||
| config.actor_rollout_ref.model.tokenizer_path = os.path.join(model_path, "tokenizer") | ||
| config.actor_rollout_ref.rollout.name = "vllm_omni" | ||
| config.actor_rollout_ref.rollout.mode = "async" | ||
| config.actor_rollout_ref.rollout.enforce_eager = True | ||
| config.actor_rollout_ref.rollout.n = 4 | ||
| config.actor_rollout_ref.rollout.num_inference_steps = 10 | ||
| config.actor_rollout_ref.rollout.guidance_scale = 1.0 | ||
| config.actor_rollout_ref.rollout.agent.num_workers = 2 | ||
| config.actor_rollout_ref.rollout.skip_tokenizer_init = True | ||
| config.actor_rollout_ref.rollout.agent.default_agent_loop = "diffusion_single_turn_agent" | ||
| config.actor_rollout_ref.rollout.sde_window_size = 3 | ||
| config.actor_rollout_ref.rollout.sde_window_range = [0, 5] | ||
|
|
||
| qwen_pipeline = "verl.workers.utils.vllm_omni_patch.pipelines.pipeline_qwenimage.QwenImagePipelineWithLogProb" | ||
| config.actor_rollout_ref.rollout.engine_kwargs.vllm_omni = {"custom_pipeline": qwen_pipeline} | ||
| config.data.custom_cls.path = "verl/utils/dataset/qwen_dataset.py" | ||
| config.data.custom_cls.name = "QwenDataset" | ||
| config.reward.reward_manager.name = "diffusion" | ||
| config.trainer.n_gpus_per_node = 4 | ||
|
|
||
| tokenizer_max_length = 1024 | ||
| prompt_template_encode_start_idx = 34 | ||
| max_length = tokenizer_max_length + prompt_template_encode_start_idx | ||
|
|
||
| config.data.apply_chat_template_kwargs = dict(max_length=max_length, padding=True, truncation=True) | ||
| config.data.max_prompt_length = max_length | ||
| config.actor_rollout_ref.rollout.max_model_len = max_length | ||
|
|
||
| # TODO (mike): test with TP later | ||
| config.actor_rollout_ref.rollout.tensor_model_parallel_size = 1 | ||
| return config | ||
|
|
||
|
|
||
| def test_single_turn(init_config): | ||
| ray.init( | ||
| runtime_env={ | ||
| "env_vars": { | ||
| "TOKENIZERS_PARALLELISM": "true", | ||
| "NCCL_DEBUG": "WARN", | ||
| "VLLM_LOGGING_LEVEL": "INFO", | ||
| } | ||
| } | ||
| ) | ||
|
|
||
| agent_loop_manager = DiffusionAgentLoopManager(init_config) | ||
|
|
||
| system_prompt = ( | ||
| "Describe the image by detailing the color, shape, size, texture, quantity, text, " | ||
| "spatial relationships of the objects and background:" | ||
| ) | ||
| user_prompts = ["A photo of cute cat with long fur and big eyes.", "A photo of cute dog with short hair."] | ||
|
|
||
| raw_prompts = [] | ||
| for user_prompt in user_prompts: | ||
| raw_prompts.append( | ||
| [ | ||
| {"role": "system", "content": system_prompt}, | ||
| {"role": "user", "content": user_prompt}, | ||
| ] | ||
| ) | ||
|
|
||
| batch = DataProto( | ||
| non_tensor_batch={ | ||
| "raw_prompt": np.array(raw_prompts), | ||
| "data_source": np.array(["jpeg_compressibility"] * len(raw_prompts)), | ||
| "reward_model": np.array([{"style": "rule", "ground_truth": ""}] * len(raw_prompts)), | ||
| }, | ||
| ) | ||
| n = init_config.actor_rollout_ref.rollout.n | ||
| batch = batch.repeat(n) | ||
| result = agent_loop_manager.generate_sequences(prompts=batch) | ||
| assert len(result) == len(raw_prompts) * n | ||
|
|
||
| expected_batch_keys = [ | ||
| "responses", | ||
| "all_latents", | ||
| "all_timesteps", | ||
| "prompt_embeds", | ||
| "prompt_embeds_mask", | ||
| "input_ids", | ||
| "attention_mask", | ||
| ] | ||
| for key in expected_batch_keys: | ||
| assert key in result.batch, f"Key {key} not found in result batch." | ||
|
|
||
| # check turns | ||
| num_turns = result.non_tensor_batch["__num_turns__"] | ||
| assert np.all(num_turns == 2) | ||
|
|
||
| # TODO: for visualization, drop later | ||
| images_pil = (result.batch["responses"].permute(0, 2, 3, 1).numpy() * 255.0).astype("uint8") | ||
| for i, image in enumerate(images_pil): | ||
| image_path = os.path.join(f"{i}.jpg") | ||
| Image.fromarray(image).save(image_path) | ||
zhtmike marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| print("Test passed!") | ||
| ray.shutdown() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.