Skip to content

Conversation

@zhengchenyu
Copy link
Contributor

The current train_ddp.py has two problems:

  • It cannot guarantee the sequential reading of each sample. For example, the replica group world size is 3, but only 2 replicas are working. Some samples will be missing.
  • When the replica group world size changes, the total batch size used for gradient aggregation will change. This makes idempotency computation impossible.

The following modifications were made:

  • SkipDistributedSampler is provided to ensure that training can resume from any offset.
  • The dataloader is reconfigured when the quorum changes.
  • For the training rounds that were just initialized and when the quorum changed, the commit will be abandoned due to the setting of the dirty flag.
  • Add example train_ddp2.py.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 18, 2025
@zhengchenyu
Copy link
Contributor Author

zhengchenyu commented Nov 28, 2025

To supplement the experimental results, adjustments were made to train_ddp_fix_batch.py. NUM_EPOCHS was set to 1. BATCH_SIZE was set to 4 to ensure more steps. Then the same initialization model was loaded for each experiment. In the figure below, the base uses a single worker for training. The ft frequently starts or stops the worker, keeping the number of workers consistently between 1 and 3. The loss curves are almost identical.

ddp

Using the same experimental conditions, fsdp2 yielded the following results:

fsdp2

I also used the same method on deepspeed stage3 and obtained almost identical curves. However, this PR does not involve deepspeed, so it is not shown here.

Note: A slight inconsistency appears in the latter half of the curve. After debugging, I found that this is due to a loss of precision.

@zhengchenyu zhengchenyu marked this pull request as draft December 4, 2025 07:23
@zhengchenyu zhengchenyu closed this Dec 6, 2025
@zhengchenyu zhengchenyu deleted the fix.totat.batch branch December 6, 2025 01:57
…egardless of changes in the replica world size.
@zhengchenyu zhengchenyu reopened this Dec 6, 2025
@zhengchenyu zhengchenyu marked this pull request as ready for review December 6, 2025 02:15
@zhengchenyu
Copy link
Contributor Author

zhengchenyu commented Dec 6, 2025

In fsdp2 experiment, I found get_optimizer_state_dict and set_optimizer_state_dict may call optim.step, then will increase step which is used by adam optimizer. An unexpected increase of 1 step will cause inconsistencies. Therefore, in the fsdp2 experiment, I commented out the call to optim.step in _init_optim_state.

@d4l3k
Copy link
Member

d4l3k commented Dec 6, 2025

@zhengchenyu this is super cool! I'll take a deeper look when I'm back on Monday

@zhengchenyu
Copy link
Contributor Author

zhengchenyu commented Dec 6, 2025

I add sequence diagram to illustrate the start of training and the scaling up from 1 replica world size to 2. local batch = 2, total batch = 4. Each color represents one iteration in the while training process.
The yellow iteration is the start of training . The orange iteration is normal training when replica world size is 1. The green iteration is scaling up from 1 replica world size to 2. The blue iteration is normal training when replica world size is 2.

image

loss.backward()
total_loss += loss.item()

if accumulation_steps > 1:
Copy link
Contributor Author

@zhengchenyu zhengchenyu Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: use loss * (1 / accumulation_steps)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants