@@ -266,15 +266,13 @@ def __init__(
266
266
dtype : torch .dtype = torch .half ,
267
267
device : torch .device = None ,
268
268
activation_checkpointing : float = 0.0 ,
269
- module_shapes : Dict [str , torch .Size ] = None ,
270
269
) -> None :
271
270
self .dtype = dtype
272
271
if device is None :
273
272
self .device = get_current_device ()
274
273
else :
275
274
self .device = device
276
275
self .activation_checkpointing = activation_checkpointing
277
- self .module_shapes = module_shapes
278
276
279
277
280
278
class ISPOverlapState :
@@ -285,7 +283,7 @@ class ISPOverlapState:
285
283
def __init__ (self ) -> None :
286
284
self .num_blocks : int = 0
287
285
self .ckpt_block_num : int = 0
288
- self .isp_outs : List [nn .Module ] = []
286
+ self .isp_prefetch_launch_module : List [nn .Module ] = []
289
287
self .isp_modules : List [nn .Module ] = []
290
288
self .index_to_isp_modules : Dict [int , nn .Module ] = {}
291
289
self .index_to_block : Dict [int , nn .Module ] = {}
@@ -315,16 +313,17 @@ def __init__(
315
313
self .is_moe = is_moe
316
314
self .is_forward = True
317
315
self .reduce_scatter_handlers = {}
318
- self ._module_shapes = {}
319
316
self ._forward_prefetch_prerequisites = []
317
+ self ._forward_overlap_per = self ._get_forward_overlap_granularity ()
318
+ self ._launch_before_module = self ._get_launch_before_module ()
320
319
321
320
# real overlap state for each chunk.
322
321
self ._overlap_states : Dict [int , ISPOverlapState ] = {}
323
322
324
323
# inner interface variables of overlap state.
325
324
self ._num_blocks = None
326
325
self ._ckpt_block_num = None
327
- self ._isp_outs = None
326
+ self ._isp_prefetch_launch_module = None
328
327
self ._isp_modules = None
329
328
# key: isp module; value: module global all-gather op handle
330
329
self ._weight_global_handle = None
@@ -351,14 +350,46 @@ def __init__(
351
350
self ._register_sync_parameters_hook ()
352
351
# switch to chunk 0 at first.
353
352
self .switch_current_model_chunk (0 )
354
- self .model_conf .module_shapes = self ._module_shapes
353
+
354
+ def _get_launch_before_module (self ):
355
+ if self .is_moe is True :
356
+ _launch_before = gpc .config .parallel .expert_weight .get ("launch_allgather_before" , "wo" )
357
+ else :
358
+ _launch_before = gpc .config .parallel .weight .get ("launch_allgather_before" , "wo" )
359
+
360
+ if _launch_before == "wqkv" :
361
+ return ["wqkv" , "Wqkv" , "qkv" , "q_a_proj" , "q_proj" ]
362
+ elif _launch_before == "attn" :
363
+ return ["attn" ]
364
+ elif _launch_before == "wo" :
365
+ return ["out_proj" , "wo" ]
366
+ elif _launch_before == "w1" :
367
+ return ["w1" , "fused_w1_w3" ]
368
+ else :
369
+ assert False , "launch module should be in ['wqkv', 'attn', 'wo', 'w1']"
370
+
371
+ def _get_forward_overlap_granularity (self ):
372
+ if self .is_moe is True :
373
+ _overlap_granularity = gpc .config .parallel .expert_weight .get ("forward_overlap_per" , "layer" )
374
+ else :
375
+ _overlap_granularity = gpc .config .parallel .weight .get ("forward_overlap_per" , "layer" )
376
+
377
+ assert _overlap_granularity in ["module" , "layer" ]
378
+ return _overlap_granularity
355
379
356
380
def _parse_model_structure (self , cid : int , model : nn .Module ) -> None :
357
381
self ._overlap_states [cid ] = ISPOverlapState ()
358
382
359
383
def get_model (obj : nn .Module ) -> nn .Module :
360
384
return get_model (obj .model ) if hasattr (obj , "model" ) else obj
361
385
386
+ def is_allgather_launch_module (name , module ):
387
+ return (
388
+ hasattr (module , "is_attn_cls" )
389
+ and getattr (module , "is_attn_cls" )
390
+ and self ._launch_before_module == ["attn" ]
391
+ ) or (name .split ("." )[- 1 ] in self ._launch_before_module )
392
+
362
393
# Important: only works for llama-class models
363
394
children_name = get_model (model ).named_children ()
364
395
for _ , children in children_name :
@@ -369,18 +400,12 @@ def get_model(obj: nn.Module) -> nn.Module:
369
400
self ._overlap_states [cid ].index_to_isp_modules [idx ] = []
370
401
self ._overlap_states [cid ].index_to_block [idx ] = block
371
402
for name , child in block .named_modules ():
372
- if name .split ("." )[- 1 ] in ["out_proj" , "wo" ]:
373
- self ._overlap_states [cid ].isp_outs .append (child )
374
- self ._overlap_states [cid ].module_to_index [child ] = idx
403
+ if is_allgather_launch_module (name , child ):
404
+ self ._overlap_states [cid ].isp_prefetch_launch_module .append (child )
375
405
if isinstance (child , (ParallelLinearWithCommExt )):
376
406
if is_moe_param (child .weight ) != self .is_moe :
377
407
continue
378
- if name not in self ._module_shapes :
379
- weight_parallel_size = dist .get_world_size (self .process_group )
380
- origin_shape = tuple (
381
- [child .weight .shape [0 ] * weight_parallel_size ] + list (child .weight .shape [1 :])
382
- )
383
- self ._module_shapes [name ] = torch .Size (origin_shape )
408
+
384
409
self ._overlap_states [cid ].module_to_index [child ] = idx
385
410
self ._overlap_states [cid ].isp_modules .append (child )
386
411
self ._overlap_states [cid ].index_to_isp_modules [idx ].append (child )
@@ -403,25 +428,28 @@ def get_model(obj: nn.Module) -> nn.Module:
403
428
self ._overlap_states [cid ].num_blocks = len (self ._overlap_states [cid ].index_to_isp_modules )
404
429
405
430
def _all_gather_module_weight (self , module ):
431
+ assert module not in self ._bias_global_output and module not in self ._weight_global_output
406
432
with_bias = module .bias is not None
407
433
408
434
# submit the all-gather communication for weight and bias.
409
435
if with_bias :
410
- bias_output , bias_handle = all_gather_raw (
411
- module .bias ,
436
+ if module not in self ._bias_global_output :
437
+ bias_output , bias_handle = all_gather_raw (
438
+ module .bias ,
439
+ self .process_group ,
440
+ async_op = True ,
441
+ )
442
+ self ._bias_global_handle [module ] = bias_handle
443
+ self ._bias_global_output [module ] = bias_output
444
+
445
+ if module not in self ._weight_global_output :
446
+ weight_output , weight_handle = all_gather_raw (
447
+ module .weight ,
412
448
self .process_group ,
413
449
async_op = True ,
414
450
)
415
- self ._bias_global_handle [module ] = bias_handle
416
- self ._bias_global_output [module ] = bias_output
417
-
418
- weight_output , weight_handle = all_gather_raw (
419
- module .weight ,
420
- self .process_group ,
421
- async_op = True ,
422
- )
423
- self ._weight_global_handle [module ] = weight_handle
424
- self ._weight_global_output [module ] = weight_output
451
+ self ._weight_global_handle [module ] = weight_handle
452
+ self ._weight_global_output [module ] = weight_output
425
453
426
454
def _all_gather_block_weight (self , block_index : int ):
427
455
block = self ._index_to_block [block_index ]
@@ -463,30 +491,53 @@ def _pre_forward_hook_for_first_block(self, *args): # pylint: disable=W0613
463
491
"""
464
492
prefetch weight for block 0 before forward.
465
493
"""
466
- if self .is_forward is True :
494
+ if self ._forward_overlap_per == "layer" and self . is_forward is True :
467
495
self ._all_gather_block_weight (0 )
468
496
469
- def _pre_forward_hook_for_last_ckpt_block (self , * args ): # pylint: disable=W0613
470
- if self .is_forward is False :
471
- self ._all_gather_block_weight (self ._ckpt_block_num - 1 )
472
-
473
- def _pre_forward_hook_for_out_proj (self , module : nn .Module , * args ): # pylint: disable=W0613
497
+ def _pre_forward_hook_for_prefetch_launch_module (self , module : nn .Module , * args ): # pylint: disable=W0613
474
498
block_index = self ._module_to_index [module ]
475
499
476
- if (block_index - 1 < self ._ckpt_block_num ) and self .is_forward is False :
477
- if block_index - 1 >= 0 :
478
- self ._all_gather_block_weight (block_index - 1 )
479
- else :
480
- # start the all-gather for next block
481
- if block_index + 1 < self ._num_blocks :
482
- self ._all_gather_block_weight (block_index + 1 )
500
+ if self ._forward_overlap_per == "layer" :
501
+ if (block_index - 1 < self ._ckpt_block_num ) and self .is_forward is False :
502
+ if block_index - 1 >= 0 :
503
+ self ._all_gather_block_weight (block_index - 1 )
504
+ else :
505
+ # start the all-gather for next block
506
+ if block_index + 1 < self ._num_blocks :
507
+ self ._all_gather_block_weight (block_index + 1 )
483
508
484
509
def _pre_forward_hook_for_module (self , module : nn .Module , * args ): # pylint: disable=W0613
485
510
if module not in self ._weight_global_handle :
486
511
self ._all_gather_module_weight (module )
487
512
488
513
self ._wait_handle (module )
489
514
515
+ if self ._forward_overlap_per == "module" :
516
+ # start the all-gather for next module
517
+ # 1.forward prefetch for next module
518
+ module_index = self ._isp_modules .index (module )
519
+ module_layer_id = self ._module_to_index [module ]
520
+ if module_index + 1 < len (self ._isp_modules ) and self .is_forward is True :
521
+ next_module = self ._isp_modules [module_index + 1 ]
522
+ self ._all_gather_module_weight (next_module )
523
+
524
+ # 2.recompute forward prefetch for next module
525
+ if self .is_forward is False :
526
+ if module_index + 1 < len (self ._isp_modules ):
527
+ next_module = self ._isp_modules [module_index + 1 ]
528
+ next_module_layer_id = self ._module_to_index [next_module ]
529
+ if module_layer_id == next_module_layer_id :
530
+ self ._all_gather_module_weight (next_module )
531
+ # if current module is the last module in current layer, prefetch previous layer's first module
532
+ elif module_layer_id - 1 >= 0 :
533
+ next_module = self ._index_to_isp_modules [module_layer_id - 1 ][0 ]
534
+ self ._all_gather_module_weight (next_module )
535
+ else :
536
+ # if current module is the last module, prefetch previous layer's first module
537
+ if module_layer_id - 1 >= 0 :
538
+ next_module = self ._index_to_isp_modules [module_layer_id - 1 ][0 ]
539
+ self ._all_gather_module_weight (next_module )
540
+
490
541
def _post_forward_hook_for_module (self , module : nn .Module , * args ): # pylint: disable=W0613
491
542
if not ((self ._module_to_index [module ] < self ._ckpt_block_num ) and self .is_forward is False ):
492
543
self ._clear_handle (module )
@@ -515,29 +566,24 @@ def _register_sync_parameters_hook(self) -> None:
515
566
register forward hooks and backward hooks for isp modules.
516
567
"""
517
568
# register forward hooks
518
- # 1. register pre_forward_hook @block_0 to prefetch for block 0
519
- # 2. register pre_forward_hook @block_(ckpt_block_num-1) to prefetch for the last ckpt block
520
- # 3. register pre_forward_hook @out_proj module to prefetch for next block,
521
- # notice that next block's all_gather op should be after current block's all_to_all op
522
- # 4. register pre_forward_hook @isp_module to wait handle for current module
523
- # 5 . register post_forward_hook @isp_module to release resource
569
+ # 1. register pre_forward_hook @block_0 to prefetch weight for block 0.
570
+ # 2. register pre_forward_hook @prefetch_launch_module to prefetch weight for next block,
571
+ # when forward overlap granularity is 'layer'.
572
+ # 3. register pre_forward_hook @isp_module to wait handle for current module,
573
+ # and prefetch weight for next module when forward overlap granularity is 'module'.
574
+ # 4 . register post_forward_hook @isp_module to release memory resource.
524
575
self ._index_to_block [0 ].register_forward_pre_hook (self ._pre_forward_hook_for_first_block )
525
576
526
- if self ._ckpt_block_num >= 1 :
527
- self ._index_to_block [self ._ckpt_block_num - 1 ].register_forward_pre_hook (
528
- self ._pre_forward_hook_for_last_ckpt_block
529
- )
530
-
531
- for out_proj in self ._isp_outs :
532
- out_proj .register_forward_pre_hook (self ._pre_forward_hook_for_out_proj )
577
+ for module in self ._isp_prefetch_launch_module :
578
+ module .register_forward_pre_hook (self ._pre_forward_hook_for_prefetch_launch_module )
533
579
534
580
for module in self ._isp_modules :
535
581
module .register_forward_pre_hook (self ._pre_forward_hook_for_module )
536
582
module .register_forward_hook (self ._post_forward_hook_for_module )
537
583
538
584
# register backward hooks
539
- # 1. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module
540
- # 2. register post_backward_hook @isp_module to release resource
585
+ # 1. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module.
586
+ # 2. register post_backward_hook @isp_module to release memory resource.
541
587
if self ._ckpt_block_num < self ._num_blocks :
542
588
for module in self ._isp_modules :
543
589
module .register_full_backward_pre_hook (self ._pre_backward_hook_for_module )
@@ -556,7 +602,7 @@ def communication_mode(self) -> str:
556
602
return "wp"
557
603
558
604
def switch_current_model_chunk (self , chunk_id : int ) -> None :
559
- self ._isp_outs = self ._overlap_states [chunk_id ].isp_outs
605
+ self ._isp_prefetch_launch_module = self ._overlap_states [chunk_id ].isp_prefetch_launch_module
560
606
self ._isp_modules = self ._overlap_states [chunk_id ].isp_modules
561
607
self ._weight_global_handle = self ._overlap_states [chunk_id ].weight_global_handle
562
608
self ._bias_global_handle = self ._overlap_states [chunk_id ].bias_global_handle
@@ -872,9 +918,7 @@ def _q_kv(self, q: torch.Tensor, kv: torch.Tensor, *args, **kwargs) -> torch.Ten
872
918
873
919
q , kv = _SeqAllToAll .apply (self .spg , [2 , 3 ], [1 , 1 ], q , kv )
874
920
875
- torch .cuda .synchronize ()
876
921
context = self .local_attn (q , kv , * args , ** kwargs )
877
- torch .cuda .synchronize ()
878
922
879
923
context = _SeqAllToAll .apply (self .spg , 1 , 2 , context )
880
924
0 commit comments