Skip to content

Fix total_batch_size logging for sequence parallelism#1542

Merged
hamishivi merged 40 commits intomainfrom
hamishivi/fix-sft-sp-batch-size-log
Mar 20, 2026
Merged

Fix total_batch_size logging for sequence parallelism#1542
hamishivi merged 40 commits intomainfrom
hamishivi/fix-sft-sp-batch-size-log

Conversation

@hamishivi
Copy link
Copy Markdown
Collaborator

@hamishivi hamishivi commented Mar 20, 2026

fixing logging on the batch size - taking into account SP.

Made with Cursor

accelerator.get_tracker("wandb") returns a GeneralTracker stub on
non-main ranks that lacks the .run attribute, causing crashes in
multi-node training. Guard all wandb_tracker.run.url accesses with
main process checks.

Also adds a two-node SFT integration test script to catch this.

Made-with: Cursor
beaker_config was only defined inside the push_to_hub block but
referenced in the with_tracking block. Call maybe_get_beaker_config()
directly where it's needed.

Made-with: Cursor
Made-with: Cursor
- Remove try/except around ParallelismConfig imports
- Update help text: flash attn recommended, not required
- Remove redundant ParallelismConfig runtime check
- Add sequence_parallel_size to two-node test script
- Add SP changelog entry

Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
The previous commit created the scheduler after prepare() without
re-wrapping it via accelerator.prepare(), which broke the
AcceleratedScheduler gating on sync_gradients — the scheduler would
step every micro-batch instead of every optimizer step.

Restored the original pre-prepare scheduler creation for non-SP. For SP,
recreate and re-wrap via accelerator.prepare() so the scheduler is
correctly gated.

Also updated changelog.

Made-with: Cursor
SP ranks process the same sequence, not independent data. The effective
data-parallel world size is num_processes // sequence_parallel_size.

Made-with: Cursor
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily addresses an inaccuracy in total_batch_size logging when sequence parallelism is active, ensuring that the reported batch size correctly reflects the data-parallel world size. It also fully integrates Ulysses sequence parallelism into the SFT training pipeline, including necessary adjustments to data collation, loss aggregation, and learning rate scheduling. Additionally, it includes critical fixes for tracking setup in multi-node SFT training, improving robustness and correctness.

Highlights

  • Corrected Total Batch Size Logging: The calculation for total_batch_size was updated to correctly reflect the effective data-parallel world size when sequence parallelism is enabled. Previously, it used the total number of processes, leading to inflated log values. Now, it accounts for the sequence_parallel_size to provide accurate logging.
  • Ulysses Sequence Parallelism Integration: Implemented support for Ulysses sequence parallelism in SFT training via the --sequence_parallel_size argument. This leverages HF Accelerate's ParallelismConfig with the DeepSpeed Ulysses SP backend, enabling training with longer context lengths by sharding sequences across GPUs.
  • Sequence Parallelism-Aware Loss Aggregation: Introduced logic to correctly aggregate loss across sequence parallel ranks, ensuring accurate loss calculation during training with Ulysses sequence parallelism.
  • Data Collator and Scheduler Adjustments for SP: Modified the data collator to handle sequence parallelism by removing the 'index' column and padding sequence lengths to be divisible by the SP size. The learning rate scheduler creation was also adjusted to account for changes in dataloader length post-preparation when SP is active.
  • Fixes for Multi-Node SFT Training: Resolved an AttributeError for wandb_tracker.run.url on non-main processes and an UnboundLocalError for beaker_config in SFT tracking setup, by guarding these accesses with accelerator.is_main_process checks.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant changes to add Ulysses sequence parallelism support for SFT training, along with several related fixes. While the PR title focuses on a logging fix for total_batch_size, the scope is much broader. The implementation of sequence parallelism, including data collation, loss aggregation, and fixes for multi-process execution, is mostly well-handled. However, I've identified a critical issue with the learning rate scheduler recreation that could lead to incorrect learning rate decay in distributed settings. I've also suggested an improvement to the loss aggregation logic for better clarity and efficiency. Overall, this is a valuable addition, and addressing the identified issue is crucial for correct training behavior.

I am having trouble creating individual review comments. Click here to see my feedback.

open_instruct/finetune.py (721-725)

high

The lr_scheduler is recreated here to account for changes in max_train_steps due to sequence parallelism. However, the new scheduler is not prepared with accelerator.prepare(). The original lr_scheduler was wrapped by accelerator.prepare() to handle distributed training details (like scaling steps by num_processes). By reassigning lr_scheduler to a raw scheduler, the lr_scheduler.step() call in the training loop will not behave as expected in a distributed environment, likely causing the learning rate to decay much slower than intended.

The new scheduler must also be prepared.

    if args.sequence_parallel_size > 1:
        # SP changes the dataloader length post-prepare. Recreate the scheduler using
        # the post-prepare max_train_steps. Multiply by gradient_accumulation_steps because
        # the scheduler is called every micro-batch (not just on optimizer steps).
        lr_scheduler = _create_scheduler(args, optimizer, args.max_train_steps * args.gradient_accumulation_steps)
        lr_scheduler = accelerator.prepare(lr_scheduler)

open_instruct/finetune.py (826-840)

medium

This sequence parallelism loss aggregation logic is correct. However, it can be implemented more cleanly and efficiently using torch.distributed.all_gather instead of the legacy torch.distributed.nn.functional.all_gather. This also allows removing some redundant code.

                if args.sequence_parallel_size > 1:
                    sp_group = accelerator.torch_device_mesh["sp"].get_group()
                    
                    # Gather losses from all SP ranks
                    losses_per_rank = [torch.empty_like(loss) for _ in range(args.sequence_parallel_size)]
                    torch.distributed.all_gather(losses_per_rank, loss, group=sp_group)

                    # Gather non-padded token counts from all SP ranks
                    good_tokens = (batch["labels"] != -100).view(-1).sum().to(loss.dtype)
                    good_tokens_per_rank = [torch.empty_like(good_tokens) for _ in range(args.sequence_parallel_size)]
                    torch.distributed.all_gather(good_tokens_per_rank, good_tokens, group=sp_group)

                    # Calculate weighted average loss across SP ranks
                    total_loss_sp = sum(l * n for l, n in zip(losses_per_rank, good_tokens_per_rank))
                    total_good_tokens = sum(good_tokens_per_rank)
                    loss = total_loss_sp / torch.clamp(total_good_tokens, min=1.0)

Comment thread open_instruct/finetune.py

# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
dp_world_size = accelerator.num_processes // args.sequence_parallel_size
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you ofc but I think we should refactor this somewhere so we can share it across DPO, GRPO, SFT

@hamishivi hamishivi added this pull request to the merge queue Mar 20, 2026
Merged via the queue into main with commit 1fe6153 Mar 20, 2026
6 of 7 checks passed
@hamishivi hamishivi deleted the hamishivi/fix-sft-sp-batch-size-log branch March 20, 2026 15:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants