Skip to content

Commit 449d9fc

Browse files
committed
save
1 parent 6976016 commit 449d9fc

10 files changed

Lines changed: 80 additions & 63 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,7 @@ outputs/
175175
checkpoints/
176176
wandb/
177177
tensorboard_log/
178+
179+
# data
180+
images/
181+
images*

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ See [baselines.md](assets/baselines.md).
156156
- **ViGoRL**: Grounded Reinforcement Learning for Visual Reasoning. [![[code]](https://img.shields.io/github/stars/Gabesarch/grounded-rl)](https://github.com/Gabesarch/grounded-rl) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.22334-blue)](https://arxiv.org/abs/2505.23678)
157157
- **Revisual-R1**: Advancing Multimodal Reasoning: From Optimized Cold Start to Staged Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/CSfufu/Revisual-R1)](https://github.com/CSfufu/Revisual-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2506.04207-blue)](https://arxiv.org/abs/2506.04207)
158158
- **SophiaVL-R1**: Reinforcing MLLMs Reasoning with Thinking Reward. [![[code]](https://img.shields.io/github/stars/kxfan2002/SophiaVL-R1)](https://github.com/kxfan2002/SophiaVL-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.17018-blue)](https://arxiv.org/abs/2505.17018)
159-
159+
160160
## TODO
161161

162162
- Support LoRA (high priority).

examples/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ data:
44
prompt_key: problem
55
answer_key: answer
66
image_key: images
7+
image_dir: null
78
max_prompt_length: 2048
89
max_response_length: 2048
910
rollout_batch_size: 512

verl/protocol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,11 @@ def pop(
384384
meta_info_keys = meta_info_keys or []
385385

386386
tensors = {}
387-
for key in batch_keys:
387+
for key in batch_keys and key in self.batch:
388388
tensors[key] = self.batch.pop(key)
389389

390390
non_tensors = {}
391-
for key in non_tensor_batch_keys:
391+
for key in non_tensor_batch_keys and key in self.non_tensor_batch:
392392
non_tensors[key] = self.non_tensor_batch.pop(key)
393393

394394
meta_info = {}

verl/trainer/config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class DataConfig:
3838
prompt_key: str = "prompt"
3939
answer_key: str = "answer"
4040
image_key: str = "images"
41+
image_dir: Optional[str] = None
4142
max_prompt_length: int = 512
4243
max_response_length: int = 512
4344
rollout_batch_size: int = 512
@@ -51,10 +52,18 @@ class DataConfig:
5152
filter_overlong_prompts: bool = True
5253

5354
def post_init(self):
55+
if self.image_dir is not None:
56+
if os.path.exists(self.image_dir): # ray job uses absolute path
57+
self.image_dir = os.path.abspath(self.image_dir)
58+
else:
59+
print(f"Image directory {self.image_dir} is not found.")
60+
self.image_dir = None
61+
5462
if self.format_prompt is not None:
5563
if os.path.exists(self.format_prompt): # ray job uses absolute path
5664
self.format_prompt = os.path.abspath(self.format_prompt)
5765
else:
66+
print(f"Format prompt file {self.format_prompt} is not found.")
5867
self.format_prompt = None
5968

6069

@@ -97,7 +106,11 @@ def post_init(self):
97106

98107
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path
99108
if self.load_checkpoint_path is not None:
100-
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
109+
if os.path.exists(self.load_checkpoint_path): # ray job uses absolute path
110+
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
111+
else:
112+
print(f"Model checkpoint {self.load_checkpoint_path} is not found.")
113+
self.load_checkpoint_path = None
101114

102115

103116
@dataclass

verl/trainer/data_loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, proces
3131
prompt_key=config.prompt_key,
3232
answer_key=config.answer_key,
3333
image_key=config.image_key,
34+
image_dir=config.image_dir,
3435
max_prompt_length=config.max_prompt_length,
3536
truncation="right",
3637
format_prompt=config.format_prompt,
@@ -63,6 +64,7 @@ def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, proces
6364
prompt_key=config.prompt_key,
6465
answer_key=config.answer_key,
6566
image_key=config.image_key,
67+
image_dir=config.image_dir,
6668
max_prompt_length=config.max_prompt_length,
6769
truncation="right",
6870
format_prompt=config.format_prompt,

verl/trainer/ray_trainer.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -283,22 +283,13 @@ def _validate(self) -> Dict[str, Any]:
283283
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
284284
sample_inputs.extend(input_texts)
285285

286-
if "multi_modal_data" in test_batch.non_tensor_batch.keys():
287-
test_gen_batch = test_batch.pop(
288-
batch_keys=["input_ids", "attention_mask", "position_ids"],
289-
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
290-
)
291-
else:
292-
test_gen_batch = test_batch.pop(
293-
batch_keys=["input_ids", "attention_mask", "position_ids"],
294-
non_tensor_batch_keys=["raw_prompt_ids"],
295-
)
296-
286+
test_gen_batch = test_batch.pop(
287+
batch_keys=["input_ids", "attention_mask", "position_ids"],
288+
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
289+
)
297290
test_gen_batch.meta_info = self.config.worker.rollout.val_override_config
298-
test_gen_batch.meta_info.update({
299-
"min_pixels": self.config.data.min_pixels,
300-
"max_pixels": self.config.data.max_pixels,
301-
})
291+
test_gen_batch.meta_info["min_pixels"] = self.config.data.min_pixels
292+
test_gen_batch.meta_info["max_pixels"] = self.config.data.max_pixels
302293
test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
303294
test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
304295
test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size)
@@ -485,23 +476,16 @@ def fit(self):
485476

486477
metrics, timing_raw = {}, {}
487478
batch: DataProto = DataProto.from_single_dict(batch_dict)
479+
batch.meta_info = {
480+
"min_pixels": self.config.data.min_pixels,
481+
"max_pixels": self.config.data.max_pixels,
482+
}
488483

489484
# pop those keys for generation
490-
if "multi_modal_data" in batch.non_tensor_batch.keys():
491-
gen_batch = batch.pop(
492-
batch_keys=["input_ids", "attention_mask", "position_ids"],
493-
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
494-
)
495-
gen_batch.meta_info.update({
496-
"min_pixels": self.config.data.min_pixels,
497-
"max_pixels": self.config.data.max_pixels,
498-
})
499-
else:
500-
gen_batch = batch.pop(
501-
batch_keys=["input_ids", "attention_mask", "position_ids"],
502-
non_tensor_batch_keys=["raw_prompt_ids"],
503-
)
504-
485+
gen_batch = batch.pop(
486+
batch_keys=["input_ids", "attention_mask", "position_ids"],
487+
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
488+
)
505489
with timer("step", timing_raw):
506490
# generate a batch
507491
with timer("gen", timing_raw): # wg: worker group

verl/utils/dataset.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
5050
return {**tensors, **non_tensors}
5151

5252

53-
5453
def process_image(image: Union[Dict[str, Any], ImageObject, str], min_pixels: int, max_pixels: int) -> ImageObject:
5554
if isinstance(image, str):
5655
image = Image.open(image)
@@ -59,6 +58,7 @@ def process_image(image: Union[Dict[str, Any], ImageObject, str], min_pixels: in
5958
elif isinstance(image, bytes):
6059
image = Image.open(BytesIO(image))
6160

61+
image.load() # avoid "Too many open files" errors
6262
if (image.width * image.height) > max_pixels:
6363
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
6464
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
@@ -88,6 +88,7 @@ def __init__(
8888
prompt_key: str = "prompt",
8989
answer_key: str = "answer",
9090
image_key: str = "images",
91+
image_dir: Optional[str] = None,
9192
max_prompt_length: int = 1024,
9293
truncation: str = "error",
9394
format_prompt: Optional[str] = None,
@@ -100,6 +101,7 @@ def __init__(
100101
self.prompt_key = prompt_key
101102
self.answer_key = answer_key
102103
self.image_key = image_key
104+
self.image_dir = image_dir
103105
self.max_prompt_length = max_prompt_length
104106
self.truncation = truncation
105107
self.max_pixels = max_pixels
@@ -113,9 +115,11 @@ def __init__(
113115

114116
if os.path.isdir(data_path):
115117
# when we use dataset builder, we should always refer to the train split
116-
self.dataset = load_dataset("parquet", data_dir=data_path, split="train")
118+
file_type = os.path.splitext(os.listdir("images/train")[0])[-1][1:].replace("jsonl", "json")
119+
self.dataset = load_dataset(file_type, data_dir=data_path, split=data_split)
117120
elif os.path.isfile(data_path):
118-
self.dataset = load_dataset("parquet", data_files=data_path, split="train")
121+
file_type = os.path.splitext(data_path)[-1][1:].replace("jsonl", "json")
122+
self.dataset = load_dataset(file_type, data_files=data_path, split=data_split)
119123
else:
120124
# load remote dataset from huggingface hub
121125
self.dataset = load_dataset(data_path, split=data_split)
@@ -164,22 +168,25 @@ def __getitem__(self, index):
164168

165169
if self.image_key in example:
166170
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
167-
raw_image_data = example.pop(self.image_key)
168-
images = [
171+
images = example.pop(self.image_key)
172+
if self.image_dir is not None and len(images) != 0 and isinstance(images[0], str): # image paths
173+
images = [os.path.join(self.image_dir, image) for image in images]
174+
175+
resized_images = [
169176
process_image(image, min_pixels=self.min_pixels, max_pixels=self.max_pixels)
170-
for image in raw_image_data
177+
for image in images
171178
]
172-
model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt")
179+
model_inputs = self.processor(resized_images, [prompt], add_special_tokens=False, return_tensors="pt")
173180
input_ids = model_inputs.pop("input_ids")[0]
174181
attention_mask = model_inputs.pop("attention_mask")[0]
175-
example["multi_modal_data"] = {"image": raw_image_data}
182+
example["multi_modal_inputs"] = {"images": images}
176183
else:
177184
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
178185
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
179186
input_ids = model_inputs.pop("input_ids")[0]
180187
attention_mask = model_inputs.pop("attention_mask")[0]
181188

182-
if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor":
189+
if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
183190
# qwen2vl mrope
184191
position_ids = get_rope_index(
185192
self.processor,

verl/workers/fsdp_workers.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
The main entry point to run the PPO algorithm
1616
"""
1717

18+
from copy import deepcopy
1819
from typing import Literal, Optional, Union
1920

2021
import numpy as np
2122
import psutil
2223
import torch
2324
import torch.distributed as dist
24-
from copy import deepcopy
2525
from accelerate import init_empty_weights
2626
from codetiming import Timer
2727
from torch.distributed.device_mesh import init_device_mesh
@@ -42,6 +42,7 @@
4242
from ..single_controller.base import Worker
4343
from ..single_controller.base.decorator import Dispatch, register
4444
from ..utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
45+
from ..utils.dataset import process_image
4546
from ..utils.flops_counter import FlopsCounter
4647
from ..utils.fsdp_utils import (
4748
get_fsdp_wrap_policy,
@@ -51,7 +52,6 @@
5152
offload_fsdp_model,
5253
offload_fsdp_optimizer,
5354
)
54-
from ..utils.dataset import process_image
5555
from ..utils.model_utils import print_gpu_memory_usage, print_model_size
5656
from ..utils.tokenizer import get_processor, get_tokenizer
5757
from ..utils.torch_dtypes import PrecisionType
@@ -436,10 +436,9 @@ def preprocess_multi_modal_data(self, data: DataProto):
436436
processed_images = []
437437
for multi_modal_data in multi_modal_data_copy:
438438
processed_per_query_images = []
439-
for image in multi_modal_data['image']:
440-
processed_per_query_images.append(
441-
process_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
442-
)
439+
for image in multi_modal_data["image"]:
440+
processed_per_query_images.append(process_image(image, min_pixels=min_pixels, max_pixels=max_pixels))
441+
443442
processed_images.append(processed_per_query_images)
444443

445444
# Note: Using the alternative (commented) code below to process images can lead to subtle resize issues:
@@ -454,17 +453,20 @@ def preprocess_multi_modal_data(self, data: DataProto):
454453
# for j, image in enumerate(per_query_images):
455454
# images[i][j] = process_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
456455

457-
multi_modal_inputs = np.array([
458-
dict(self.processor.image_processor(images=per_query_images, videos=None))
459-
for per_query_images in processed_images
460-
], dtype=object)
456+
multi_modal_inputs = np.array(
457+
[
458+
dict(self.processor.image_processor(images=per_query_images, videos=None))
459+
for per_query_images in processed_images
460+
],
461+
dtype=object,
462+
)
461463
data.non_tensor_batch["multi_modal_inputs"] = multi_modal_inputs
462464

463465
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
464466
def update_actor(self, data: DataProto):
465467
assert self._is_actor
466468
if "multi_modal_inputs" in self._cache:
467-
data.non_tensor_batch['multi_modal_inputs'] = deepcopy(self._cache['multi_modal_inputs'])
469+
data.non_tensor_batch["multi_modal_inputs"] = deepcopy(self._cache["multi_modal_inputs"])
468470
elif "multi_modal_data" in data.non_tensor_batch:
469471
self.preprocess_multi_modal_data(data)
470472

@@ -545,12 +547,14 @@ def generate_sequences(self, prompts: DataProto):
545547
cached_multi_modal_data = None
546548
if "multi_modal_data" in prompts.non_tensor_batch:
547549
cached_multi_modal_data = deepcopy(prompts.non_tensor_batch["multi_modal_data"])
548-
min_pixels = prompts.meta_info['min_pixels']
549-
max_pixels = prompts.meta_info['max_pixels']
550+
min_pixels = prompts.meta_info["min_pixels"]
551+
max_pixels = prompts.meta_info["max_pixels"]
550552
processed_images = []
551553
for i, multi_modal_data in enumerate(prompts.non_tensor_batch["multi_modal_data"]):
552554
for j, image in enumerate(multi_modal_data["image"]):
553-
multi_modal_data['image'][j] = process_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
555+
multi_modal_data["image"][j] = process_image(
556+
image, min_pixels=min_pixels, max_pixels=max_pixels
557+
)
554558
processed_images.append(multi_modal_data)
555559
prompts.non_tensor_batch["multi_modal_data"] = processed_images
556560

@@ -562,7 +566,9 @@ def generate_sequences(self, prompts: DataProto):
562566
output.non_tensor_batch["multi_modal_data"] = cached_multi_modal_data
563567
if sampling_n > 1:
564568
output.non_tensor_batch["multi_modal_data"] = np.repeat(
565-
output.non_tensor_batch["multi_modal_data"], repeats=sampling_n, axis=0,
569+
output.non_tensor_batch["multi_modal_data"],
570+
repeats=sampling_n,
571+
axis=0,
566572
)
567573

568574
output = self.rollout_sharding_manager.postprocess_data(output)
@@ -577,7 +583,7 @@ def compute_log_probs(self, data: DataProto):
577583
if "multi_modal_data" in data.non_tensor_batch:
578584
self.preprocess_multi_modal_data(data)
579585
# create cache for multi_modal_inputs
580-
self._cache['multi_modal_inputs'] = deepcopy(data.non_tensor_batch['multi_modal_inputs'])
586+
self._cache["multi_modal_inputs"] = deepcopy(data.non_tensor_batch["multi_modal_inputs"])
581587

582588
data = data.to(torch.cuda.current_device())
583589
if self._use_param_offload:
@@ -611,7 +617,7 @@ def compute_ref_log_probs(self, data: DataProto):
611617
# not in the ref_policy's or critic's caches.
612618
assert self._is_ref
613619
if "multi_modal_inputs" in self._cache:
614-
data.non_tensor_batch['multi_modal_inputs'] = deepcopy(self._cache['multi_modal_inputs'])
620+
data.non_tensor_batch["multi_modal_inputs"] = deepcopy(self._cache["multi_modal_inputs"])
615621
elif "multi_modal_data" in data.non_tensor_batch:
616622
self.preprocess_multi_modal_data(data)
617623

@@ -643,7 +649,7 @@ def compute_values(self, data: DataProto):
643649
# The `self._cache` is empty here since cached `multi_modal_inputs` is only saved in the actor's _cache,
644650
# not in the ref_policy's or critic's caches.
645651
if "multi_modal_inputs" in self._cache:
646-
data.non_tensor_batch['multi_modal_inputs'] = deepcopy(self._cache['multi_modal_inputs'])
652+
data.non_tensor_batch["multi_modal_inputs"] = deepcopy(self._cache["multi_modal_inputs"])
647653
elif "multi_modal_data" in data.non_tensor_batch:
648654
self.preprocess_multi_modal_data(data)
649655

@@ -668,7 +674,7 @@ def update_critic(self, data: DataProto):
668674
# The `self._cache` is empty here since cached `multi_modal_inputs` is only saved in the actor's _cache,
669675
# not in the ref_policy's or critic's caches.
670676
if "multi_modal_inputs" in self._cache:
671-
data.non_tensor_batch['multi_modal_inputs'] = deepcopy(self._cache['multi_modal_inputs'])
677+
data.non_tensor_batch["multi_modal_inputs"] = deepcopy(self._cache["multi_modal_inputs"])
672678
elif "multi_modal_data" not in data.non_tensor_batch:
673679
self.preprocess_multi_modal_data(data)
674680

verl/workers/rollout/vllm_rollout_spmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,5 +201,5 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
201201
return DataProto(
202202
batch=batch,
203203
non_tensor_batch=non_tensor_batch,
204-
meta_info=prompts.meta_info.copy(),
204+
meta_info=prompts.meta_info.copy()
205205
)

0 commit comments

Comments
 (0)