Fix total_batch_size logging for sequence parallelism#1542
Conversation
accelerator.get_tracker("wandb") returns a GeneralTracker stub on
non-main ranks that lacks the .run attribute, causing crashes in
multi-node training. Guard all wandb_tracker.run.url accesses with
main process checks.
Also adds a two-node SFT integration test script to catch this.
Made-with: Cursor
beaker_config was only defined inside the push_to_hub block but referenced in the with_tracking block. Call maybe_get_beaker_config() directly where it's needed. Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
- Remove try/except around ParallelismConfig imports - Update help text: flash attn recommended, not required - Remove redundant ParallelismConfig runtime check - Add sequence_parallel_size to two-node test script - Add SP changelog entry Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
The previous commit created the scheduler after prepare() without re-wrapping it via accelerator.prepare(), which broke the AcceleratedScheduler gating on sync_gradients — the scheduler would step every micro-batch instead of every optimizer step. Restored the original pre-prepare scheduler creation for non-SP. For SP, recreate and re-wrap via accelerator.prepare() so the scheduler is correctly gated. Also updated changelog. Made-with: Cursor
…ses multiplier Made-with: Cursor
…ltiplier Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
SP ranks process the same sequence, not independent data. The effective data-parallel world size is num_processes // sequence_parallel_size. Made-with: Cursor
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request primarily addresses an inaccuracy in Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant changes to add Ulysses sequence parallelism support for SFT training, along with several related fixes. While the PR title focuses on a logging fix for total_batch_size, the scope is much broader. The implementation of sequence parallelism, including data collation, loss aggregation, and fixes for multi-process execution, is mostly well-handled. However, I've identified a critical issue with the learning rate scheduler recreation that could lead to incorrect learning rate decay in distributed settings. I've also suggested an improvement to the loss aggregation logic for better clarity and efficiency. Overall, this is a valuable addition, and addressing the identified issue is crucial for correct training behavior.
I am having trouble creating individual review comments. Click here to see my feedback.
open_instruct/finetune.py (721-725)
The lr_scheduler is recreated here to account for changes in max_train_steps due to sequence parallelism. However, the new scheduler is not prepared with accelerator.prepare(). The original lr_scheduler was wrapped by accelerator.prepare() to handle distributed training details (like scaling steps by num_processes). By reassigning lr_scheduler to a raw scheduler, the lr_scheduler.step() call in the training loop will not behave as expected in a distributed environment, likely causing the learning rate to decay much slower than intended.
The new scheduler must also be prepared.
if args.sequence_parallel_size > 1:
# SP changes the dataloader length post-prepare. Recreate the scheduler using
# the post-prepare max_train_steps. Multiply by gradient_accumulation_steps because
# the scheduler is called every micro-batch (not just on optimizer steps).
lr_scheduler = _create_scheduler(args, optimizer, args.max_train_steps * args.gradient_accumulation_steps)
lr_scheduler = accelerator.prepare(lr_scheduler)
open_instruct/finetune.py (826-840)
This sequence parallelism loss aggregation logic is correct. However, it can be implemented more cleanly and efficiently using torch.distributed.all_gather instead of the legacy torch.distributed.nn.functional.all_gather. This also allows removing some redundant code.
if args.sequence_parallel_size > 1:
sp_group = accelerator.torch_device_mesh["sp"].get_group()
# Gather losses from all SP ranks
losses_per_rank = [torch.empty_like(loss) for _ in range(args.sequence_parallel_size)]
torch.distributed.all_gather(losses_per_rank, loss, group=sp_group)
# Gather non-padded token counts from all SP ranks
good_tokens = (batch["labels"] != -100).view(-1).sum().to(loss.dtype)
good_tokens_per_rank = [torch.empty_like(good_tokens) for _ in range(args.sequence_parallel_size)]
torch.distributed.all_gather(good_tokens_per_rank, good_tokens, group=sp_group)
# Calculate weighted average loss across SP ranks
total_loss_sp = sum(l * n for l, n in zip(losses_per_rank, good_tokens_per_rank))
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss_sp / torch.clamp(total_good_tokens, min=1.0)
Made-with: Cursor
|
|
||
| # Train! | ||
| total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | ||
| dp_world_size = accelerator.num_processes // args.sequence_parallel_size |
There was a problem hiding this comment.
Up to you ofc but I think we should refactor this somewhere so we can share it across DPO, GRPO, SFT
fixing logging on the batch size - taking into account SP.
Made with Cursor