-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Description
System Info
verl version: main ,commit id f3a0233
platform: A100
python:3.12
torch:2.8.0
主线engine_worker实现bug, 在训练的前向切分microbatch时, 不开启dynamic batch size情况下,直接使用tensordict的split方法不支持NestedTensor
代码实现位置
报错堆栈:
(TaskRunner pid=1953021) File "verl/verl/workers/engine_workers.py", line 323, in infer_batch
(TaskRunner pid=1953021) output = self.engine.infer_batch(data, loss_function=loss_function)
(TaskRunner pid=1953021) File "verl/verl/workers/engine/base.py", line 140, in infer_batch
(TaskRunner pid=1953021) outputs = self.forward_backward_batch(data, loss_function, forward_only=True)
(TaskRunner pid=1953021) File "verl/verl/workers/engine/fsdp/transformer_impl.py", line 499, in forward_backward_batch
(TaskRunner pid=1953021) micro_batches, indices = prepare_micro_batches(
(TaskRunner pid=1953021) File "verl/verl/workers/engine/utils.py", line 88, in prepare_micro_batches
(TaskRunner pid=1953021) micro_batches = data.split(micro_batch_size_per_gpu)
(TaskRunner pid=1953021) File "/root/miniconda3/envs/verl-sgl-mf-py312/lib/python3.12/site-packages/tensordict/_td.py", line 1774, in split
(TaskRunner pid=1953021) splits = {k: v.split(splits, dim) for k, v in self.items()}
(TaskRunner pid=1953021) File "/root/miniconda3/envs/verl-sgl-mf-py312/lib/python3.12/site-packages/torch/nested/_internal/ops.py", line 54, in _wrap_jagged_dim
(TaskRunner pid=1953021) raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
(TaskRunner pid=1953021) RuntimeError: split_with_sizes(): not supported for NestedTensor on dim=0
初步定位分析:
原有fsdpworker、megatronworker实现中切分batch时数据类型为DataProto,自己定义了split方法,最新engine worker中切分时数据类型直接为TensorDict,batch中的input_ids和position_ids为NestedTensor,形状不规则
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
grpo主线脚本
use_legacy_worker_impl=disable
use_dynamic_bsz=False
Expected behavior
计算old log p时前向切分batch报错