Skip to content

Commit 5d297e4

Browse files
committed
async tp
Signed-off-by: inkcherry <[email protected]>
1 parent b418cf6 commit 5d297e4

File tree

5 files changed

+76
-18
lines changed

5 files changed

+76
-18
lines changed

deepspeed/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def init_inference(model, config=None, **kwargs):
366366
return engine
367367

368368

369-
def tp_model_init(model, tp_size, dtype):
369+
def tp_model_init(model, tp_size, dtype, config=None, **kwargs):
370370
"""
371371
Initialize the model for tensor parallelism.
372372
@@ -379,8 +379,9 @@ def tp_model_init(model, tp_size, dtype):
379379
torch.nn.Module: The initialized model with tensor parallelism.
380380
"""
381381
# avoid re-entry
382-
assert not hasattr(
383-
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."
382+
if hasattr(model, 'ds_autotp_parsed'):
383+
logger.warning("ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed.")
384+
return
384385

385386
set_autotp_mode(training=True)
386387

deepspeed/module_inject/layers.py

+53-7
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,35 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, N
8080
return None, grad_output, None
8181

8282

83+
class AsyncColumnParallel(torch.autograd.Function):
84+
85+
@staticmethod
86+
def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bias) -> torch.Tensor:
87+
"""
88+
Forward pass.
89+
"""
90+
ctx.use_bias = bias is not None
91+
ctx.group = group
92+
output = torch.matmul(input, weight.transpose(-1, -2))
93+
if bias is not None:
94+
output += bias
95+
96+
ctx.save_for_backward(input, weight)
97+
98+
return output
99+
100+
@staticmethod
101+
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
102+
103+
input, weight = ctx.saved_tensors
104+
grad_input = grad_output.matmul(weight)
105+
handle = dist.all_reduce(grad_input.contiguous(), group=ctx.group, async_op=True)
106+
grad_weight = grad_output.view(-1, grad_output.shape[-1]).t().matmul(input.view(-1, input.shape[-1]))
107+
grad_bias = grad_output.sum(0) if ctx.use_bias else None
108+
handle.wait()
109+
return None, grad_input, grad_weight, grad_bias
110+
111+
83112
class ColumnParallel(torch.autograd.Function):
84113
"""
85114
Custom autograd function for column-wise parallelism.
@@ -124,11 +153,17 @@ class TensorParallel_Layer(nn.Module, ABC):
124153
support_training (bool): Flag indicating whether the layer supports training (default: False).
125154
name (Optional[str]): The name of the layer, if provided.
126155
"""
156+
##### Initialize Parameter List #####
127157

128-
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
129-
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
158+
# keep_module_on_host determines whether to keep the module on the host.
159+
# Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory),
160+
# so an additional copy is unnecessary.
130161
keep_module_on_host: bool = False
131162

163+
##### Runtime Parameter List #####
164+
overlap_comm: bool = False
165+
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
166+
132167
def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
133168
"""
134169
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
@@ -260,6 +295,13 @@ def move(self, tensor):
260295
return cloned_tensor
261296

262297

298+
def configure_tensor_parallel_runtime(config):
299+
runtime_keys = ['overlap_comm']
300+
for key in runtime_keys:
301+
if hasattr(config, key):
302+
setattr(TensorParallel_Layer, key, getattr(config, key))
303+
304+
263305
class GatherReplacedLayerParams:
264306
"""
265307
A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality
@@ -406,11 +448,15 @@ def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
406448
self.config_tp_params(self.bias)
407449

408450
def forward(self, input):
409-
if getattr(self, 'mp_group', None) is not None:
410-
input = ColumnParallel.apply(self.mp_group, input)
411-
output = torch.matmul(input, self.weight.transpose(-1, -2))
412-
if self.bias is not None:
413-
output += self.bias
451+
if not self.__class__.overlap_comm:
452+
if getattr(self, 'mp_group', None) is not None:
453+
input = ColumnParallel.apply(self.mp_group, input)
454+
output = torch.matmul(input, self.weight.transpose(-1, -2))
455+
if self.bias is not None:
456+
output += self.bias
457+
else:
458+
output = AsyncColumnParallel.apply(self.mp_group, input, self.weight, self.bias)
459+
414460
return output
415461

