|
12 | 12 | logger = logging.getLogger(__name__) |
13 | 13 |
|
14 | 14 |
|
15 | | -def create_megatron_parallel_state( |
16 | | - model: torch.nn.Module | Sequence[torch.nn.Module] | None = None, |
17 | | -) -> ParallelState: |
18 | | - vpp_size_value = mpu.get_virtual_pipeline_model_parallel_world_size() |
19 | | - if vpp_size_value is None: |
20 | | - vpp_size = 1 |
21 | | - microbatch_group_size_per_vp_stage = None |
22 | | - elif vpp_size_value > 1: |
23 | | - assert model is not None |
24 | | - model_to_check = model[0] if isinstance(model, Sequence) else model |
25 | | - config = get_model_config(model_to_check) |
26 | | - vpp_size = vpp_size_value |
27 | | - microbatch_group_size_per_vp_stage = config.microbatch_group_size_per_vp_stage |
28 | | - else: |
29 | | - vpp_size = 1 |
30 | | - microbatch_group_size_per_vp_stage = None |
| 15 | +class MegatronParallelState(ParallelState): |
| 16 | + """ |
| 17 | + ParallelState for Megatron backend, initialized from mpu module. |
| 18 | + """ |
| 19 | + |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + model: torch.nn.Module | Sequence[torch.nn.Module] | None = None, |
| 23 | + ): |
| 24 | + super().__init__() |
| 25 | + |
| 26 | + self.dp_rank = mpu.get_data_parallel_rank(with_context_parallel=False) |
| 27 | + self.cp_rank = mpu.get_context_parallel_rank() |
| 28 | + self.tp_rank = mpu.get_tensor_model_parallel_rank() |
| 29 | + self.dp_cp_rank = mpu.get_data_parallel_rank(with_context_parallel=True) |
| 30 | + self.dp_src_rank = mpu.get_data_parallel_src_rank(with_context_parallel=True) |
| 31 | + |
| 32 | + self.dp_size = mpu.get_data_parallel_world_size(with_context_parallel=False) |
| 33 | + self.dp_cp_size = mpu.get_data_parallel_world_size(with_context_parallel=True) |
| 34 | + self.cp_size = mpu.get_context_parallel_world_size() |
| 35 | + self.tp_size = mpu.get_tensor_model_parallel_world_size() |
31 | 36 |
|
32 | | - parallel_state = ParallelState( |
33 | | - dp_rank=mpu.get_data_parallel_rank(with_context_parallel=False), |
34 | | - dp_src_rank=mpu.get_data_parallel_src_rank(with_context_parallel=True), |
35 | | - dp_size=mpu.get_data_parallel_world_size(with_context_parallel=False), |
36 | | - cp_rank=mpu.get_context_parallel_rank(), |
37 | | - cp_size=mpu.get_context_parallel_world_size(), |
38 | | - dp_cp_rank=mpu.get_data_parallel_rank(with_context_parallel=True), |
39 | | - dp_cp_size=mpu.get_data_parallel_world_size(with_context_parallel=True), |
40 | | - dp_group=mpu.get_data_parallel_group(with_context_parallel=False), |
41 | | - dp_cp_group=mpu.get_data_parallel_group(with_context_parallel=True), |
42 | | - dp_cp_group_gloo=mpu.get_data_parallel_group_gloo(with_context_parallel=True), |
43 | | - cp_group=mpu.get_context_parallel_group(), |
44 | | - tp_size=mpu.get_tensor_model_parallel_world_size(), |
45 | | - tp_rank=mpu.get_tensor_model_parallel_rank(), |
46 | | - tp_group=mpu.get_tensor_model_parallel_group(), |
47 | | - is_pp_last_stage=mpu.is_pipeline_last_stage(), |
48 | | - vpp_size=vpp_size, |
49 | | - microbatch_group_size_per_vp_stage=microbatch_group_size_per_vp_stage, |
50 | | - ) |
| 37 | + self.dp_group = mpu.get_data_parallel_group(with_context_parallel=False) |
| 38 | + self.dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) |
| 39 | + self.dp_cp_group_gloo = mpu.get_data_parallel_group_gloo(with_context_parallel=True) |
| 40 | + self.cp_group = mpu.get_context_parallel_group() |
| 41 | + self.tp_group = mpu.get_tensor_model_parallel_group() |
51 | 42 |
|
52 | | - return parallel_state |
| 43 | + self.is_pp_last_stage = mpu.is_pipeline_last_stage() |
| 44 | + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() |
| 45 | + if vpp_size is None: |
| 46 | + self.vpp_size = 1 |
| 47 | + self.microbatch_group_size_per_vp_stage = None |
| 48 | + elif vpp_size > 1: |
| 49 | + assert model is not None |
| 50 | + model_to_check = model[0] if isinstance(model, Sequence) else model |
| 51 | + config = get_model_config(model_to_check) |
| 52 | + self.vpp_size = vpp_size |
| 53 | + self.microbatch_group_size_per_vp_stage = config.microbatch_group_size_per_vp_stage |
53 | 54 |
|
54 | 55 |
|
55 | 56 | def get_packed_seq_params(batch: dict[str, torch.Tensor], args: Namespace) -> PackedSeqParams: |
|
0 commit comments