@@ -80,6 +80,35 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, N
80
80
return None , grad_output , None
81
81
82
82
83
+ class AsyncColumnParallel (torch .autograd .Function ):
84
+
85
+ @staticmethod
86
+ def forward (ctx : Any , group : dist .ProcessGroup , input : torch .Tensor , weight , bias ) -> torch .Tensor :
87
+ """
88
+ Forward pass.
89
+ """
90
+ ctx .use_bias = bias is not None
91
+ ctx .group = group
92
+ output = torch .matmul (input , weight .transpose (- 1 , - 2 ))
93
+ if bias is not None :
94
+ output += bias
95
+
96
+ ctx .save_for_backward (input , weight )
97
+
98
+ return output
99
+
100
+ @staticmethod
101
+ def backward (ctx : Any , grad_output : torch .Tensor ) -> Tuple [None , torch .Tensor ]:
102
+
103
+ input , weight = ctx .saved_tensors
104
+ grad_input = grad_output .matmul (weight )
105
+ handle = dist .all_reduce (grad_input .contiguous (), group = ctx .group , async_op = True )
106
+ grad_weight = grad_output .view (- 1 , grad_output .shape [- 1 ]).t ().matmul (input .view (- 1 , input .shape [- 1 ]))
107
+ grad_bias = grad_output .sum (0 ) if ctx .use_bias else None
108
+ handle .wait ()
109
+ return None , grad_input , grad_weight , grad_bias
110
+
111
+
83
112
class ColumnParallel (torch .autograd .Function ):
84
113
"""
85
114
Custom autograd function for column-wise parallelism.
@@ -124,11 +153,17 @@ class TensorParallel_Layer(nn.Module, ABC):
124
153
support_training (bool): Flag indicating whether the layer supports training (default: False).
125
154
name (Optional[str]): The name of the layer, if provided.
126
155
"""
156
+ ##### Initialize Parameter List #####
127
157
128
- # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
129
- # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
158
+ # keep_module_on_host determines whether to keep the module on the host.
159
+ # Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory),
160
+ # so an additional copy is unnecessary.
130
161
keep_module_on_host : bool = False
131
162
163
+ ##### Runtime Parameter List #####
164
+ overlap_comm : bool = False
165
+ """ Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
166
+
132
167
def __init__ (self , mp_group : Optional [dist .ProcessGroup ], ** kwargs : Any ):
133
168
"""
134
169
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
@@ -260,6 +295,13 @@ def move(self, tensor):
260
295
return cloned_tensor
261
296
262
297
298
+ def configure_tensor_parallel_runtime (config ):
299
+ runtime_keys = ['overlap_comm' ]
300
+ for key in runtime_keys :
301
+ if hasattr (config , key ):
302
+ setattr (TensorParallel_Layer , key , getattr (config , key ))
303
+
304
+
263
305
class GatherReplacedLayerParams :
264
306
"""
265
307
A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality
@@ -406,11 +448,15 @@ def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
406
448
self .config_tp_params (self .bias )
407
449
408
450
def forward (self , input ):
409
- if getattr (self , 'mp_group' , None ) is not None :
410
- input = ColumnParallel .apply (self .mp_group , input )
411
- output = torch .matmul (input , self .weight .transpose (- 1 , - 2 ))
412
- if self .bias is not None :
413
- output += self .bias
451
+ if not self .__class__ .overlap_comm :
452
+ if getattr (self , 'mp_group' , None ) is not None :
453
+ input = ColumnParallel .apply (self .mp_group , input )
454
+ output = torch .matmul (input , self .weight .transpose (- 1 , - 2 ))
455
+ if self .bias is not None :
456
+ output += self .bias
457
+ else :
458
+ output = AsyncColumnParallel .apply (self .mp_group , input , self .weight , self .bias )
459
+
414
460
return output
415
461
416
462
@torch .no_grad ()
0 commit comments