Skip to content

Commit 0ec6cdc

Browse files
feat(isp): support switch for launch ag and forward overlap per module (#381)
1 parent e60a50a commit 0ec6cdc

File tree

7 files changed

+126
-67
lines changed

7 files changed

+126
-67
lines changed

configs/7B_MoE4_sft.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@
183183
weight parallel (dict):
184184
1. size: int, the size of weight parallel.
185185
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
186+
3. launch_allgather_before: str, before which module to launch the all gather communication to
187+
prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'.
188+
Must be used with forward_overlap_per 'layer'.
189+
4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'.
186190
expert parallel (dict):
187191
1. size: int
188192
* if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
@@ -193,14 +197,18 @@
193197
expert weight parallel (dict):
194198
1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
195199
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
200+
3. launch_allgather_before: str, before which module to launch the all gather communication to
201+
prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'.
202+
Must be used with forward_overlap_per 'layer'.
203+
4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'.
196204
"""
197205
parallel = dict(
198206
zero1=dict(size=-1, fsdp=False),
199207
tensor=dict(size=1, mode="mtp"),
200208
pipeline=dict(size=1, interleaved_overlap=True),
201-
weight=dict(size=1, overlap=True),
209+
weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
202210
expert=dict(size=-1, no_tp=False),
203-
expert_weight=dict(size=1, overlap=True),
211+
expert_weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
204212
)
205213

206214
cudnn_deterministic = False

configs/7B_isp_sft.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@
186186
weight parallel (dict):
187187
1. size: int, the size of weight parallel.
188188
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
189+
3. launch_allgather_before: str, before which module to launch the all gather communication to
190+
prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'.
191+
Must be used with forward_overlap_per 'layer'.
192+
4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'.
189193
sequence_2D (dict):
190194
1. enable: bool, whether enable the 2D sequence parallel or not.
191195
2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses).
@@ -205,7 +209,7 @@
205209
zero1=dict(size=-1),
206210
tensor=dict(size=2, mode="isp"),
207211
pipeline=dict(size=1, interleaved_overlap=True),
208-
weight=dict(size=4, overlap=True),
212+
weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
209213
sequence_2D=dict(
210214
enable=False,
211215
head_size=2,

internlm/core/parallel/comm/isp.py

+102-58
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,13 @@ def __init__(
266266
dtype: torch.dtype = torch.half,
267267
device: torch.device = None,
268268
activation_checkpointing: float = 0.0,
269-
module_shapes: Dict[str, torch.Size] = None,
270269
) -> None:
271270
self.dtype = dtype
272271
if device is None:
273272
self.device = get_current_device()
274273
else:
275274
self.device = device
276275
self.activation_checkpointing = activation_checkpointing
277-
self.module_shapes = module_shapes
278276

279277

280278
class ISPOverlapState:
@@ -285,7 +283,7 @@ class ISPOverlapState:
285283
def __init__(self) -> None:
286284
self.num_blocks: int = 0
287285
self.ckpt_block_num: int = 0
288-
self.isp_outs: List[nn.Module] = []
286+
self.isp_prefetch_launch_module: List[nn.Module] = []
289287
self.isp_modules: List[nn.Module] = []
290288
self.index_to_isp_modules: Dict[int, nn.Module] = {}
291289
self.index_to_block: Dict[int, nn.Module] = {}
@@ -315,16 +313,17 @@ def __init__(
315313
self.is_moe = is_moe
316314
self.is_forward = True
317315
self.reduce_scatter_handlers = {}
318-
self._module_shapes = {}
319316
self._forward_prefetch_prerequisites = []
317+
self._forward_overlap_per = self._get_forward_overlap_granularity()
318+
self._launch_before_module = self._get_launch_before_module()
320319

321320
# real overlap state for each chunk.
322321
self._overlap_states: Dict[int, ISPOverlapState] = {}
323322

324323
# inner interface variables of overlap state.
325324
self._num_blocks = None
326325
self._ckpt_block_num = None
327-
self._isp_outs = None
326+
self._isp_prefetch_launch_module = None
328327
self._isp_modules = None
329328
# key: isp module; value: module global all-gather op handle
330329
self._weight_global_handle = None
@@ -351,14 +350,46 @@ def __init__(
351350
self._register_sync_parameters_hook()
352351
# switch to chunk 0 at first.
353352
self.switch_current_model_chunk(0)
354-
self.model_conf.module_shapes = self._module_shapes
353+
354+
def _get_launch_before_module(self):
355+
if self.is_moe is True:
356+
_launch_before = gpc.config.parallel.expert_weight.get("launch_allgather_before", "wo")
357+
else:
358+
_launch_before = gpc.config.parallel.weight.get("launch_allgather_before", "wo")
359+
360+
if _launch_before == "wqkv":
361+
return ["wqkv", "Wqkv", "qkv", "q_a_proj", "q_proj"]
362+
elif _launch_before == "attn":
363+
return ["attn"]
364+
elif _launch_before == "wo":
365+
return ["out_proj", "wo"]
366+
elif _launch_before == "w1":
367+
return ["w1", "fused_w1_w3"]
368+
else:
369+
assert False, "launch module should be in ['wqkv', 'attn', 'wo', 'w1']"
370+
371+
def _get_forward_overlap_granularity(self):
372+
if self.is_moe is True:
373+
_overlap_granularity = gpc.config.parallel.expert_weight.get("forward_overlap_per", "layer")
374+
else:
375+
_overlap_granularity = gpc.config.parallel.weight.get("forward_overlap_per", "layer")
376+
377+
assert _overlap_granularity in ["module", "layer"]
378+
return _overlap_granularity
355379

356380
def _parse_model_structure(self, cid: int, model: nn.Module) -> None:
357381
self._overlap_states[cid] = ISPOverlapState()
358382

359383
def get_model(obj: nn.Module) -> nn.Module:
360384
return get_model(obj.model) if hasattr(obj, "model") else obj
361385

386+
def is_allgather_launch_module(name, module):
387+
return (
388+
hasattr(module, "is_attn_cls")
389+
and getattr(module, "is_attn_cls")
390+
and self._launch_before_module == ["attn"]
391+
) or (name.split(".")[-1] in self._launch_before_module)
392+
362393
# Important: only works for llama-class models
363394
children_name = get_model(model).named_children()
364395
for _, children in children_name:
@@ -369,18 +400,12 @@ def get_model(obj: nn.Module) -> nn.Module:
369400
self._overlap_states[cid].index_to_isp_modules[idx] = []
370401
self._overlap_states[cid].index_to_block[idx] = block
371402
for name, child in block.named_modules():
372-
if name.split(".")[-1] in ["out_proj", "wo"]:
373-
self._overlap_states[cid].isp_outs.append(child)
374-
self._overlap_states[cid].module_to_index[child] = idx
403+
if is_allgather_launch_module(name, child):
404+
self._overlap_states[cid].isp_prefetch_launch_module.append(child)
375405
if isinstance(child, (ParallelLinearWithCommExt)):
376406
if is_moe_param(child.weight) != self.is_moe:
377407
continue
378-
if name not in self._module_shapes:
379-
weight_parallel_size = dist.get_world_size(self.process_group)
380-
origin_shape = tuple(
381-
[child.weight.shape[0] * weight_parallel_size] + list(child.weight.shape[1:])
382-
)
383-
self._module_shapes[name] = torch.Size(origin_shape)
408+
384409
self._overlap_states[cid].module_to_index[child] = idx
385410
self._overlap_states[cid].isp_modules.append(child)
386411
self._overlap_states[cid].index_to_isp_modules[idx].append(child)
@@ -403,25 +428,28 @@ def get_model(obj: nn.Module) -> nn.Module:
403428
self._overlap_states[cid].num_blocks = len(self._overlap_states[cid].index_to_isp_modules)
404429

405430
def _all_gather_module_weight(self, module):
431+
assert module not in self._bias_global_output and module not in self._weight_global_output
406432
with_bias = module.bias is not None
407433

408434
# submit the all-gather communication for weight and bias.
409435
if with_bias:
410-
bias_output, bias_handle = all_gather_raw(
411-
module.bias,
436+
if module not in self._bias_global_output:
437+
bias_output, bias_handle = all_gather_raw(
438+
module.bias,
439+
self.process_group,
440+
async_op=True,
441+
)
442+
self._bias_global_handle[module] = bias_handle
443+
self._bias_global_output[module] = bias_output
444+
445+
if module not in self._weight_global_output:
446+
weight_output, weight_handle = all_gather_raw(
447+
module.weight,
412448
self.process_group,
413449
async_op=True,
414450
)
415-
self._bias_global_handle[module] = bias_handle
416-
self._bias_global_output[module] = bias_output
417-
418-
weight_output, weight_handle = all_gather_raw(
419-
module.weight,
420-
self.process_group,
421-
async_op=True,
422-
)
423-
self._weight_global_handle[module] = weight_handle
424-
self._weight_global_output[module] = weight_output
451+
self._weight_global_handle[module] = weight_handle
452+
self._weight_global_output[module] = weight_output
425453

426454
def _all_gather_block_weight(self, block_index: int):
427455
block = self._index_to_block[block_index]
@@ -463,30 +491,53 @@ def _pre_forward_hook_for_first_block(self, *args): # pylint: disable=W0613
463491
"""
464492
prefetch weight for block 0 before forward.
465493
"""
466-
if self.is_forward is True:
494+
if self._forward_overlap_per == "layer" and self.is_forward is True:
467495
self._all_gather_block_weight(0)
468496

469-
def _pre_forward_hook_for_last_ckpt_block(self, *args): # pylint: disable=W0613
470-
if self.is_forward is False:
471-
self._all_gather_block_weight(self._ckpt_block_num - 1)
472-
473-
def _pre_forward_hook_for_out_proj(self, module: nn.Module, *args): # pylint: disable=W0613
497+
def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args): # pylint: disable=W0613
474498
block_index = self._module_to_index[module]
475499

476-
if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False:
477-
if block_index - 1 >= 0:
478-
self._all_gather_block_weight(block_index - 1)
479-
else:
480-
# start the all-gather for next block
481-
if block_index + 1 < self._num_blocks:
482-
self._all_gather_block_weight(block_index + 1)
500+
if self._forward_overlap_per == "layer":
501+
if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False:
502+
if block_index - 1 >= 0:
503+
self._all_gather_block_weight(block_index - 1)
504+
else:
505+
# start the all-gather for next block
506+
if block_index + 1 < self._num_blocks:
507+
self._all_gather_block_weight(block_index + 1)
483508

484509
def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
485510
if module not in self._weight_global_handle:
486511
self._all_gather_module_weight(module)
487512

488513
self._wait_handle(module)
489514

515+
if self._forward_overlap_per == "module":
516+
# start the all-gather for next module
517+
# 1.forward prefetch for next module
518+
module_index = self._isp_modules.index(module)
519+
module_layer_id = self._module_to_index[module]
520+
if module_index + 1 < len(self._isp_modules) and self.is_forward is True:
521+
next_module = self._isp_modules[module_index + 1]
522+
self._all_gather_module_weight(next_module)
523+
524+
# 2.recompute forward prefetch for next module
525+
if self.is_forward is False:
526+
if module_index + 1 < len(self._isp_modules):
527+
next_module = self._isp_modules[module_index + 1]
528+
next_module_layer_id = self._module_to_index[next_module]
529+
if module_layer_id == next_module_layer_id:
530+
self._all_gather_module_weight(next_module)
531+
# if current module is the last module in current layer, prefetch previous layer's first module
532+
elif module_layer_id - 1 >= 0:
533+
next_module = self._index_to_isp_modules[module_layer_id - 1][0]
534+
self._all_gather_module_weight(next_module)
535+
else:
536+
# if current module is the last module, prefetch previous layer's first module
537+
if module_layer_id - 1 >= 0:
538+
next_module = self._index_to_isp_modules[module_layer_id - 1][0]
539+
self._all_gather_module_weight(next_module)
540+
490541
def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
491542
if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False):
492543
self._clear_handle(module)
@@ -515,29 +566,24 @@ def _register_sync_parameters_hook(self) -> None:
515566
register forward hooks and backward hooks for isp modules.
516567
"""
517568
# register forward hooks
518-
# 1. register pre_forward_hook @block_0 to prefetch for block 0
519-
# 2. register pre_forward_hook @block_(ckpt_block_num-1) to prefetch for the last ckpt block
520-
# 3. register pre_forward_hook @out_proj module to prefetch for next block,
521-
# notice that next block's all_gather op should be after current block's all_to_all op
522-
# 4. register pre_forward_hook @isp_module to wait handle for current module
523-
# 5. register post_forward_hook @isp_module to release resource
569+
# 1. register pre_forward_hook @block_0 to prefetch weight for block 0.
570+
# 2. register pre_forward_hook @prefetch_launch_module to prefetch weight for next block,
571+
# when forward overlap granularity is 'layer'.
572+
# 3. register pre_forward_hook @isp_module to wait handle for current module,
573+
# and prefetch weight for next module when forward overlap granularity is 'module'.
574+
# 4. register post_forward_hook @isp_module to release memory resource.
524575
self._index_to_block[0].register_forward_pre_hook(self._pre_forward_hook_for_first_block)
525576

526-
if self._ckpt_block_num >= 1:
527-
self._index_to_block[self._ckpt_block_num - 1].register_forward_pre_hook(
528-
self._pre_forward_hook_for_last_ckpt_block
529-
)
530-
531-
for out_proj in self._isp_outs:
532-
out_proj.register_forward_pre_hook(self._pre_forward_hook_for_out_proj)
577+
for module in self._isp_prefetch_launch_module:
578+
module.register_forward_pre_hook(self._pre_forward_hook_for_prefetch_launch_module)
533579

534580
for module in self._isp_modules:
535581
module.register_forward_pre_hook(self._pre_forward_hook_for_module)
536582
module.register_forward_hook(self._post_forward_hook_for_module)
537583

538584
# register backward hooks
539-
# 1. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module
540-
# 2. register post_backward_hook @isp_module to release resource
585+
# 1. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module.
586+
# 2. register post_backward_hook @isp_module to release memory resource.
541587
if self._ckpt_block_num < self._num_blocks:
542588
for module in self._isp_modules:
543589
module.register_full_backward_pre_hook(self._pre_backward_hook_for_module)
@@ -556,7 +602,7 @@ def communication_mode(self) -> str:
556602
return "wp"
557603

558604
def switch_current_model_chunk(self, chunk_id: int) -> None:
559-
self._isp_outs = self._overlap_states[chunk_id].isp_outs
605+
self._isp_prefetch_launch_module = self._overlap_states[chunk_id].isp_prefetch_launch_module
560606
self._isp_modules = self._overlap_states[chunk_id].isp_modules
561607
self._weight_global_handle = self._overlap_states[chunk_id].weight_global_handle
562608
self._bias_global_handle = self._overlap_states[chunk_id].bias_global_handle
@@ -872,9 +918,7 @@ def _q_kv(self, q: torch.Tensor, kv: torch.Tensor, *args, **kwargs) -> torch.Ten
872918

873919
q, kv = _SeqAllToAll.apply(self.spg, [2, 3], [1, 1], q, kv)
874920

875-
torch.cuda.synchronize()
876921
context = self.local_attn(q, kv, *args, **kwargs)
877-
torch.cuda.synchronize()
878922

879923
context = _SeqAllToAll.apply(self.spg, 1, 2, context)
880924

internlm/initialize/launch.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,17 @@ def args_sanity_check():
9494
gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name))
9595

