Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

async tp allreduce #7115

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def init_inference(model, config=None, **kwargs):
return engine


def tp_model_init(model, tp_size, dtype):
def tp_model_init(model, tp_size, dtype, config=None, **kwargs):
"""
Initialize the model for tensor parallelism.

Expand All @@ -379,8 +379,9 @@ def tp_model_init(model, tp_size, dtype):
torch.nn.Module: The initialized model with tensor parallelism.
"""
# avoid re-entry
assert not hasattr(
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."
if hasattr(model, 'ds_autotp_parsed'):
logger.warning("ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed.")
return

set_autotp_mode(training=True)

Expand Down
60 changes: 53 additions & 7 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,35 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, N
return None, grad_output, None


class AsyncColumnParallel(torch.autograd.Function):

@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bias) -> torch.Tensor:
"""
Forward pass.
"""
ctx.use_bias = bias is not None
ctx.group = group
output = torch.matmul(input, weight.transpose(-1, -2))
if bias is not None:
output += bias

ctx.save_for_backward(input, weight)

return output

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:

input, weight = ctx.saved_tensors
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(grad_input.contiguous(), group=ctx.group, async_op=True)
grad_weight = grad_output.view(-1, grad_output.shape[-1]).t().matmul(input.view(-1, input.shape[-1]))
grad_bias = grad_output.sum(0) if ctx.use_bias else None
handle.wait()
return None, grad_input, grad_weight, grad_bias


class ColumnParallel(torch.autograd.Function):
"""
Custom autograd function for column-wise parallelism.
Expand Down Expand Up @@ -124,11 +153,17 @@ class TensorParallel_Layer(nn.Module, ABC):
support_training (bool): Flag indicating whether the layer supports training (default: False).
name (Optional[str]): The name of the layer, if provided.
"""
##### Initialize Parameter List #####

# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
# keep_module_on_host determines whether to keep the module on the host.
# Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory),
# so an additional copy is unnecessary.
keep_module_on_host: bool = False

##### Runtime Parameter List #####
overlap_comm: bool = False
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """

def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
"""
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
Expand Down Expand Up @@ -260,6 +295,13 @@ def move(self, tensor):
return cloned_tensor


def configure_tensor_parallel_runtime(config):
runtime_keys = ['overlap_comm']
for key in runtime_keys:
if hasattr(config, key):
setattr(TensorParallel_Layer, key, getattr(config, key))


class GatherReplacedLayerParams:
"""
A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality
Expand Down Expand Up @@ -406,11 +448,15 @@ def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
self.config_tp_params(self.bias)

def forward(self, input):
if getattr(self, 'mp_group', None) is not None:
input = ColumnParallel.apply(self.mp_group, input)
output = torch.matmul(input, self.weight.transpose(-1, -2))
if self.bias is not None:
output += self.bias
if not self.__class__.overlap_comm:
if getattr(self, 'mp_group', None) is not None:
input = ColumnParallel.apply(self.mp_group, input)
output = torch.matmul(input, self.weight.transpose(-1, -2))
if self.bias is not None:
output += self.bias
else:
output = AsyncColumnParallel.apply(self.mp_group, input, self.weight, self.bias)

return output

@torch.no_grad()
Expand Down
13 changes: 9 additions & 4 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer

from deepspeed.linear.optimized_linear import LoRAOptimizedLinear
from deepspeed.module_inject.layers import GatherReplacedLayerParams

from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
Expand Down Expand Up @@ -248,7 +247,7 @@ def __init__(self,
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
if self.autotp_size() > 1:
self._configure_tensor_parallel_states(model)
self._configure_tensor_parallel(model, self.tensor_parallel_config())
see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
if mpu is not None:
if self.elasticity_enabled():
Expand Down Expand Up @@ -416,14 +415,17 @@ def _optimized_linear_offload_setup(self):
else:
p.ds_offload = False

def _configure_tensor_parallel(self, model, tp_config):
self._configure_tensor_parallel_states(model)
configure_tensor_parallel_runtime(tp_config)

def _configure_tensor_parallel_states(self, model):
"""
Configures the tensor parallel states for the model.
This includes setting up the tensor parallel groups, initializing the TP mesh,
and registering a pre-hook to ensure that the Dataloader inputs are consistent across ranks.
"""
self._set_client_model(model)

# sanity check
# currently, the compatibility between 'autotp' and 'zero > 1' has not been validated
assert self.zero_optimization_stage(
Expand Down Expand Up @@ -902,6 +904,9 @@ def zero_legacy_stage1(self):
def zero_ignore_unused_parameters(self):
return self._config.zero_config.ignore_unused_parameters

def tensor_parallel_config(self):
return self._config.tensor_parallel_config

def autotp_size(self):
return self._config.tensor_parallel_config.autotp_size

Expand Down
3 changes: 3 additions & 0 deletions deepspeed/runtime/tensor_parallel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class TPTrainingConfig(DeepSpeedConfigModel):
In automatic tensor-parallelism training, 'tensor_parallel_size'
When set to 0, indicates that it is disabled.
"""
overlap_comm: bool = False
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """

tensor_parallel: TPConfig = Field({}, alias="tp")
"""
Configuration for tensor parallelism used to split the model across several
Expand Down
11 changes: 7 additions & 4 deletions tests/unit/model_parallelism/test_autotp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,12 @@ def process_linear_layer(hidden_dim, input):

@pytest.mark.sequential
@pytest.mark.parametrize("tp_size", [2, 4])
@pytest.mark.parametrize("overlap_comm", [True, False])
class TestTpLayerFwdBwd(DistributedTest):
world_size = 4
reuse_dist_env = True

def testRowParallel(self, tp_size: int):
def testRowParallel(self, tp_size: int, overlap_comm: bool):
skip_on_device()
hidden_dim = 128
batch_size_per_device = 1
Expand All @@ -182,7 +183,8 @@ def testRowParallel(self, tp_size: int):
}
},
"tensor_parallel": {
"autotp_size": tp_size
"autotp_size": tp_size,
"overlap_comm": overlap_comm
},
"zero_optimization": {
"stage": 0,
Expand Down Expand Up @@ -216,7 +218,7 @@ def testRowParallel(self, tp_size: int):
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-3)

def testColumnParallel(self, tp_size: int):
def testColumnParallel(self, tp_size: int, overlap_comm: bool):
skip_on_device()
hidden_dim = 128
batch_size_per_device = 1
Expand All @@ -231,7 +233,8 @@ def testColumnParallel(self, tp_size: int):
}
},
"tensor_parallel": {
"autotp_size": tp_size
"autotp_size": tp_size,
"overlap_comm": overlap_comm
},
"zero_optimization": {
"stage": 0,
Expand Down
Loading