Skip to content

R3 replay: RuntimeError 'Split sizes doesn't match total dim 0 size' in Megatron all_to_all_single on MoE compute_log_prob #1002

@DavidBellamy

Description

@DavidBellamy

Reproduction

On GLM-4.7-Flash with --use-rollout-routing-replay and --expert-model-parallel-size 8, the trainer crashes during compute_log_prob with:

RuntimeError: Split sizes doesn't match total dim 0 size

Traceback terminates at:

File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4688, in all_to_all_single
    work = group.alltoall_base(

Wrapping via miles/utils/reloadable_process_group.py:55.

Scope

  • Reliable reproduction with FAST_ITER_MOCK_AGENT + FAST_ITER_CONFIG=pd-disagg + USE_ROLLOUT_ROUTING_REPLAY=1
  • With USE_ROLLOUT_ROUTING_REPLAY=0 the same topology runs 10+ train_steps cleanly.

So the bug is specific to the R3 replay path. During replay, the MoE token-split sizes decided during inference do not round-trip correctly to training ranks when expert-parallel alltoall is re-enacted, producing a shape mismatch on the training side.

Suggested next step

Add a rank-level assertion in the R3 replay code that the sum of per-rank token counts matches the total before invoking all_to_all_single, with a diagnostic dump of each rank's declared split. Would help pin down whether the bug is in the replay-record serialization or in the replay-reader reconstruction.

Happy to test fixes against FAST_ITER_MOCK_AGENT which reproduces reliably.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions