Skip to content
4 changes: 2 additions & 2 deletions swift/megatron/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,8 @@ def _should_use_npu_generated_attention_mask(self, args) -> bool:
return (is_torch_npu_available() and args.task_type == 'causal_lm' and not args.padding_free
and getattr(args, 'attention_backend', None) != 'local' and getattr(args, 'use_flash_attn', False))

def _prepare_batch(self, data, vp_stage=None, num_samples=None):
return prepare_batch(self.args, data, vp_stage=vp_stage, num_samples=num_samples)
def _prepare_batch(self, data, vp_stage=None):
return prepare_batch(self.args, data, vp_stage=vp_stage)
Comment thread
Jintao-Huang marked this conversation as resolved.

def get_batch(self, data_iterator, vp_stage=None):
"""Generate a batch."""
Expand Down
9 changes: 3 additions & 6 deletions swift/megatron/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,15 +353,13 @@ def _should_use_npu_attention_mask(args) -> bool:
and getattr(args, 'attention_backend', None) != 'local' and getattr(args, 'use_flash_attn', False))


def prepare_batch(args, data, vp_stage=None, num_samples=None):
def prepare_batch(args, data, vp_stage=None):
"""Prepare a micro-batch for Megatron forward: PP slicing, packed_seq_params, CP slicing.

Extracted from BaseMegatronTrainer._prepare_batch for reuse in ray workers.
"""
batch = get_batch_on_this_pp_rank(args, data, vp_stage=vp_stage)
if num_samples is None:
num_samples = batch.pop('num_samples')
seq_lens = batch.pop('seq_lens', None)
seq_lens = batch.pop('seq_lens')
Comment thread
Jintao-Huang marked this conversation as resolved.
Outdated
text_position_ids = batch.pop('text_position_ids', None)
if text_position_ids is None:
text_position_ids = batch.get('position_ids')
Expand All @@ -373,7 +371,6 @@ def prepare_batch(args, data, vp_stage=None, num_samples=None):
batch.pop('attention_mask_2d', None)
if args.padding_free and text_position_ids is not None:
batch['packed_seq_params'] = get_packed_seq_params(text_position_ids)
batch['packed_seq_params'].num_samples = num_samples
if seq_lens is not None:
batch['packed_seq_params'].seq_lens = torch.tensor(seq_lens, device=text_position_ids.device)
batch = get_batch_on_this_cp_rank(args, batch)
Expand Down Expand Up @@ -426,7 +423,7 @@ def compute_per_token_logps_fn(model, args, data_iterator, temperature=1.0, no_g

packed_seq_params = data.get('packed_seq_params')
if packed_seq_params is not None:
num_samples = packed_seq_params.num_samples
num_samples = len(packed_seq_params.seq_lens)
Comment thread
Jintao-Huang marked this conversation as resolved.
Outdated
else:
input_ids = data.get('input_ids')
num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0]
Expand Down
9 changes: 0 additions & 9 deletions swift/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,6 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int
from swift.dataset import RowPreprocessor
if self.packing and isinstance(batch[0], list):
batch = sum(batch, start=[])
num_samples = len(batch)
if self.task_type == 'causal_lm':
if self.mode in {'transformers', 'train'}:
res = self._data_collator(batch, padding_to=padding_to)
Expand All @@ -1655,10 +1654,6 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int
extra_kwargs = [b['_extra_kwargs'] for b in batch if b.get('_extra_kwargs') is not None]
extra_kwargs = RowPreprocessor.rows_to_batched(extra_kwargs)
res.update({k: v for k, v in extra_kwargs.items() if k not in res})
if 'num_samples' in res:
num_samples = res.pop('num_samples')
if self.use_megatron:
res['num_samples'] = num_samples
return res

@staticmethod
Expand Down Expand Up @@ -1765,9 +1760,7 @@ def _embedding_data_collator(self,
for prefix in indexes:
new_batch += self._fetch_inputs_startswith([b], prefix)
labels.extend(b.get('labels', []))
num_samples = len(new_batch)
res = self._data_collator(new_batch, padding_to=padding_to)
res['num_samples'] = num_samples
if labels:
res['labels'] = torch.tensor(labels, dtype=torch.float32)
return res
Expand Down Expand Up @@ -1801,9 +1794,7 @@ def _reranker_data_collator(self,
for key in b.keys() if isinstance(b[key], list) and b[key][j + positive_num] is not None
})
labels_list.append(0)
num_samples = len(new_batch)
res = self._data_collator(new_batch, padding_to=padding_to)
res['num_samples'] = num_samples
if labels_list:
res['labels'] = torch.tensor(labels_list, dtype=torch.long)
else:
Expand Down
Loading