Skip to content

Commit 18713c6

Browse files
jomayerijeffra
andauthored
Updating API docs (#2586)
Co-authored-by: Jeff Rasley <[email protected]>
1 parent 377c770 commit 18713c6

File tree

2 files changed

+52
-49
lines changed

2 files changed

+52
-49
lines changed

deepspeed/moe/layer.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,24 @@
1313

1414

1515
class MoE(torch.nn.Module):
16+
"""Initialize an MoE layer.
17+
18+
Arguments:
19+
hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
20+
expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
21+
num_experts (int, optional): default=1, the total number of experts per layer.
22+
ep_size (int, optional): default=1, number of ranks in the expert parallel world or group.
23+
k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
24+
capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
25+
eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
26+
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
27+
use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
28+
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
29+
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
30+
use_rts (bool, optional): default=True, whether to use Random Token Selection.
31+
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
32+
enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
33+
"""
1634
def __init__(self,
1735
hidden_size,
1836
expert,
@@ -28,24 +46,6 @@ def __init__(self,
2846
use_rts=True,
2947
use_tutel: bool = False,
3048
enable_expert_tensor_parallelism: bool = False):
31-
"""Initialize an MoE layer.
32-
33-
Arguments:
34-
hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
35-
expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
36-
num_experts (int, optional): default=1, the total number of experts per layer.
37-
ep_size (int, optional): default=1, number of ranks in the expert parallel world or group.
38-
k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
39-
capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
40-
eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
41-
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
42-
use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
43-
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
44-
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
45-
use_rts (bool, optional): default=True, whether to use Random Token Selection.
46-
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
47-
enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
48-
"""
4949

5050
super(MoE, self).__init__()
5151

deepspeed/runtime/pipe/module.py

+34-31
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,40 @@ def __init__(self,
8383

8484

8585
class PipelineModule(nn.Module):
86+
"""Modules to be parallelized with pipeline parallelism.
87+
88+
The key constraint that enables pipeline parallelism is the
89+
representation of the forward pass as a sequence of layers
90+
and the enforcement of a simple interface between them. The
91+
forward pass is implicitly defined by the module ``layers``. The key
92+
assumption is that the output of each layer can be directly fed as
93+
input to the next, like a ``torch.nn.Sequence``. The forward pass is
94+
implicitly:
95+
96+
.. code-block:: python
97+
98+
def forward(self, inputs):
99+
x = inputs
100+
for layer in self.layers:
101+
x = layer(x)
102+
return x
103+
104+
.. note::
105+
Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3.
106+
107+
Args:
108+
layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module.
109+
num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided.
110+
topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``.
111+
loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)``
112+
seed_layers(bool, optional): Use a different seed for each layer. Defaults to False.
113+
seed_fn(type, optional): The custom seed generating function. Defaults to random seed generator.
114+
base_seed (int, optional): The starting seed. Defaults to 1234.
115+
partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
116+
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
117+
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
118+
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
119+
"""
86120
def __init__(self,
87121
layers,
88122
num_stages=None,
@@ -95,37 +129,6 @@ def __init__(self,
95129
activation_checkpoint_interval=0,
96130
activation_checkpoint_func=checkpointing.checkpoint,
97131
checkpointable_layers=None):
98-
"""Modules to be parallelized with pipeline parallelism.
99-
100-
The key constraint that enables pipeline parallelism is the
101-
representation of the forward pass as a sequence of layers
102-
and the enforcement of a simple interface between them. The
103-
forward pass is implicitly defined by the module ``layers``. The key
104-
assumption is that the output of each layer can be directly fed as
105-
input to the next, like a ``torch.nn.Sequence``. The forward pass is
106-
implicitly:
107-
108-
.. code-block:: python
109-
110-
def forward(self, inputs):
111-
x = inputs
112-
for layer in self.layers:
113-
x = layer(x)
114-
return x
115-
116-
.. note::
117-
Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3.
118-
119-
Args:
120-
layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module.
121-
num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided.
122-
topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``.
123-
loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)``
124-
base_seed (int, optional): [description]. Defaults to 1234.
125-
partition_method (str, optional): [description]. Defaults to 'parameters'.
126-
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
127-
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
128-
"""
129132

130133
super().__init__()
131134

0 commit comments

Comments
 (0)