[megatron] megatron remove num_samples#9522
Conversation
There was a problem hiding this comment.
Code Review
This pull request simplifies batch preparation and data collation by removing the num_samples parameter and instead deriving it dynamically from the length of seq_lens when packed_seq_params is present. While this cleanup streamlines the code, the review highlights a critical regression in get_last_tokens where packed_seq_params.num_samples is still accessed. Additionally, the feedback suggests adding safety checks when popping seq_lens from the batch and when accessing it within packed_seq_params to prevent potential KeyError or AttributeError exceptions.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _prepare_batch(self, data, vp_stage=None): | ||
| return prepare_batch(self.args, data, vp_stage=vp_stage) |
There was a problem hiding this comment.
The removal of num_samples from prepare_batch and packed_seq_params introduces a critical regression in get_last_tokens (lines 1015-1031 of this file), which is not modified in this PR but still relies on packed_seq_params.num_samples:
num_samples = num_samples or packed_seq_params.num_samplesSince packed_seq_params.num_samples is no longer set, this will raise an AttributeError or result in None, causing subsequent slicing operations like packed_seq_params.cu_seqlens_q[:num_samples] to fail.
Please update get_last_tokens to use len(packed_seq_params.seq_lens) instead:
num_samples = num_samples or len(packed_seq_params.seq_lens)| if packed_seq_params is not None: | ||
| num_samples = packed_seq_params.num_samples | ||
| num_samples = len(packed_seq_params.seq_lens) |
There was a problem hiding this comment.
If packed_seq_params is not None but seq_lens was not set (e.g., if seq_lens was None in prepare_batch), accessing packed_seq_params.seq_lens will raise an AttributeError or TypeError when calling len(). It is safer to check if seq_lens is present and not None before calling len() on it, falling back to the else block if it is missing.
| if packed_seq_params is not None: | |
| num_samples = packed_seq_params.num_samples | |
| num_samples = len(packed_seq_params.seq_lens) | |
| if packed_seq_params is not None and getattr(packed_seq_params, 'seq_lens', None) is not None: | |
| num_samples = len(packed_seq_params.seq_lens) |
| if num_samples is None: | ||
| num_samples = batch.pop('num_samples') | ||
| seq_lens = batch.pop('seq_lens', None) | ||
| seq_lens = batch.pop('seq_lens') |
There was a problem hiding this comment.
Using batch.pop('seq_lens') without a default value will raise a KeyError if seq_lens is not present in the batch (for example, in custom data pipelines or testing environments). It is safer to use .pop('seq_lens', None) to handle missing keys gracefully.
| seq_lens = batch.pop('seq_lens') | |
| seq_lens = batch.pop('seq_lens', None) |
823caf5 to
2cd141b
Compare
No description provided.