Skip to content

Commit e60a50a

Browse files
feat(parallel_context.py): remove useless gqa process group (#390)
1 parent 5ad2eb0 commit e60a50a

File tree

2 files changed

+0
-94
lines changed

2 files changed

+0
-94
lines changed

internlm/core/context/parallel_context.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -662,19 +662,6 @@ def init_parallel_groups(self):
662662
group_configs = generate_parallel_group_configs(parallel_strategy, parallel_sizes, enable_moe)
663663
group_results = create_parallel_process_groups(world_size, rank, group_configs, with_cpu_group=False)
664664

665-
# process group for extra gqa tensor parallel.
666-
if (
667-
"num_kv_attention_heads" in self.config.model
668-
and self.config.model.num_kv_attention_heads < self.tensor_parallel_size
669-
):
670-
group_results.append(
671-
create_single_process_group(
672-
world_size,
673-
rank,
674-
GroupConfig(ParallelMode.GQA, self.tensor_parallel_size // self.num_kv_attention_heads),
675-
)
676-
)
677-
678665
# process group for network test.
679666
group_results.append(
680667
create_single_process_group(

internlm/core/context/process_group_initializer.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,6 @@ class ParallelMode(Enum):
7474
# real data parallel for isp
7575
ISP_DATA = "isp_data"
7676

77-
# grouped query attention
78-
GQA = "gqa"
79-
8077
# sequence 2D parallel
8178
HEAD = "head"
8279
CONTEXT = "context"
@@ -1454,84 +1451,6 @@ def init_dist_group(self, use_cpu: bool = False):
14541451
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
14551452

14561453

1457-
class Initializer_GQA(ProcessGroupInitializer):
1458-
"""A ProcessGroupInitializer for allreduce kv gradients with common attention head.
1459-
1460-
Args:
1461-
rank (int): The rank of current process.
1462-
world_size (int): Size of whole communication world.
1463-
weight_parallel_size (int): Size of model weight parallel.
1464-
weight_data_parallel_size (int): Size of data parallel for common weight.
1465-
sequence_parallel_size (int): Size of data sequence parallel.
1466-
data_parallel_size (int): Size of data parallel.
1467-
pipeline_parallel_size (int): Size of pipeline parallel.
1468-
tensor_parallel_size (int): Size of tensor parallel.
1469-
zero1_parallel_size (int): Size of zero1 parallel.
1470-
nettest_parallel_size (int): Size of net testing parallel.
1471-
expert_parallel_size (int): Size of expert parallel.
1472-
"""
1473-
1474-
def __init__(self, *args, **kwargs):
1475-
self.num_attention_heads = kwargs.pop("num_attention_heads")
1476-
self.num_kv_attention_heads = kwargs.pop("num_kv_attention_heads")
1477-
super().__init__(*args, **kwargs)
1478-
self.kv_head_repeats_num = self.tensor_parallel_size // self.num_kv_attention_heads
1479-
self.num_kv_group_per_tp = self.num_kv_attention_heads
1480-
self.num_kv_groups = self.num_kv_group_per_tp * self.data_parallel_size
1481-
1482-
assert self.world_size % self.tensor_parallel_size == 0
1483-
assert self.world_size % (self.pipeline_parallel_size * self.tensor_parallel_size) == 0
1484-
assert self.pipeline_parallel_size == 1
1485-
1486-
def init_dist_group(self, use_cpu: bool = False):
1487-
"""Initialize weight's data parallel groups, and assign local_ranks and groups to each gpu.
1488-
1489-
Returns:
1490-
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
1491-
A WEIGHT_DATA parallelism's information tuple.
1492-
1493-
n=128 sp=32 wp=64 zo1=1 with nopp
1494-
sp groups: [0-31] [32-63] [64-95] [96-127]
1495-
wp groups: [0-63] [64-127]
1496-
kv_head groups: [0,1,2,3] [4,5,6,7] [8,9,10,11] [12,13,14,15]
1497-
[16,17,18,19] [20,21,22,23] [24,25,26,27] [28,29,30,31]
1498-
...
1499-
...
1500-
...
1501-
"""
1502-
local_rank = None
1503-
ranks_in_group = None
1504-
process_group = None
1505-
cpu_group = None
1506-
group_world_size = None
1507-
mode = ParallelMode.GQA
1508-
1509-
for i in range(self.data_parallel_size):
1510-
for j in range(self.num_kv_group_per_tp):
1511-
ranks = [
1512-
i * self.tensor_parallel_size + j * self.kv_head_repeats_num + k
1513-
for k in range(self.kv_head_repeats_num)
1514-
]
1515-
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
1516-
if use_cpu:
1517-
group_cpu = (
1518-
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
1519-
if dist.get_backend() != "gloo"
1520-
else group
1521-
)
1522-
else:
1523-
group_cpu = None
1524-
1525-
if self.rank in ranks:
1526-
local_rank = ranks.index(self.rank)
1527-
group_world_size = len(ranks)
1528-
process_group = group
1529-
cpu_group = group_cpu
1530-
ranks_in_group = ranks
1531-
1532-
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
1533-
1534-
15351454
class Initializer_2D_SEQUENCE_PARALLEL(ProcessGroupInitializer):
15361455
"""
15371456
A ProcessGroupInitializer for 2D sequence parallel.

0 commit comments

Comments
 (0)