|
32 | 32 | get_optimizer_param_scheduler, get_padding_to, init_persistent_async_worker, |
33 | 33 | initialize_tp_communicators, load_mcore_checkpoint, |
34 | 34 | logical_and_across_model_parallel_group, maybe_finalize_async_save, |
35 | | - prepare_mcore_model, reduce_max_stat_across_model_parallel_group, |
36 | | - save_mcore_checkpoint, should_disable_forward_pre_hook, warmup_jit_function, |
37 | | - wrap_model) |
| 35 | + prepare_mcore_model, reconstruct_tensor_cp, |
| 36 | + reduce_max_stat_across_model_parallel_group, save_mcore_checkpoint, |
| 37 | + should_disable_forward_pre_hook, warmup_jit_function, wrap_model) |
38 | 38 | from swift.template import Template |
39 | 39 | from swift.trainers import dynamic_gradient_checkpointing |
40 | 40 | from swift.trainers.utils import patch_modelscope_hub_timeout |
41 | 41 | from swift.utils import (deep_getattr, gc_collect, get_current_device, get_last_valid_indices, get_logger, is_last_rank, |
42 | 42 | is_master, ms_logger_context) |
43 | 43 | from .batch_sampler import MegatronPretrainingRandomSampler, MegatronPretrainingSampler |
44 | | -from .utils import TrainerState, build_streaming_dataloader |
| 44 | +from .utils import TrainerState, build_streaming_dataloader, prepare_batch |
45 | 45 |
|
46 | 46 | try: |
47 | 47 | from megatron.core.optimizer import param_group_identifier_keys |
@@ -985,7 +985,6 @@ def _should_use_npu_generated_attention_mask(self, args) -> bool: |
985 | 985 | and getattr(args, 'attention_backend', None) != 'local' and getattr(args, 'use_flash_attn', False)) |
986 | 986 |
|
987 | 987 | def _prepare_batch(self, data, vp_stage=None, num_samples=None): |
988 | | - from .utils import prepare_batch |
989 | 988 | return prepare_batch(self.args, data, vp_stage=vp_stage, num_samples=num_samples) |
990 | 989 |
|
991 | 990 | def get_batch(self, data_iterator, vp_stage=None): |
@@ -1014,11 +1013,19 @@ def _collect_config_info(self) -> Dict[str, str]: |
1014 | 1013 | return {} |
1015 | 1014 |
|
1016 | 1015 | def get_last_tokens(self, output_tensor, packed_seq_params=None, attention_mask=None, num_samples=None): |
| 1016 | + if self.args.context_parallel_size > 1: |
| 1017 | + output_tensor = reconstruct_tensor_cp(output_tensor, packed_seq_params, dim=1) |
1017 | 1018 | if packed_seq_params is None: |
1018 | | - last_token_idx = get_last_valid_indices((~attention_mask[:, 0, -1]).long()) |
| 1019 | + # Compatible with attention_mask_2d |
| 1020 | + if attention_mask.dim() > 2: |
| 1021 | + attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0 |
| 1022 | + last_token_idx = get_last_valid_indices(attention_mask.long()) |
1019 | 1023 | last_tokens = output_tensor[torch.arange(output_tensor.shape[0]), last_token_idx] |
1020 | 1024 | else: |
1021 | 1025 | num_samples = num_samples or packed_seq_params.num_samples |
1022 | | - last_token_idx = packed_seq_params.cu_seqlens_q[1:num_samples + 1] - 1 |
| 1026 | + if self.args.context_parallel_size > 1: |
| 1027 | + last_token_idx = packed_seq_params.cu_seqlens_q[:num_samples] + packed_seq_params.seq_lens - 1 |
| 1028 | + else: |
| 1029 | + last_token_idx = packed_seq_params.cu_seqlens_q[1:num_samples + 1] - 1 |
1023 | 1030 | last_tokens = output_tensor[0, last_token_idx] |
1024 | 1031 | return last_tokens |
0 commit comments