-
Notifications
You must be signed in to change notification settings - Fork 994
[megatron, GRPO] fix: CP/padding_free repeat_interleave mismatch #6720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This PR fix the logic to ensure template encoding captures true packed lengths and propagates them through seq_lengths, advantages expansion, and truncation masks. It normalizes packed-sequence metadata in loss_func by aligning lengths_with_padding with the completion tensor width and padding/trimming per-token log-probabilities so downstream splits stay consistent. It guards against invalid metadata by validating the adjusted final length and reusing the helper for ref/old log-probs. Error stack trace: ```logs Traceback (most recent call last): File "ms-swift/swift/cli/_megatron/rlhf.py", line 5, in <module> megatron_rlhf_main() File "ms-swift/swift/megatron/train/rlhf.py", line 70, in megatron_rlhf_main return MegatronRLHF(args).main() File "ms-swift/swift/llm/base.py", line 49, in main result = self.run() File "ms-swift/swift/megatron/train/sft.py", line 63, in run self.trainer.train(train_dataset, val_dataset, data_collator) File "ms-swift/swift/megatron/trainers/grpo_trainer.py", line 66, in train super().train(train_dataset, val_dataset, data_collator) File "ms-swift/swift/megatron/trainers/base.py", line 990, in train pretrain( File "megatron/training/training.py", line 710, in pretrain iteration, num_floating_point_operations_so_far = train( File "megatron/training/training.py", line 2122, in train ) = train_step( File "ms-swift/swift/megatron/trainers/base.py", line 496, in train_step new_data_iterator = self._replace_data_iterator(data_iterator, model) File "ms-swift/swift/megatron/trainers/grpo_trainer.py", line 468, in _replace_data_iterator micro_batch_data = self._generate_and_score_completions(rollout_batch) File "ms-swift/swift/megatron/trainers/grpo_trainer.py", line 557, in _generate_and_score_completions micro_batch_data = _get_encoded_batch(micro_batch_data, micro_batch_advantages) File "ms-swift/swift/megatron/trainers/grpo_trainer.py", line 521, in _get_encoded_batch advantages = torch.repeat_interleave(advantages, lengths) RuntimeError: repeats must have the same size as input along dim, but got repeats.size(0) = 2 and input.size(0) = 1 ``` Signed-off-by: Hollow Man <[email protected]>
Summary of ChangesHello @HollowMan6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical bug in the Megatron GRPO trainer related to tensor dimension mismatches, specifically when using Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively resolves a repeat_interleave mismatch error that occurred during GRPO training with context parallelism. The fix correctly simplifies the logic for determining sequence lengths by using the length returned directly from the template's encode method, rather than relying on brittle calculations from position_ids. Furthermore, the introduction of a padding and trimming mechanism within the loss function is a robust solution to prevent tensor shape mismatches, which is crucial for stability in distributed training environments. The changes are well-reasoned and directly address the bug described.
PR type
PR information
This PR fix the logic to ensure template encoding captures true packed lengths and propagates them through seq_lengths, advantages expansion, and truncation masks. It normalizes packed-sequence metadata in loss_func by aligning lengths_with_padding with the completion tensor width and padding/trimming per-token log-probabilities so downstream splits stay consistent. It guards against invalid metadata by validating the adjusted final length and reusing the helper for ref/old log-probs.
Error stack trace:
Experiment results
Script to reproduce on with 8 GPUs:
✨ Presented to you with Mind Lab - A Lab for Experiential Intelligence.