Skip to content

【Bug】RuntimeError: split_with_sizes():not supported for NestedTensor on dim=0 #4737

@hustmf

Description

@hustmf

System Info

verl version: main ,commit id f3a0233
platform: A100
python:3.12
torch:2.8.0
主线engine_worker实现bug, 在训练的前向切分microbatch时, 不开启dynamic batch size情况下,直接使用tensordict的split方法不支持NestedTensor
代码实现位置
报错堆栈:
(TaskRunner pid=1953021) File "verl/verl/workers/engine_workers.py", line 323, in infer_batch
(TaskRunner pid=1953021) output = self.engine.infer_batch(data, loss_function=loss_function)
(TaskRunner pid=1953021) File "verl/verl/workers/engine/base.py", line 140, in infer_batch
(TaskRunner pid=1953021) outputs = self.forward_backward_batch(data, loss_function, forward_only=True)
(TaskRunner pid=1953021) File "verl/verl/workers/engine/fsdp/transformer_impl.py", line 499, in forward_backward_batch
(TaskRunner pid=1953021) micro_batches, indices = prepare_micro_batches(
(TaskRunner pid=1953021) File "verl/verl/workers/engine/utils.py", line 88, in prepare_micro_batches
(TaskRunner pid=1953021) micro_batches = data.split(micro_batch_size_per_gpu)
(TaskRunner pid=1953021) File "/root/miniconda3/envs/verl-sgl-mf-py312/lib/python3.12/site-packages/tensordict/_td.py", line 1774, in split
(TaskRunner pid=1953021) splits = {k: v.split(splits, dim) for k, v in self.items()}
(TaskRunner pid=1953021) File "/root/miniconda3/envs/verl-sgl-mf-py312/lib/python3.12/site-packages/torch/nested/_internal/ops.py", line 54, in _wrap_jagged_dim
(TaskRunner pid=1953021) raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
(TaskRunner pid=1953021) RuntimeError: split_with_sizes(): not supported for NestedTensor on dim=0
初步定位分析:
原有fsdpworker、megatronworker实现中切分batch时数据类型为DataProto,自己定义了split方法,最新engine worker中切分时数据类型直接为TensorDict,batch中的input_ids和position_ids为NestedTensor,形状不规则

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

grpo主线脚本
use_legacy_worker_impl=disable
use_dynamic_bsz=False

Expected behavior

计算old log p时前向切分batch报错

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions