17
17
from copy import deepcopy
18
18
from typing import Union
19
19
20
+ __all__ = [
21
+ "TensorParallel_Layer" , "LinearAllreduce" , "LinearLayer" , "LmHeadLinearAllreduce" , "Yuan_LinearAllreduce" ,
22
+ "Yuan_LinearLayer" , "GateUpPack_LinearLayer" , "Conv_LinearALlreduce" , "fused_LinearLayer" , "conv_LinearLayer"
23
+ ]
24
+
20
25
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE .INFERENCE
21
26
DS_IS_REPLACED_MODULE = 'ds_is_replaced_module'
22
27
DS_TENSOR_MODEL_PARALLEL = 'tensor_model_parallel'
@@ -43,26 +48,6 @@ def set_autotp_mode(training=False):
43
48
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE .INFERENCE
44
49
45
50
46
- def move (tensor , device ):
47
- # TODO: consider the timing of deletion
48
- # to save host resources when DP > 1。
49
-
50
- if tensor .is_meta :
51
- # Keep tensor in meta device if tensor is meta.
52
- return tensor
53
- else :
54
- # Using new tensors help in freeing memory (after split for example) was done before by calling clone().
55
- # Using copy=True instead of clone() will help in case of cpu --> cpu.
56
- # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
57
- cloned_tensor = tensor .to (device , copy = True )
58
-
59
- # free the memory of the original tensor to reduce memory peak
60
- # Equivalent to directly deleting the tensor reference outside the function.
61
- # see https://github.com/microsoft/DeepSpeed/pull/4353
62
- tensor .data = torch .empty (0 , device = tensor .device )
63
- return cloned_tensor
64
-
65
-
66
51
class RowParallel (torch .autograd .Function ):
67
52
"""
68
53
A custom autograd function for performing row-wise parallelism.
@@ -140,6 +125,10 @@ class TensorParallel_Layer(nn.Module, ABC):
140
125
name (Optional[str]): The name of the layer, if provided.
141
126
"""
142
127
128
+ # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
129
+ # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
130
+ keep_module_on_host : bool = False
131
+
143
132
def __init__ (self , mp_group : Optional [dist .ProcessGroup ], ** kwargs : Any ):
144
133
"""
145
134
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
@@ -163,6 +152,16 @@ def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
163
152
if kwargs .get ('name' ) is not None :
164
153
self .name = kwargs .get ('name' ) # Set the layer name if provided.
165
154
155
+ @classmethod
156
+ def set_keep_module_on_host (cls , value : bool ):
157
+ """
158
+ Set the static variable keep_module_on_host.
159
+
160
+ Args:
161
+ value (bool): The new value for keep_module_on_host.
162
+ """
163
+ cls .keep_module_on_host = value
164
+
166
165
@abstractmethod
167
166
def forward (self , input ):
168
167
"""
@@ -235,6 +234,31 @@ def extra_repr(self):
235
234
in_features , out_features , self .bias is not None , dtype )
236
235
return extra_repr_str
237
236
237
+ def move (self , tensor ):
238
+ # TODO: consider the timing of deletion
239
+ # to save host resources when DP > 1。
240
+
241
+ # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
242
+ # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
243
+ if tensor .is_meta :
244
+ # Keep tensor in meta device if tensor is meta.
245
+ return tensor
246
+ else :
247
+ device = 'cpu' if self .__class__ .keep_module_on_host else get_accelerator ().current_device_name ()
248
+ return_new_copy = not self .__class__ .keep_module_on_host
249
+
250
+ # Using new tensors help in freeing memory (after split for example) was done before by calling clone().
251
+ # Using copy=True instead of clone() will help in case of cpu --> cpu.
252
+ # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
253
+ cloned_tensor = tensor .to (device , copy = return_new_copy )
254
+
255
+ if return_new_copy :
256
+ # free the memory of the original tensor to reduce memory peak
257
+ # Equivalent to directly deleting the tensor reference outside the function.
258
+ # see https://github.com/microsoft/DeepSpeed/pull/4353
259
+ tensor .data = torch .empty (0 , device = tensor .device )
260
+ return cloned_tensor
261
+
238
262
239
263
class GatherReplacedLayerParams :
240
264
"""
@@ -349,7 +373,7 @@ def _tp_partition(self, params_list):
349
373
return
350
374
_partition = torch .chunk (param , self .tp_world_size , dim = - 1 )[self .tp_index ]
351
375
352
- _partition = move (_partition , get_accelerator (). current_device_name () ).detach ()
376
+ _partition = self . move (_partition ).detach ()
353
377
354
378
params_list [idx ].data = _partition
355
379
@@ -363,7 +387,7 @@ def uneven_partition(self, params_list):
363
387
self .name ),
364
388
dim = 1 )[self .tp_index ]
365
389
366
- _partition = move (_partition , get_accelerator (). current_device_name () ).detach ()
390
+ _partition = self . move (_partition ).detach ()
367
391
params_list [idx ].data = _partition
368
392
369
393
@@ -414,7 +438,7 @@ def _tp_partition(self, params_list):
414
438
#split bias if provide
415
439
_partition = torch .chunk (param , self .tp_world_size , dim = 0 )[self .tp_index ]
416
440
417
- _partition = move (_partition , get_accelerator (). current_device_name () ).detach ()
441
+ _partition = self . move (_partition ).detach ()
418
442
419
443
params_list [idx ].data = _partition
420
444
@@ -429,7 +453,7 @@ def uneven_partition(self, params_list):
429
453
self .name ),
430
454
dim = 0 )[self .tp_index ]
431
455
432
- _partition = move (_partition , get_accelerator (). current_device_name () ).detach ()
456
+ _partition = self . move (_partition ).detach ()
433
457
434
458
params_list [idx ].data = _partition
435
459
@@ -475,7 +499,7 @@ def _tp_partition(self, params_list):
475
499
476
500
_partition = prepare_tp_fused_qkvw (self .fused_module .module , param , self .tp_world_size , self .tp_index )
477
501
478
- _partition = move (_partition , get_accelerator (). current_device_name () ).detach ()
502
+ _partition = self . move (_partition ).detach ()
479
503
480
504
params_list [idx ].data = _partition
481
505
@@ -492,13 +516,13 @@ def _tp_partition(self, params_list):
492
516
weight , bias = params_list [0 ], params_list [1 ]
493
517
_partition = weight .data .split (get_shard_size_list (weight .shape [0 ], self .tp_world_size , self .name ),
494
518
dim = 1 )[self .tp_index ]
495
- _partition = move (_partition , get_accelerator (). current_device_name () ).detach ()
519
+ _partition = self . move (_partition ).detach ()
496
520
weight .data = _partition
497
521
498
522
if bias is not None :
499
523
_partition = bias .data .split (get_shard_size_list (weight .shape [1 ], self .tp_world_size , self .name ),
500
524
dim = 0 )[self .tp_index ]
501
- _partition = move (_partition , get_accelerator (). current_device_name () ).detach ()
525
+ _partition = self . move (_partition ).detach ()
502
526
503
527
bias .data = _partition
504
528
@@ -522,19 +546,19 @@ class Yuan_LinearLayer(LinearLayer):
522
546
def _tp_partition (self , params_list ):
523
547
weight , bias = shard_value_with_share_qk (params_list [0 ].data , params_list [1 ], self .tp_index ,
524
548
self .tp_world_size , True )
525
- params_list [0 ].data = move (weight , get_accelerator (). current_device_name () ).detach ()
549
+ params_list [0 ].data = self . move (weight ).detach ()
526
550
if bias is not None :
527
- params_list [1 ].data = move (bias , get_accelerator (). current_device_name () ).detach ()
551
+ params_list [1 ].data = self . move (bias ).detach ()
528
552
529
553
530
554
class GateUpPack_LinearLayer (LinearLayer ):
531
555
# chatGLM2, chatGLM2
532
556
@torch .no_grad ()
533
557
def _tp_partition (self , params_list ):
534
558
weight , bias = shard_chunk_mlp (params_list [0 ].data , params_list [1 ], self .tp_index , self .tp_world_size )
535
- params_list [0 ].data = move (weight , device = get_accelerator (). current_device_name () ).detach ()
559
+ params_list [0 ].data = self . move (weight ).detach ()
536
560
if bias is not None :
537
- params_list [1 ].data = move (bias , device = get_accelerator (). current_device_name () ).detach ()
561
+ params_list [1 ].data = self . move (bias ).detach ()
538
562
539
563
540
564
class Conv_LinearALlreduce (LinearAllreduce ):
@@ -549,7 +573,7 @@ def _tp_partition(self, params_list):
549
573
_partition = param .split (get_shard_size_list (param .shape [0 ], self .tp_world_size , self .name ),
550
574
dim = 1 )[self .tp_index ]
551
575
552
- _partition = move (_partition , get_accelerator (). current_device_name () ).detach ()
576
+ _partition = self . move (_partition ).detach ()
553
577
554
578
params_list [idx ].data = _partition
555
579
0 commit comments