Skip to content

Commit 5728eda

Browse files
authored
[megatron] Support megatron CP/non-padding-free more tasks (#9516)
1 parent f38a8f3 commit 5728eda

13 files changed

Lines changed: 96 additions & 46 deletions

File tree

swift/megatron/trainers/base.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@
3232
get_optimizer_param_scheduler, get_padding_to, init_persistent_async_worker,
3333
initialize_tp_communicators, load_mcore_checkpoint,
3434
logical_and_across_model_parallel_group, maybe_finalize_async_save,
35-
prepare_mcore_model, reduce_max_stat_across_model_parallel_group,
36-
save_mcore_checkpoint, should_disable_forward_pre_hook, warmup_jit_function,
37-
wrap_model)
35+
prepare_mcore_model, reconstruct_tensor_cp,
36+
reduce_max_stat_across_model_parallel_group, save_mcore_checkpoint,
37+
should_disable_forward_pre_hook, warmup_jit_function, wrap_model)
3838
from swift.template import Template
3939
from swift.trainers import dynamic_gradient_checkpointing
4040
from swift.trainers.utils import patch_modelscope_hub_timeout
4141
from swift.utils import (deep_getattr, gc_collect, get_current_device, get_last_valid_indices, get_logger, is_last_rank,
4242
is_master, ms_logger_context)
4343
from .batch_sampler import MegatronPretrainingRandomSampler, MegatronPretrainingSampler
44-
from .utils import TrainerState, build_streaming_dataloader
44+
from .utils import TrainerState, build_streaming_dataloader, prepare_batch
4545

4646
try:
4747
from megatron.core.optimizer import param_group_identifier_keys
@@ -985,7 +985,6 @@ def _should_use_npu_generated_attention_mask(self, args) -> bool:
985985
and getattr(args, 'attention_backend', None) != 'local' and getattr(args, 'use_flash_attn', False))
986986

987987
def _prepare_batch(self, data, vp_stage=None, num_samples=None):
988-
from .utils import prepare_batch
989988
return prepare_batch(self.args, data, vp_stage=vp_stage, num_samples=num_samples)
990989

991990
def get_batch(self, data_iterator, vp_stage=None):
@@ -1014,11 +1013,19 @@ def _collect_config_info(self) -> Dict[str, str]:
10141013
return {}
10151014

10161015
def get_last_tokens(self, output_tensor, packed_seq_params=None, attention_mask=None, num_samples=None):
1016+
if self.args.context_parallel_size > 1:
1017+
output_tensor = reconstruct_tensor_cp(output_tensor, packed_seq_params, dim=1)
10171018
if packed_seq_params is None:
1018-
last_token_idx = get_last_valid_indices((~attention_mask[:, 0, -1]).long())
1019+
# Compatible with attention_mask_2d
1020+
if attention_mask.dim() > 2:
1021+
attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0
1022+
last_token_idx = get_last_valid_indices(attention_mask.long())
10191023
last_tokens = output_tensor[torch.arange(output_tensor.shape[0]), last_token_idx]
10201024
else:
10211025
num_samples = num_samples or packed_seq_params.num_samples
1022-
last_token_idx = packed_seq_params.cu_seqlens_q[1:num_samples + 1] - 1
1026+
if self.args.context_parallel_size > 1:
1027+
last_token_idx = packed_seq_params.cu_seqlens_q[:num_samples] + packed_seq_params.seq_lens - 1
1028+
else:
1029+
last_token_idx = packed_seq_params.cu_seqlens_q[1:num_samples + 1] - 1
10231030
last_tokens = output_tensor[0, last_token_idx]
10241031
return last_tokens

swift/megatron/trainers/dpo_trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed
7575
return loss, metric
7676

7777
def forward_step(self, data_iterator, model):
78-
# Get the batch.
7978
unwrapped_model = model.module.module
8079
input_tensor = unwrapped_model.get_input_tensor()
8180
vp_stage = unwrapped_model.vp_stage

swift/megatron/trainers/embedding_trainer.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@ class MegatronEmbeddingTrainer(BaseMegatronTrainer):
1414

1515
def __init__(self, args, template):
1616
super().__init__(args, template)
17-
if args.context_parallel_size > 1:
18-
raise ValueError('Currently `task_type="embedding"` does not support context parallelism.')
19-
if not args.padding_free:
20-
raise ValueError('Currently, task_type embedding only supports padding_free.')
2117
self._loss_func = loss_map[args.loss_type](args, self)
2218
eval_metric = 'infonce' if args.loss_type == 'infonce' else 'paired'
2319
self.eval_metrics = eval_metrics_map[eval_metric](args, self)
2420

