Skip to content

Commit 4f779d5

Browse files
committed
add config
1 parent 9e8342a commit 4f779d5

File tree

5 files changed

+48
-33
lines changed

5 files changed

+48
-33
lines changed

deepspeed/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def init_inference(model, config=None, **kwargs):
366366
return engine
367367

368368

369-
def tp_model_init(model, tp_size, dtype):
369+
def tp_model_init(model, tp_size, dtype, config=None, **kwargs):
370370
"""
371371
Initialize the model for tensor parallelism.
372372
@@ -379,8 +379,9 @@ def tp_model_init(model, tp_size, dtype):
379379
torch.nn.Module: The initialized model with tensor parallelism.
380380
"""
381381
# avoid re-entry
382-
assert not hasattr(
383-
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."
382+
if hasattr(model, 'ds_autotp_parsed'):
383+
logger.warning("ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed.")
384+
return
384385

385386
set_autotp_mode(training=True)
386387

deepspeed/module_inject/layers.py

+27-14
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,24 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bia
9191
ctx.group = group
9292
output = torch.matmul(input, weight.transpose(-1, -2))
9393
if bias is not None:
94-
output+=bias
95-
94+
output += bias
95+
9696
ctx.save_for_backward(input, weight)
9797

9898
return output
99+
99100
@staticmethod
100101
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
101-
102-
103-
input, weight = ctx.saved_tensors
102+
103+
input, weight = ctx.saved_tensors
104104
grad_input = grad_output.matmul(weight)
105-
handle=dist.all_reduce(grad_input.contiguous(), group=ctx.group, async_op=True)
106-
grad_weight = grad_output.view(-1,grad_output.shape[-1]).t().matmul(input.view(-1, input.shape[-1]))
105+
handle = dist.all_reduce(grad_input.contiguous(), group=ctx.group, async_op=True)
106+
grad_weight = grad_output.view(-1, grad_output.shape[-1]).t().matmul(input.view(-1, input.shape[-1]))
107107
grad_bias = grad_output.sum(0) if ctx.use_bias else None
108108
handle.wait()
109109
return None, grad_input, grad_weight, grad_bias
110-
111-
110+
111+
112112
class ColumnParallel(torch.autograd.Function):
113113
"""
114114
Custom autograd function for column-wise parallelism.
@@ -153,11 +153,17 @@ class TensorParallel_Layer(nn.Module, ABC):
153153
support_training (bool): Flag indicating whether the layer supports training (default: False).
154154
name (Optional[str]): The name of the layer, if provided.
155155
"""
156+
##### Initialize Parameter List #####
156157

157-
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
158-
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
158+
# keep_module_on_host determines whether to keep the module on the host.
159+
# Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory),
160+
# so an additional copy is unnecessary.
159161
keep_module_on_host: bool = False
160162

163+
##### Runtime Parameter List #####
164+
overlap_comm: bool = False
165+
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
166+
161167
def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
162168
"""
163169
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
@@ -289,6 +295,13 @@ def move(self, tensor):
289295
return cloned_tensor
290296

291297

298+
def configure_tensor_parallel_runtime(config):
299+
runtime_keys = ['overlap_comm']
300+
for key in runtime_keys:
301+
if hasattr(config, key):
302+
setattr(TensorParallel_Layer, key, getattr(config, key))
303+
304+
292305
class GatherReplacedLayerParams:
293306
"""
294307
A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality
@@ -435,15 +448,15 @@ def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
435448
self.config_tp_params(self.bias)
436449

437450
def forward(self, input):
438-
if True:
451+
if not self.__class__.overlap_comm:
439452
if getattr(self, 'mp_group', None) is not None:
440453
input = ColumnParallel.apply(self.mp_group, input)
441454
output = torch.matmul(input, self.weight.transpose(-1, -2))
442455
if self.bias is not None:
443456
output += self.bias
444457
else:
445-
output = AsyncColumnParallel.apply(self.mp_group,input, self.weight, self.bias)
446-
458+
output = AsyncColumnParallel.apply(self.mp_group, input, self.weight, self.bias)
459+
447460
return output
448461

449462
@torch.no_grad()

deepspeed/runtime/engine.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@
3737
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
3838

3939
from deepspeed.linear.optimized_linear import LoRAOptimizedLinear
40-
from deepspeed.module_inject.layers import GatherReplacedLayerParams
41-
40+
from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime
4241
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
4342
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
4443
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
@@ -248,7 +247,7 @@ def __init__(self,
248247
self._configure_with_arguments(args, mpu)
249248
self._do_sanity_check()
250249
if self.autotp_size() > 1:
251-
self._configure_tensor_parallel_states(model)
250+
self._configure_tensor_parallel(model, self.tensor_parallel_config())
252251
see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
253252
if mpu is not None:
254253
if self.elasticity_enabled():
@@ -416,15 +415,17 @@ def _optimized_linear_offload_setup(self):
416415
else:
417416
p.ds_offload = False
418417

419-
418+
def _configure_tensor_parallel(self, model, tp_config):
419+
self._configure_tensor_parallel_states(model)
420+
configure_tensor_parallel_runtime(tp_config)
421+
420422
def _configure_tensor_parallel_states(self, model):
421423
"""
422424
Configures the tensor parallel states for the model.
423425
This includes setting up the tensor parallel groups, initializing the TP mesh,
424426
and registering a pre-hook to ensure that the Dataloader inputs are consistent across ranks.
425427
"""
426428
self._set_client_model(model)
427-
428429
# sanity check
429430
# currently, the compatibility between 'autotp' and 'zero > 1' has not been validated
430431
assert self.zero_optimization_stage(
@@ -905,7 +906,7 @@ def zero_ignore_unused_parameters(self):
905906

906907
def tensor_parallel_config(self):
907908
return self._config.tensor_parallel_config
908-
909+
909910
def autotp_size(self):
910911
return self._config.tensor_parallel_config.autotp_size
911912

deepspeed/runtime/tensor_parallel/config.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,19 @@
88
import torch
99
from pydantic import Field
1010
from typing import Optional
11-
from deepspeed.module_inject.layers import TensorParallel_Layer
1211

1312

1413
class AUTOTP_MODE(Enum):
1514
TRAINING = "TRAINING"
1615
INFERENCE = "INFERENCE"
1716

18-
def configure_tensor_parallel_runtime(config):
19-
TensorParallel_Layer.overlap_comm = config['overlap_comm']
2017

2118
class TPConfig(DeepSpeedConfigModel):
2219
""" Configure tensor parallelism settings """
2320

2421
tp_size: int = 1
2522
""" Number of devices to split the model across using tensor parallelism. """
2623

27-
overlap_comm : bool = False
28-
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
29-
3024
tp_grain_size: int = 1
3125
"The variable required by the autoTP parser has not been activated in training yet"
3226
"as it depends on the gather logic that supports uneven partitioning. "
@@ -53,6 +47,9 @@ class TPTrainingConfig(DeepSpeedConfigModel):
5347
In automatic tensor-parallelism training, 'tensor_parallel_size'
5448
When set to 0, indicates that it is disabled.
5549
"""
50+
overlap_comm: bool = False
51+
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
52+
5653
tensor_parallel: TPConfig = Field({}, alias="tp")
5754
"""
5855
Configuration for tensor parallelism used to split the model across several

tests/unit/model_parallelism/test_autotp_training.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,12 @@ def process_linear_layer(hidden_dim, input):
163163

164164
@pytest.mark.sequential
165165
@pytest.mark.parametrize("tp_size", [2, 4])
166+
@pytest.mark.parametrize("overlap_comm", [True, False])
166167
class TestTpLayerFwdBwd(DistributedTest):
167168
world_size = 4
168169
reuse_dist_env = True
169170

170-
def testRowParallel(self, tp_size: int):
171+
def testRowParallel(self, tp_size: int, overlap_comm: bool):
171172
skip_on_device()
172173
hidden_dim = 128
173174
batch_size_per_device = 1
@@ -182,7 +183,8 @@ def testRowParallel(self, tp_size: int):
182183
}
183184
},
184185
"tensor_parallel": {
185-
"autotp_size": tp_size
186+
"autotp_size": tp_size,
187+
"overlap_comm": overlap_comm
186188
},
187189
"zero_optimization": {
188190
"stage": 0,
@@ -216,7 +218,7 @@ def testRowParallel(self, tp_size: int):
216218
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
217219
assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-3)
218220

219-
def testColumnParallel(self, tp_size: int):
221+
def testColumnParallel(self, tp_size: int, overlap_comm: bool):
220222
skip_on_device()
221223
hidden_dim = 128
222224
batch_size_per_device = 1
@@ -231,7 +233,8 @@ def testColumnParallel(self, tp_size: int):
231233
}
232234
},
233235
"tensor_parallel": {
234-
"autotp_size": tp_size
236+
"autotp_size": tp_size,
237+
"overlap_comm": overlap_comm
235238
},
236239
"zero_optimization": {
237240
"stage": 0,

0 commit comments

Comments
 (0)