@@ -74,9 +74,6 @@ class ParallelMode(Enum):
74
74
# real data parallel for isp
75
75
ISP_DATA = "isp_data"
76
76
77
- # grouped query attention
78
- GQA = "gqa"
79
-
80
77
# sequence 2D parallel
81
78
HEAD = "head"
82
79
CONTEXT = "context"
@@ -1454,84 +1451,6 @@ def init_dist_group(self, use_cpu: bool = False):
1454
1451
return local_rank , group_world_size , process_group , cpu_group , ranks_in_group , mode
1455
1452
1456
1453
1457
- class Initializer_GQA (ProcessGroupInitializer ):
1458
- """A ProcessGroupInitializer for allreduce kv gradients with common attention head.
1459
-
1460
- Args:
1461
- rank (int): The rank of current process.
1462
- world_size (int): Size of whole communication world.
1463
- weight_parallel_size (int): Size of model weight parallel.
1464
- weight_data_parallel_size (int): Size of data parallel for common weight.
1465
- sequence_parallel_size (int): Size of data sequence parallel.
1466
- data_parallel_size (int): Size of data parallel.
1467
- pipeline_parallel_size (int): Size of pipeline parallel.
1468
- tensor_parallel_size (int): Size of tensor parallel.
1469
- zero1_parallel_size (int): Size of zero1 parallel.
1470
- nettest_parallel_size (int): Size of net testing parallel.
1471
- expert_parallel_size (int): Size of expert parallel.
1472
- """
1473
-
1474
- def __init__ (self , * args , ** kwargs ):
1475
- self .num_attention_heads = kwargs .pop ("num_attention_heads" )
1476
- self .num_kv_attention_heads = kwargs .pop ("num_kv_attention_heads" )
1477
- super ().__init__ (* args , ** kwargs )
1478
- self .kv_head_repeats_num = self .tensor_parallel_size // self .num_kv_attention_heads
1479
- self .num_kv_group_per_tp = self .num_kv_attention_heads
1480
- self .num_kv_groups = self .num_kv_group_per_tp * self .data_parallel_size
1481
-
1482
- assert self .world_size % self .tensor_parallel_size == 0
1483
- assert self .world_size % (self .pipeline_parallel_size * self .tensor_parallel_size ) == 0
1484
- assert self .pipeline_parallel_size == 1
1485
-
1486
- def init_dist_group (self , use_cpu : bool = False ):
1487
- """Initialize weight's data parallel groups, and assign local_ranks and groups to each gpu.
1488
-
1489
- Returns:
1490
- Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
1491
- A WEIGHT_DATA parallelism's information tuple.
1492
-
1493
- n=128 sp=32 wp=64 zo1=1 with nopp
1494
- sp groups: [0-31] [32-63] [64-95] [96-127]
1495
- wp groups: [0-63] [64-127]
1496
- kv_head groups: [0,1,2,3] [4,5,6,7] [8,9,10,11] [12,13,14,15]
1497
- [16,17,18,19] [20,21,22,23] [24,25,26,27] [28,29,30,31]
1498
- ...
1499
- ...
1500
- ...
1501
- """
1502
- local_rank = None
1503
- ranks_in_group = None
1504
- process_group = None
1505
- cpu_group = None
1506
- group_world_size = None
1507
- mode = ParallelMode .GQA
1508
-
1509
- for i in range (self .data_parallel_size ):
1510
- for j in range (self .num_kv_group_per_tp ):
1511
- ranks = [
1512
- i * self .tensor_parallel_size + j * self .kv_head_repeats_num + k
1513
- for k in range (self .kv_head_repeats_num )
1514
- ]
1515
- group = dist .new_group (ranks , timeout = LLM_NCCL_TIMEOUT )
1516
- if use_cpu :
1517
- group_cpu = (
1518
- dist .new_group (ranks , backend = "gloo" , timeout = LLM_NCCL_TIMEOUT )
1519
- if dist .get_backend () != "gloo"
1520
- else group
1521
- )
1522
- else :
1523
- group_cpu = None
1524
-
1525
- if self .rank in ranks :
1526
- local_rank = ranks .index (self .rank )
1527
- group_world_size = len (ranks )
1528
- process_group = group
1529
- cpu_group = group_cpu
1530
- ranks_in_group = ranks
1531
-
1532
- return local_rank , group_world_size , process_group , cpu_group , ranks_in_group , mode
1533
-
1534
-
1535
1454
class Initializer_2D_SEQUENCE_PARALLEL (ProcessGroupInitializer ):
1536
1455
"""
1537
1456
A ProcessGroupInitializer for 2D sequence parallel.
0 commit comments