25-
def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params=None):
21+
def loss_func(self,
22+
output_tensor: torch.Tensor,
23+
*,
24+
labels: torch.Tensor,
25+
packed_seq_params=None,
26+
attention_mask=None):
2627
training = self.unwrapped_models[0].training
27-
last_hidden_state = self.get_last_tokens(output_tensor, packed_seq_params)
28+
last_hidden_state = self.get_last_tokens(output_tensor, packed_seq_params, attention_mask)
2829
if not training:
2930
self.eval_metrics.update(last_hidden_state.detach(), labels)
3031
loss = self._loss_func({'last_hidden_state': last_hidden_state}, labels)
@@ -33,11 +34,14 @@ def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed
3334
return loss, metric
3435

3536
def forward_step(self, data_iterator, model):
36-
# Get the batch.
3737
vp_stage = model.module.module.vp_stage
3838
data = self.get_batch(data_iterator, vp_stage)
3939
labels = data.pop('labels', None)
4040
output_tensor = model(**data)
41-
packed_seq_params = data.get('packed_seq_params')
42-
loss_func = partial(self.loss_func, labels=labels, packed_seq_params=packed_seq_params)
41+
loss_func = partial(
42+
self.loss_func,
43+
labels=labels,
44+
packed_seq_params=data.get('packed_seq_params'),
45+
attention_mask=data.get('attention_mask')
46+
if data.get('attention_mask') is not None else data.get('attention_mask_2d'))
4347
return output_tensor, loss_func

swift/megatron/trainers/kto_trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def _get_input_tensor(input_tensor, is_KL: bool, is_ref: bool, length: int, dim:
110110
return res
111111

112112
def forward_step(self, data_iterator, model):
113-
# Get the batch.
114113
unwrapped_model = model.module.module
115114
input_tensor = unwrapped_model.get_input_tensor()
116115
vp_stage = unwrapped_model.vp_stage

swift/megatron/trainers/reranker_trainer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ class MegatronRerankerTrainer(BaseMegatronTrainer):
1616

1717
def __init__(self, args, template):
1818
super().__init__(args, template)
19-
if args.context_parallel_size > 1:
20-
raise ValueError('Currently `task_type="reranker/generative_reranker"` does not support '
21-
'context parallelism.')
22-
if not args.padding_free:
23-
raise ValueError('Currently, task_type reranker/generative_reranker only supports padding_free.')
2419
self._loss_func = loss_map[args.loss_type](args, self)
2520
self.eval_metrics = eval_metrics_map['reranker'](args, self)
2621

@@ -36,9 +31,14 @@ def _get_listwise_reranker_preds(logits, labels):
3631
labels = torch.tensor([0] * (len(positive_indices) - 1), device=preds.device)
3732
return preds, labels
3833

39-
def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params=None):
34+
def loss_func(self,
35+
output_tensor: torch.Tensor,
36+
*,
37+
labels: torch.Tensor,
38+
packed_seq_params=None,
39+
attention_mask=None):
4040
training = self.unwrapped_models[0].training
41-
logits = self.get_last_tokens(output_tensor, packed_seq_params)
41+
logits = self.get_last_tokens(output_tensor, packed_seq_params, attention_mask)
4242
loss = self._loss_func(ModelOutputs(logits=logits), labels)
4343
args = self.args
4444
logits_detach = logits.detach().squeeze(-1)
@@ -60,11 +60,14 @@ def prepare_model(self):
6060
lm_model.tokenizer = self.template.tokenizer
6161

6262
def forward_step(self, data_iterator, model):
63-
# Get the batch.
6463
vp_stage = model.module.module.vp_stage
6564
data = self.get_batch(data_iterator, vp_stage)
6665
labels = data.pop('labels', None)
6766
output_tensor = model(**data)
68-
packed_seq_params = data.get('packed_seq_params')
69-
loss_func = partial(self.loss_func, labels=labels, packed_seq_params=packed_seq_params)
67+
loss_func = partial(
68+
self.loss_func,
69+
labels=labels,
70+
packed_seq_params=data.get('packed_seq_params'),
71+
attention_mask=data.get('attention_mask')
72+
if data.get('attention_mask') is not None else data.get('attention_mask_2d'))
7073
return output_tensor, loss_func

swift/megatron/trainers/reward_trainer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111

1212
class MegatronRewardTrainer(MegatronRLHFTrainer):
1313

14-
def __init__(self, args, template):
15-
super().__init__(args, template)
16-
assert args.context_parallel_size == 1, 'Currently `rlhf_type="rm"` does not support context parallelism.'
17-
1814
def loss_func(self, output_tensor, *, data):
1915
packed_seq_params = data.get('packed_seq_params')
2016
margin = data.pop('margin', None)
@@ -43,7 +39,6 @@ def loss_func(self, output_tensor, *, data):
4339
return loss, metric
4440

4541
def forward_step(self, data_iterator, model):
46-
# Get the batch.
4742
vp_stage = model.module.module.vp_stage
4843
data = self.get_batch(data_iterator, vp_stage)
4944
data.pop('loss_scale', None)