416462
@torch.no_grad()

deepspeed/runtime/engine.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@
3737
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
3838

3939
from deepspeed.linear.optimized_linear import LoRAOptimizedLinear
40-
from deepspeed.module_inject.layers import GatherReplacedLayerParams
41-
40+
from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime
4241
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
4342
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
4443
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
@@ -248,7 +247,7 @@ def __init__(self,
248247
self._configure_with_arguments(args, mpu)
249248
self._do_sanity_check()
250249
if self.autotp_size() > 1:
251-
self._configure_tensor_parallel_states(model)
250+
self._configure_tensor_parallel(model, self.tensor_parallel_config())
252251
see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
253252
if mpu is not None:
254253
if self.elasticity_enabled():
@@ -416,14 +415,17 @@ def _optimized_linear_offload_setup(self):
416415
else:
417416
p.ds_offload = False
418417

418+
def _configure_tensor_parallel(self, model, tp_config):
419+
self._configure_tensor_parallel_states(model)
420+
configure_tensor_parallel_runtime(tp_config)
421+
419422
def _configure_tensor_parallel_states(self, model):
420423
"""
421424
Configures the tensor parallel states for the model.
422425
This includes setting up the tensor parallel groups, initializing the TP mesh,
423426
and registering a pre-hook to ensure that the Dataloader inputs are consistent across ranks.
424427
"""
425428
self._set_client_model(model)
426-
427429
# sanity check
428430
# currently, the compatibility between 'autotp' and 'zero > 1' has not been validated
429431
assert self.zero_optimization_stage(
@@ -902,6 +904,9 @@ def zero_legacy_stage1(self):
902904
def zero_ignore_unused_parameters(self):
903905
return self._config.zero_config.ignore_unused_parameters
904906

907+
def tensor_parallel_config(self):
908+
return self._config.tensor_parallel_config
909+
905910
def autotp_size(self):
906911
return self._config.tensor_parallel_config.autotp_size
907912

deepspeed/runtime/tensor_parallel/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class TPTrainingConfig(DeepSpeedConfigModel):
4747
In automatic tensor-parallelism training, 'tensor_parallel_size'
4848
When set to 0, indicates that it is disabled.
4949
"""
50+
overlap_comm: bool = False
51+
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
52+
5053
tensor_parallel: TPConfig = Field({}, alias="tp")
5154
"""
5255
Configuration for tensor parallelism used to split the model across several

tests/unit/model_parallelism/test_autotp_training.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,12 @@ def process_linear_layer(hidden_dim, input):
163163

164164
@pytest.mark.sequential
165165
@pytest.mark.parametrize("tp_size", [2, 4])
166+
@pytest.mark.parametrize("overlap_comm", [True, False])
166167
class TestTpLayerFwdBwd(DistributedTest):
167168
world_size = 4
168169
reuse_dist_env = True
169170

170-
def testRowParallel(self, tp_size: int):
171+
def testRowParallel(self, tp_size: int, overlap_comm: bool):
171172
skip_on_device()
172173
hidden_dim = 128
173174
batch_size_per_device = 1
@@ -182,7 +183,8 @@ def testRowParallel(self, tp_size: int):
182183
}
183184
},
184185
"tensor_parallel": {
185-
"autotp_size": tp_size
186+
"autotp_size": tp_size,
187+
"overlap_comm": overlap_comm
186188
},
187189
"zero_optimization": {
188190
"stage": 0,
@@ -216,7 +218,7 @@ def testRowParallel(self, tp_size: int):
216218
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
217219
assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-3)
218220

219-
def testColumnParallel(self, tp_size: int):
221+
def testColumnParallel(self, tp_size: int, overlap_comm: bool):
220222
skip_on_device()
221223
hidden_dim = 128
222224
batch_size_per_device = 1
@@ -231,7 +233,8 @@ def testColumnParallel(self, tp_size: int):
231233
}
232234
},
233235
"tensor_parallel": {
234-
"autotp_size": tp_size
236+
"autotp_size": tp_size,
237+
"overlap_comm": overlap_comm
235238
},
236239
"zero_optimization": {
237240
"stage": 0,

0 commit comments

Comments
 (0)