Skip to content

Commit 542103f

Browse files
- Updated sinkhorn initialization and add max_iter argument.
- Removed mp assertion for moe - Removed mlp_type checks in moe code - Added Bf16 conversion to dmoe_gather
1 parent fb68c07 commit 542103f

File tree

6 files changed

+18
-36
lines changed

6 files changed

+18
-36
lines changed

Diff for: megatron/model/moe.py

+5-14
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,11 @@ def __init__(
5353
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
5454

5555
# decide which parallel grouped MLP implementation to use
56-
if neox_args.mlp_type == "regular":
57-
self.mlp = ParallelGroupedMLP(
58-
neox_args=neox_args,
59-
init_method=init_method,
60-
output_layer_init_method=output_layer_init_method,
61-
)
62-
elif neox_args.mlp_type == "llama":
63-
self.mlp = ParallelGroupedLLaMAMLP(
64-
neox_args=neox_args,
65-
init_method=init_method,
66-
output_layer_init_method=output_layer_init_method,
67-
)
68-
else:
69-
raise KeyError(neox_args.mlp_type)
56+
self.mlp = ParallelGroupedMLP(
57+
neox_args=neox_args,
58+
init_method=init_method,
59+
output_layer_init_method=output_layer_init_method,
60+
)
7061

7162
def indices_and_bins(self, top_expert: torch.Tensor):
7263
# Sort the expert ids to produce the scatter/gather

Diff for: megatron/model/moe_mlp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(
145145
"""
146146
super(ParallelGroupedMLP, self).__init__()
147147

148-
self.activation_func = get_activation(neox_args)
148+
self.activation_func, self.activation_fn_is_gated = get_activation(neox_args)
149149
self.activation_type = neox_args.activation
150150

151151
self.multiple_of = multiple_of
@@ -334,7 +334,7 @@ def __init__(
334334
"""
335335
super(ParallelGroupedLLaMAMLP, self).__init__()
336336

337-
self.activation_func = get_activation(neox_args)
337+
self.activation_func, self.activation_fn_is_gated = get_activation(neox_args)
338338
self.activation_type = neox_args.activation
339339

340340
self.multiple_of = multiple_of

Diff for: megatron/model/router.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,22 @@ def __init__(
6060
)
6161
init_method(self.layer.weight)
6262

63-
def sinkhorn(self, cost: torch.Tensor, tol: float = 0.0001):
63+
def sinkhorn(self, cost: torch.Tensor, tol: float = 0.0001, max_iter=3):
6464
"""Sinkhorn based MoE routing function"""
6565
cost = torch.exp(cost)
6666
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
67-
d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
67+
d1 = 1 / (cost.size(1) * torch.sum(cost, 0))
6868

6969
eps = 0.00000001
7070
error = 1e9
7171
d1_old = d1
72-
while error > tol:
72+
for iteration in range(max_iter):
7373
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
7474
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
7575
error = torch.mean(torch.abs(d1_old - d1))
7676
d1_old = d1
77+
if error > tol:
78+
break
7779
return d1 * cost * d0.unsqueeze(1)
7880

7981
def sinkhorn_load_balancing(self, logits: torch.Tensor):

Diff for: megatron/model/transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ def _get_bias_dropout(self):
10061006
def forward(self, x, attention_mask, layer_past=None):
10071007
layer_past = layer_past if layer_past is not None else self.layer_past
10081008
bias_dropout_fn = self._get_bias_dropout()
1009-
1009+
10101010
# x: [b, s, h]
10111011
if self.gpt_j_residual:
10121012
# pseudocode:

Diff for: megatron/mpu/mappings.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,7 @@ def _gather(input_):
8686
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())
8787

8888
# Note: torch.cat already creates a contiguous tensor.
89-
output = torch.cat(tensor_list, dim=last_dim)
90-
91-
# Bf16 convert
92-
if dt == torch.bfloat16 and get_fp32_allreduce():
93-
output = output.bfloat16()
89+
output = torch.cat(tensor_list, dim=last_dim).contiguous()
9490

9591
return output
9692

@@ -180,6 +176,10 @@ def _dmoe_gather(input_: torch.Tensor, tokens_per_expert: torch.Tensor):
180176
# Note: torch.cat already creates a contiguous tensor.
181177
output = torch.cat(tensor_list, dim=gather_dim)
182178

179+
# Bf16 convert
180+
if dt == torch.bfloat16 and get_fp32_allreduce():
181+
output = output.bfloat16()
182+
183183
return output
184184

185185

Diff for: megatron/neox_arguments/arguments.py

-11
Original file line numberDiff line numberDiff line change
@@ -1078,17 +1078,6 @@ def calculate_derived(self):
10781078
# the sequential model without the PipelineModule wrapper to avoid the overhead it incurs
10791079
self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1)
10801080

1081-
# Do MoE checks
1082-
if self.moe_num_experts > 1:
1083-
assert not (
1084-
self.is_pipe_parallel or self.pipe_parallel_size > 1
1085-
), "MoE not supported with pipeline parallelism"
1086-
assert self.zero_optimization["stage"] != 3, "MoE not compatible with zero3"
1087-
1088-
assert (
1089-
self.sequence_parallel is False
1090-
), "MoE not compatible with Sequence Parallel"
1091-
10921081
# Attention config
10931082
if self.attention_config is None:
10941083
self.update_value("attention_config", [[["global"], self.num_layers]])

0 commit comments

Comments
 (0)