-
Notifications
You must be signed in to change notification settings - Fork 4.5k
/
Copy pathhybrid_parallel_plugin.py
1523 lines (1318 loc) · 70.1 KB
/
hybrid_parallel_plugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import ctypes
import random
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from functools import partial
from types import MethodType
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from torch import Tensor, inf
from torch.distributed import ProcessGroup, get_world_size
from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .pp_plugin_base import PipelinePluginBase
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.to(dtype)
return x
class HybridParallelModule(ModelWrapper, AMPModelMixin):
def __init__(
self,
module: Module,
precision: str,
shard_config: ShardConfig,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
overlap_allgather: bool = False,
use_fp8: bool = False,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
self.dp_group = dp_group
self.tp_group = tp_group
self.sp_group = sp_group
self.use_ddp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8
self.param_origin_shape = {}
for name, param in module.named_parameters():
self.param_origin_shape[name] = param.shape
shardformer = ShardFormer(shard_config)
if custom_policy is not None:
assert isinstance(custom_policy, object)
module, self.shared_params = shardformer.optimize(module, policy=custom_policy)
# setting process groups for shared parameters
self.shared_param_process_groups = []
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
)
# setting mixed_precision
self.mixed_precision = None
if precision == "fp16":
self.mixed_precision = torch.float16
elif precision == "bf16":
self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None:
module = module.to(self.mixed_precision)
module = module.to(get_accelerator().get_current_device())
# setting input type cast when using mixed precision
self.convert_fn = None
if self.mixed_precision is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
# setting ddp configs
if use_ddp:
# convert model to sync bn
module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
# wrap the model with PyTorch DDP
module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module)
self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
if self.stage_manager.stage in shared_param:
param = shared_param[self.stage_manager.stage]
dist.all_reduce(param.grad, group=group)
dist.barrier()
@contextmanager
def no_sync(self):
r"""
A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization
when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass
when exiting the context.
"""
# Store the current value of 'require_grad_sync' to restore it later.
old_require_grad_sync = self.require_grad_sync
# Disable automatic gradient synchronization.
self.require_grad_sync = False
try:
if self.use_ddp:
# If using data parallel processing (use_ddp), disable synchronization too.
with self.module.no_sync():
yield
else:
yield
finally:
# Restore the original value of 'require_grad_sync'.
self.require_grad_sync = old_require_grad_sync
def sync_dp_grads(self):
r"""
Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1.
This function performs an all-reduce operation to combine gradients from different devices in the DP group.
Args:
None
Returns:
None
"""
# Check if the DP group size is 1, meaning no synchronization is needed.
if self.dp_group.size() == 1:
return
# Iterate through the model's parameters and perform gradient synchronization.
for p in self.module.parameters():
if p.grad is not None:
# Perform all-reduce to combine gradients from different devices.
dist.all_reduce(p.grad, group=self.dp_group)
# Normalize the gradient by dividing it by the DP group size.
p.grad.div_(self.dp_group.size())
def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
r"""
Synchronize gradients that are partially derived within sequence parallelism
if sequence parallelism is enabled. Gradients can be provided explicitly or extracted
from the module.
Args:
grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not
provided, gradients will be extracted from the model.
Returns:
None
"""
if self.shard_config.enable_sequence_parallelism:
if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
return
if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
# If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized
# across the tensor parallelism group.
group = self.tp_group
else:
raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}")
if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads)
else:
# Synchronize gradients from the model across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
with self._hook_context():
return super().forward(*args, **kwargs)
def unwrap(self):
module = super().unwrap()
if isinstance(module, DDP):
module = module.module
return module
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A complete param_group, with params in the form of param_id
# 2. A mapping from param address (obtained using id(param)) to integer param_id
# 3. A mapping from integer param_id to param address.
# 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
# When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
if optim is None:
return {}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
packed_group["params"] = []
for param_id, param in enumerate(group["params"], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
packed_group["params"].append(param_id)
param_info["param2id"][id(param)] = param_id
param_info["id2param"][param_id] = id(param)
param_info["param2shape"][id(param)] = original_shape
param_info["param_groups"].append(packed_group)
start_index += len(group["params"])
return param_info
def reinitialize_optimizer(optim: Optimizer, model: Module):
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
params = [p for p in group["params"] if p in model_params]
new_param_groups.append({**group, "params": params})
optim.__setstate__({"param_groups": new_param_groups})
class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(
self,
optim: Optimizer,
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
max_norm: float = 0,
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
):
self.param_info = param_info
if use_pipeline:
reinitialize_optimizer(optim, model)
self.model = model
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.max_norm = max_norm
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
self._current_grad_norm: Optional[float] = None
super().__init__(optim)
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs backward pass for gradient computation. If sequence parallelism is enabled
and gradient synchronization is required, it will synchronize gradients that are partially derived
within sequence parallelism across tp parallelism groups.
Args:
loss (Tensor): The loss tensor to compute gradients with respect to.
*args: Additional positional arguments to be passed to the superclass backward method.
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
with self.model._hook_context():
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across tp parallelism groups.
Args:
tensor (Tensor): The input tensor for which gradients are computed.
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def step(self, *args, **kwargs):
r"""
Perform an optimization step.
Args:
*args: Variable-length positional arguments to be passed to the optimizer's step function.
**kwargs: Keyword arguments to be passed to the optimizer's step function.
"""
if self.max_norm > 0:
# Compute the total gradient norm.
param_gradient_pairs = [
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
# Clip the gradients to prevent exploding gradients.
self._clip_grad_norm(total_norm)
# Perform the optimization step using the underlying optimizer.
self.optim.step(*args, **kwargs)
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
Returns:
float: The total norm of the given gradients.
"""
if len(param_gradient_pairs) == 0:
return 0.0
norm_type = float(norm_type)
# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if self.pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]
# grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'.
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if self.tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= self.tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if self.pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
if grad is stage_shared_param.grad:
grad_norm_exponentiated /= len(shared_param)
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if self.pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
# compute the total_norm
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
def _clip_grad_norm(self, total_norm: float) -> None:
r"""
Clips the gradients of the model's parameters to prevent exploding gradients.
Args:
total_norm (float): The computed total gradient norm.
Returns:
None
"""
clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6))
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for group in self.optim.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.grad.data.mul_(clip_coef_clamped)
def update_master_params(self, model: Module):
pass
def get_working_to_master_map(self):
return None
def get_master_to_working_map(self):
return None
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
self,
optim: Optimizer,
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
precision: str = "fp16",
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
):
self.model = model
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
if use_pipeline:
reinitialize_optimizer(optim, model)
super().__init__(
optim,
precision=precision,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
max_norm=max_norm,
)
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs backward pass for gradient computation. If sequence parallelism is enabled
and gradient synchronization is required, it will synchronize gradients that are partially derived
within sequence parallelism across tp parallelism groups.
Args:
loss (Tensor): The loss tensor to compute gradients with respect to.
*args: Additional positional arguments to be passed to the superclass backward method.
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
with self.model._hook_context():
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across tp parallelism groups.
Args:
tensor (Tensor): The input tensor for which gradients are computed.
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
Returns:
float: The total norm of the given gradients.
"""
if len(param_gradient_pairs) == 0:
return 0.0
norm_type = float(norm_type)
if norm_type == inf:
# The parent class calculates the norm of 'dp' gradients,
# so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if self.pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]
# grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism.
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if self.tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= self.tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if self.pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_working_shared_param = shared_param[self.stage_manager.stage]
stage_master_shared_param = self.working_to_master_map[stage_working_shared_param]
if grad is stage_master_shared_param.grad:
grad_norm_exponentiated /= len(shared_param)
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if self.pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
# compute the total_norm
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
fp8_communication: bool = False,
):
self.model = model
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
reinitialize_optimizer(optimizer, model)
super().__init__(
optimizer=optimizer,
initial_scale=initial_scale,
min_scale=min_scale,
pg_to_param_list=pg_to_param_list,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
clip_grad_norm=clip_grad_norm,
verbose=verbose,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
backward_context=model._hook_context,
)
def sync_dp_grads(self):
r"""
Synchronize gradients in the data parallelism dimension.
This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients
in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions,
namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization
and readability.
Args:
None
Returns:
None
"""
# Call the superclass `_sync_grad` method to synchronize gradients.
super()._sync_grad()
def _sync_sp_grads(self):
r"""
Synchronize gradients that are partially derived within sequence parallelism.
This method is responsible for synchronizing partially derived gradients across tp parallelism groups.
It identifies gradients that ara partially derived or not and synchronizes them.
If synchronization is required and gradients are found to be synchronized,
it performs the synchronization.
Args:
None
Returns:
None
"""
def _get_all_working_grads() -> List[Tensor]:
"""Retrieve all working gradients from different parameter groups."""
all_working_grads = []
for group_id in range(self.num_param_groups):
working_grads = self.get_working_grads_by_group_id(group_id)
all_working_grads.extend(working_grads)
return all_working_grads
def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
"""Identify gradients to be synchronized in the sequence parallelism."""
grads_to_sync = []
for grad in all_working_grads:
param_id_for_grad = self.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
grads_to_sync.append(grad)
if len(grads_to_sync) > 0:
return grads_to_sync
else:
return None
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)
if self.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
else:
return
def backward(self, loss, inputs=None, retain_graph=False):
"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs the backward pass for gradient computation based on a given loss tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across TP parallelism groups.
Args:
loss: The loss tensor to compute gradients with respect to.
retain_graph (bool): Whether to retain the computation graph.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, inputs=inputs, retain_graph=retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation based on a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across TP parallelism groups.
Args:
tensor: The input tensor for which gradients are computed.
grad: The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
gradients (List[Tensor]): A list of tensors containing gradients.
norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2.
Returns:
float: The computed gradient norm.
"""
# Check if the list of gradients is empty
if len(gradients) == 0:
return 0.0
dp_size = get_world_size(dp_pg) if dp_pg is not None else 1
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
if norm_type == inf:
# The parent class calculates the norm of 'dp' gradients,
# so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm = super()._compute_grad_norm(gradients, norm_type)
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
param_id_for_grad = self.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
working_grad = self.get_working_grad_by_param_id(id(stage_shared_param))
if grad is working_grad:
grad_norm_exponentiated /= len(shared_param)
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if dp_size > 1:
# compute norm in dp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg)
if tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
# Compute the 'total_norm' from 'total_norm_exponentiated'
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
class HybridParallelPlugin(PipelinePluginBase):
"""
Plugin for Hybrid Parallel Training.
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
model, train_dataset, optimizer, criterion = ...
plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
```
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
sp_size (int): The size of sequence parallelism.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
"""
def __init__(
self,