9696
if "weight" not in gpc.config.parallel:
97-
gpc.config.parallel._add_item("weight", dict(size=1, overlap=False))
97+
gpc.config.parallel._add_item(
98+
"weight", dict(size=1, overlap=False, launch_allgather_before="wo", forward_overlap_per="layer")
99+
)
98100

99101
if "expert" not in gpc.config.parallel:
100102
gpc.config.parallel._add_item("expert", dict(size=-1, no_tp=False))
101103

102104
if "expert_weight" not in gpc.config.parallel:
103-
gpc.config.parallel._add_item("expert_weight", dict(size=1, overlap=False))
105+
gpc.config.parallel._add_item(
106+
"expert_weight", dict(size=1, overlap=False, launch_allgather_before="wo", forward_overlap_per="layer")
107+
)
104108

105109
if isinstance(gpc.config.parallel.pipeline, int):
106110
pp = gpc.config.parallel.pipeline

internlm/model/ops/attention.py

+2
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,8 @@ class SelfAttention(nn.Module):
886886
attention_dropout (float): Dropout rate for attention scores. Defaults to 0.0.
887887
"""
888888

889+
is_attn_cls = True
890+
889891
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, layer_idx=0):
890892
super().__init__()
891893
self.causal = causal

internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py

-3
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,6 @@ def forward(
481481
@staticmethod
482482
def backward(ctx, dout, *args): # pylint: disable=W0613
483483

484-
torch.cuda.synchronize()
485484
q, k, v, out, softmax_lse = ctx.saved_tensors
486485

487486
dq, dk, dv = zigzag_double_ring_flash_attn_backward(
@@ -504,8 +503,6 @@ def backward(ctx, dout, *args): # pylint: disable=W0613
504503
deterministic=ctx.deterministic,
505504
)
506505

507-
torch.cuda.synchronize()
508-
509506
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
510507

511508

tests/test_training/test_loss.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def train(
109109
config.hybrid_zero_optimizer.overlap_sync_grad = False
110110

111111
config.parallel.pipeline = dict(size=pp_size, mode=pp_mode)
112-
config.parallel.weight = dict(size=wp_size, overlap=True)
112+
config.parallel.weight = dict(size=wp_size, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer")
113113
if interleaved is True:
114114
config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True, mode=pp_mode)
115115
config.model.num_chunks = num_chunks

0 commit comments

Comments
 (0)