Skip to content

Commit 0989315

Browse files
authored
[misc] fix data proto (#458)
1 parent cb9166c commit 0989315

5 files changed

Lines changed: 11 additions & 11 deletions

File tree

examples/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ algorithm:
3434
worker:
3535
actor:
3636
global_batch_size: 128 # equivalent to verl's actor.ppo_mini_batch_size
37-
micro_batch_size_per_device_for_update: 4 # equivalent to verl's actor.ppo_micro_batch_size_per_gpu
38-
micro_batch_size_per_device_for_experience: 16 # equivalent to verl's rollout.log_prob_micro_batch_size_per_gpu
37+
micro_batch_size_per_device_for_update: 1 # equivalent to verl's actor.ppo_micro_batch_size_per_gpu
38+
micro_batch_size_per_device_for_experience: 2 # equivalent to verl's rollout.log_prob_micro_batch_size_per_gpu
3939
max_grad_norm: 1.0
4040
padding_free: true
4141
dynamic_batching: true

examples/qwen2_5_vl_32b_geo3k_grpo.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ python3 -m verl.trainer.main \
1111
data.train_files=hiyouga/geometry3k@train \
1212
data.val_files=hiyouga/geometry3k@test \
1313
worker.actor.model.model_path=${MODEL_PATH} \
14-
worker.actor.micro_batch_size_per_device_for_update=1 \
15-
worker.actor.micro_batch_size_per_device_for_experience=8 \
1614
worker.actor.fsdp.torch_dtype=bf16 \
1715
worker.actor.optim.strategy=adamw_bf16 \
1816
worker.rollout.tensor_parallel_size=8 \

verl/protocol.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,14 @@ def __getitem__(
204204
raise TypeError(f"Indexing with {type(item)} is not supported.")
205205

206206
def __getstate__(self) -> tuple[bytes, dict[str, NDArray], dict[str, Any]]:
207-
buffer = io.BytesIO()
208207
if self.batch is not None:
209-
self.batch: TensorDict = self.batch.contiguous()
210-
self.batch: TensorDict = self.batch.consolidate()
208+
batch_to_save: TensorDict = self.batch.contiguous()
209+
batch_to_save: TensorDict = batch_to_save.consolidate()
210+
else:
211+
batch_to_save = None
211212

212-
torch.save(self.batch, buffer)
213+
buffer = io.BytesIO()
214+
torch.save(batch_to_save, buffer)
213215
buffer_bytes = buffer.getvalue()
214216
return buffer_bytes, self.non_tensor_batch, self.meta_info
215217

verl/trainer/ray_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ..utils import torch_functional as VF
4040
from ..utils.checkpoint import CHECKPOINT_TRACKER, find_latest_ckpt, remove_obsolete_ckpt
4141
from ..utils.logger import Tracker
42-
from ..utils.py_functional import convert_dict_to_str, timer
42+
from ..utils.py_functional import convert_dict_to_str, timer, unflatten_dict
4343
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
4444
from ..workers.fsdp_workers import FSDPWorker
4545
from ..workers.reward import FunctionRewardManager
@@ -694,7 +694,7 @@ def fit(self):
694694
val_metrics = self._validate()
695695
self.logger.log(data=val_metrics, step=self.global_step)
696696

697-
print(f"Final validation metrics: {convert_dict_to_str(val_metrics)}")
697+
print(f"Final validation metrics:\n{convert_dict_to_str(unflatten_dict(val_metrics))}")
698698

699699
if self.config.trainer.save_freq <= 0 or self.global_step % self.config.trainer.save_freq != 0:
700700
self._save_checkpoint()

verl/workers/fsdp_workers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def _process_multi_modal_inputs(self, data: DataProto):
456456
multi_modal_inputs_cache = {} # avoid repeated processing for n > 1 samples
457457
for index, multi_modal_data in zip(
458458
data.non_tensor_batch["uid"], data.non_tensor_batch["multi_modal_data"]
459-
): # process multi modal data per sample
459+
): # process multi modal data per sample
460460
if index not in multi_modal_inputs_cache:
461461
images, videos = [], []
462462
if "images" in multi_modal_data:

0 commit comments

Comments
 (0)