@@ -91,24 +91,24 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bia
91
91
ctx .group = group
92
92
output = torch .matmul (input , weight .transpose (- 1 , - 2 ))
93
93
if bias is not None :
94
- output += bias
95
-
94
+ output += bias
95
+
96
96
ctx .save_for_backward (input , weight )
97
97
98
98
return output
99
+
99
100
@staticmethod
100
101
def backward (ctx : Any , grad_output : torch .Tensor ) -> Tuple [None , torch .Tensor ]:
101
-
102
-
103
- input , weight = ctx .saved_tensors
102
+
103
+ input , weight = ctx .saved_tensors
104
104
grad_input = grad_output .matmul (weight )
105
- handle = dist .all_reduce (grad_input .contiguous (), group = ctx .group , async_op = True )
106
- grad_weight = grad_output .view (- 1 ,grad_output .shape [- 1 ]).t ().matmul (input .view (- 1 , input .shape [- 1 ]))
105
+ handle = dist .all_reduce (grad_input .contiguous (), group = ctx .group , async_op = True )
106
+ grad_weight = grad_output .view (- 1 , grad_output .shape [- 1 ]).t ().matmul (input .view (- 1 , input .shape [- 1 ]))
107
107
grad_bias = grad_output .sum (0 ) if ctx .use_bias else None
108
108
handle .wait ()
109
109
return None , grad_input , grad_weight , grad_bias
110
-
111
-
110
+
111
+
112
112
class ColumnParallel (torch .autograd .Function ):
113
113
"""
114
114
Custom autograd function for column-wise parallelism.
@@ -153,11 +153,17 @@ class TensorParallel_Layer(nn.Module, ABC):
153
153
support_training (bool): Flag indicating whether the layer supports training (default: False).
154
154
name (Optional[str]): The name of the layer, if provided.
155
155
"""
156
+ ##### Initialize Parameter List #####
156
157
157
- # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
158
- # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
158
+ # keep_module_on_host determines whether to keep the module on the host.
159
+ # Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory),
160
+ # so an additional copy is unnecessary.
159
161
keep_module_on_host : bool = False
160
162
163
+ ##### Runtime Parameter List #####
164
+ overlap_comm : bool = False
165
+ """ Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
166
+
161
167
def __init__ (self , mp_group : Optional [dist .ProcessGroup ], ** kwargs : Any ):
162
168
"""
163
169
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
@@ -289,6 +295,13 @@ def move(self, tensor):
289
295
return cloned_tensor
290
296
291
297
298
+ def configure_tensor_parallel_runtime (config ):
299
+ runtime_keys = ['overlap_comm' ]
300
+ for key in runtime_keys :
301
+ if hasattr (config , key ):
302
+ setattr (TensorParallel_Layer , key , getattr (config , key ))
303
+
304
+
292
305
class GatherReplacedLayerParams :
293
306
"""
294
307
A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality
@@ -435,15 +448,15 @@ def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
435
448
self .config_tp_params (self .bias )
436
449
437
450
def forward (self , input ):
438
- if True :
451
+ if not self . __class__ . overlap_comm :
439
452
if getattr (self , 'mp_group' , None ) is not None :
440
453
input = ColumnParallel .apply (self .mp_group , input )
441
454
output = torch .matmul (input , self .weight .transpose (- 1 , - 2 ))
442
455
if self .bias is not None :
443
456
output += self .bias
444
457
else :
445
- output = AsyncColumnParallel .apply (self .mp_group ,input , self .weight , self .bias )
446
-
458
+ output = AsyncColumnParallel .apply (self .mp_group , input , self .weight , self .bias )
459
+
447
460
return output
448
461
449
462
@torch .no_grad ()
0 commit comments