diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 50f492054831ad..34007a612d6326 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -869,7 +869,6 @@ def __init__( self._sharding_parallel_id = self._get_sharding_parallel_id() self._sep_parallel_id = self._get_parallel_id(self._dense_topo, "sep") - self._cp_parallel_id = self._get_parallel_id(self._cp_topo, "context") self._cp_sharding_degree = self._cp_topo.get_dim("cp_sharding") self.stage_id = self._get_parallel_id(self._moe_topo, "pipe") @@ -1032,6 +1031,8 @@ def __init__( ) ) + self._cp_parallel_id = self._cp_group.index(self.global_rank) + self._cp_sharding_group, self._cp_sharding_comm_group = ( self.build_context_sharding_group( self._dense_topo, diff --git a/test/collective/test_comm_group_num.py b/test/collective/test_comm_group_num.py index a87f6a49b0620a..6060cd2dde927f 100644 --- a/test/collective/test_comm_group_num.py +++ b/test/collective/test_comm_group_num.py @@ -58,5 +58,34 @@ def test_comm_group_num(self): ) -if __name__ == '__main__': +class TestContextParallelRankID(unittest.TestCase): + def setUp(self): + paddle.distributed.init_parallel_env() + group_names = [ + "moe_sharding", + "sharding", + "pipe", + "sep", + "data", + "expert", + "model", + "context", + ] + dims = [1, 2, 1, 1, 1, 8, 4, 2] + self.hcg = tp.EPHybridCommunicateGroup(group_names, dims) + self.dp_rank = self.hcg.get_data_parallel_rank() + self.mp_rank = self.hcg.get_model_parallel_rank() + self.pp_rank = self.hcg.get_stage_id() + self.group = self.hcg.get_context_parallel_group() + self.cp_degree = self.hcg.get_context_parallel_world_size() + self.cp_rank = self.hcg.get_context_parallel_rank() + + def test_cp_rank_id(self): + assert ( + self.hcg.get_context_parallel_rank() + == self.hcg._get_context_parallel_id() + ) + + +if __name__ == "__main__": unittest.main()