swift/megatron/trainers/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ class MegatronTrainer(BaseMegatronTrainer):
1919

2020
def seq_cls_loss_func(self, output_tensor, *, labels: torch.Tensor, packed_seq_params=None, attention_mask=None):
2121
args = self.args
22-
if args.context_parallel_size > 1:
23-
raise ValueError('Currently `task_type="seq_cls"` does not support context parallelism.')
2422
logits = self.get_last_tokens(output_tensor, packed_seq_params, attention_mask)
2523
num_labels = args.num_labels
2624
acc = None
@@ -106,7 +104,6 @@ def _compute_channel_loss(self, losses, loss_mask, channels, packed_seq_params=N
106104
return new_metrics
107105

108106
def forward_step(self, data_iterator, model):
109-
# Get the batch.
110107
vp_stage = model.module.module.vp_stage
111108
data = self.get_batch(data_iterator, vp_stage)
112109
loss_scale = data.pop('loss_scale', None)
@@ -121,7 +118,8 @@ def forward_step(self, data_iterator, model):
121118
self.seq_cls_loss_func,
122119
labels=labels,
123120
packed_seq_params=packed_seq_params,
124-
attention_mask=data.get('attention_mask'))
121+
attention_mask=data.get('attention_mask')
122+
if data.get('attention_mask') is not None else data.get('attention_mask_2d'))
125123
else:
126124
loss_func = partial(
127125
self.loss_func,

swift/megatron/trainers/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, Optional
1212

1313
from swift.dataloader import DataLoaderDispatcher
14+
from swift.megatron.utils import get_batch_on_this_cp_rank, get_packed_seq_params
1415
from swift.utils import empty_cache, get_current_device, get_logger, to_device
1516

1617
logger = get_logger()
@@ -312,7 +313,7 @@ class TrainerState:
312313
should_log: bool = False
313314

314315
iteration: int = 0
315-
consumed_train_samples = 0
316+
consumed_train_samples: int = 0
316317
# compat transformers
317318
max_steps: Optional[int] = None
318319

@@ -357,10 +358,10 @@ def prepare_batch(args, data, vp_stage=None, num_samples=None):
357358
358359
Extracted from BaseMegatronTrainer._prepare_batch for reuse in ray workers.
359360
"""
360-
from swift.megatron.utils import get_batch_on_this_cp_rank, get_packed_seq_params
361361
batch = get_batch_on_this_pp_rank(args, data, vp_stage=vp_stage)
362362
if num_samples is None:
363363
num_samples = batch.pop('num_samples')
364+
seq_lens = batch.pop('seq_lens', None)
364365
text_position_ids = batch.pop('text_position_ids', None)
365366
if text_position_ids is None:
366367
text_position_ids = batch.get('position_ids')
@@ -373,6 +374,8 @@ def prepare_batch(args, data, vp_stage=None, num_samples=None):
373374
if args.padding_free and text_position_ids is not None:
374375
batch['packed_seq_params'] = get_packed_seq_params(text_position_ids)
375376
batch['packed_seq_params'].num_samples = num_samples
377+
if seq_lens is not None:
378+
batch['packed_seq_params'].seq_lens = torch.tensor(seq_lens, device=text_position_ids.device)
376379
batch = get_batch_on_this_cp_rank(args, batch)
377380
return batch
378381

swift/megatron/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
from .patcher import patch_merge_fn, patch_torch_dist_shard
1010
from .router_replay_utils import (RouterReplayHelper, apply_router_replay_patch, get_local_topk_idx_for_current_rank,
1111
get_router_replay_data, set_router_replay_data)
12-
from .utils import forward_step_helper, get_packed_seq_params, get_padding_to, prepare_mcore_model
12+
from .utils import (forward_step_helper, get_packed_seq_params, get_padding_to, prepare_mcore_model,
13+
reconstruct_tensor_cp)

swift/megatron/utils/convert_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from contextlib import contextmanager, nullcontext
88
from megatron.core import mpu
99
from megatron.core.extensions.transformer_engine import TEDotProductAttention
10+
from megatron.core.ssm.mamba_context_parallel import _undo_attention_load_balancing
1011
from megatron.core.tensor_parallel import VocabParallelEmbedding
1112
from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region,
1213
gather_from_tensor_model_parallel_region)
@@ -267,7 +268,6 @@ def test_convert_precision(args, hf_model, mg_model, template, test_convert_dtyp
267268
if mg_logits is not None:
268269
mg_logits = gather_from_tensor_model_parallel_region(mg_logits)
269270
if args.context_parallel_size > 1:
270-
from megatron.core.ssm.mamba_context_parallel import _undo_attention_load_balancing
271271
if mg_logits is not None:
272272
mg_logits = gather_from_sequence_parallel_region(
273273
mg_logits.transpose(0, 1), group=mpu.get_context_parallel_group())

0 commit comments

Comments
 (0)