Skip to content
This repository was archived by the owner on Jul 15, 2025. It is now read-only.

Commit 059ae20

Browse files
authored
Merge pull request databricks#38 from sashaDoubov/sasha/glu
Add GLU support
2 parents 8b959f2 + ee5ff20 commit 059ae20

File tree

8 files changed

+207
-16
lines changed

8 files changed

+207
-16
lines changed

megablocks/layers/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Arguments:
3838

3939
# Compute arguments.
4040
memory_optimized_mlp : bool = False
41+
mlp_type : str = 'mlp'
4142
grouped_mlp : bool = False
4243
quantize_inputs_num_bits: int = -1 # -1 = no quantization
4344
quantize_rematerialize_num_bits: int = -1

megablocks/layers/dmlp_registry.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Union
2+
from megablocks.layers import mlp
3+
from megablocks.layers import glu
4+
from megablocks.layers.arguments import Arguments
5+
6+
MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
7+
8+
_REGISTRY = {
9+
'mlp': {'grouped': mlp.GroupedMLP, 'sparse' : mlp.SparseMLP},
10+
'glu': {'grouped': glu.GroupedGLU, 'sparse': glu.SparseGLU},
11+
}
12+
13+
def get(args: Arguments) -> MlpType:
14+
"""Returns an MLP for use in a dMoE instance.
15+
16+
Uses the provided arguments to instantiate the appropriate
17+
MLP instance. This only contains MLPs for use in dMoEs
18+
(ie. only for the dropless versions of MoEs).
19+
20+
Args:
21+
args: propagated Arguments dataclass.
22+
23+
Returns:
24+
An instantiated MLP constructed using the input args.
25+
26+
"""
27+
if args.mlp_type not in _REGISTRY:
28+
raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
29+
30+
mlp_impl = 'grouped' if args.grouped_mlp else 'sparse'
31+
32+
if mlp_impl not in _REGISTRY[args.mlp_type]:
33+
raise ValueError(f'{args.mlp_type} does not support {mlp_impl} backend.')
34+
35+
return _REGISTRY[args.mlp_type][mlp_impl](args)

megablocks/layers/dmoe.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from megablocks.layers import common
2-
from megablocks.layers import mlp
32
from megablocks.layers import moe
3+
from megablocks.layers import dmlp_registry
44
from megablocks.layers import mpu
55
from megablocks.layers import router
66
from megablocks.layers.arguments import Arguments
@@ -9,25 +9,17 @@
99
import stk
1010
import torch
1111

12-
1312
def promote_scalar(x):
1413
return x.view(1) if not len(x.size()) else x
1514

16-
1715
class ParallelDroplessMLP(moe.ParallelMLP):
1816

1917
def __init__(self, args : Arguments):
2018
super(ParallelDroplessMLP, self).__init__(args)
2119
self.hidden_size = args.hidden_size
2220
self.ffn_hidden_size = mpu.features_per_rank(args)
2321
self.blocking = 128
24-
25-
# Grouped or sparse MLP.
26-
self.mlp = (
27-
mlp.GroupedMLP(args)
28-
if args.grouped_mlp
29-
else mlp.SparseMLP(args)
30-
)
22+
self.mlp = dmlp_registry.get(args)
3123

3224
# Calculate the number of bits needed to represent the column indices
3325
# in the intermediate sparse matrix.

megablocks/layers/dmoe_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_modules(
3030
memory_optimized_mlp=True,
3131
quantize_inputs_num_bits=num_input_bits,
3232
quantize_rematerialize_num_bits=num_remat_bits,
33+
mlp_type='mlp',
3334
grouped_mlp=grouped_mlp,
3435
fp16=False,
3536
bf16=True)

megablocks/layers/glu.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from megablocks.layers import common
2+
from megablocks.layers import gelu
3+
from megablocks.layers.mlp import SparseMLP, create_dmoe_expert_weights
4+
from megablocks.layers import mpu
5+
from megablocks.layers.arguments import Arguments, InitFn
6+
from megablocks import grouped_gemm_util as gg
7+
import stk
8+
import torch
9+
import torch.nn.functional as F
10+
11+
12+
class SparseGLU(SparseMLP):
13+
14+
def __init__(self, args : Arguments):
15+
super().__init__(args)
16+
self.v1 = torch.nn.Parameter(torch.empty(
17+
self._num_rows_per_rank,
18+
args.hidden_size,
19+
device=args.device,
20+
dtype=common.dtype(args)))
21+
with torch.no_grad():
22+
self.v1.copy_(create_dmoe_expert_weights(
23+
args, args.moe_num_experts, args.ffn_hidden_size,
24+
args.hidden_size, args.init_method))
25+
26+
mpu.set_expert_model_parallel_attributes(
27+
self.v1, self._should_set_parallelism_attribute)
28+
29+
if self.args.moe_weight_parallelism:
30+
raise NotImplementedError("Weight parallelism not yet supported with GLU.")
31+
elif self.args.memory_optimized_mlp:
32+
raise NotImplementedError("Memory optimized implementation not yet supported with GLU.")
33+
34+
def forward(self, x, topo):
35+
w1, v1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2))
36+
37+
# Compute the GLU.
38+
x1 = stk.ops.sdd(x, w1.t(), topo)
39+
x2 = stk.ops.sdd(x, v1.t(), topo)
40+
41+
x1 = stk.ops.mul(gelu.gelu(x1), x2)
42+
43+
return stk.ops.dsd(x1, w2)
44+
45+
class GroupedGLU(SparseGLU):
46+
def forward(self, x, tokens_per_expert):
47+
batch_sizes = tokens_per_expert.cpu().to(torch.long)
48+
w1, v1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2))
49+
50+
# Re-shape the weights for the grouped GEMMs.
51+
ne = mpu.experts_per_rank(self.args)
52+
w1 = w1.view(ne, -1, self.args.hidden_size)
53+
v1 = v1.view(ne, -1, self.args.hidden_size)
54+
w2 = w2.view(ne, -1, self.args.hidden_size)
55+
56+
# Compute the MLP.
57+
x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
58+
x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
59+
x1 = F.gelu(x1, approximate="tanh") * x2
60+
return gg.ops.gmm(x1, w2, batch_sizes)

megablocks/layers/glu_test.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import unittest
2+
from functools import partial
3+
4+
from absl.testing import parameterized
5+
from megablocks.layers.arguments import Arguments
6+
from megablocks.layers.glu import SparseGLU, GroupedGLU
7+
from megablocks.layers import testing
8+
9+
import torch
10+
import stk
11+
import numpy as np
12+
13+
def test_modules(
14+
hidden_size,
15+
ffn_hidden_size,
16+
grouped_mlp=False):
17+
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
18+
args = Arguments(
19+
hidden_size=hidden_size,
20+
ffn_hidden_size=ffn_hidden_size,
21+
moe_num_experts=1,
22+
moe_top_k=1,
23+
init_method=init_method,
24+
memory_optimized_mlp=False,
25+
mlp_type='glu',
26+
grouped_mlp=grouped_mlp,
27+
fp16=False,
28+
bf16=True)
29+
30+
glu = testing.GLU(args)
31+
dmoe_glu = GroupedGLU(args) if grouped_mlp else SparseGLU(args)
32+
33+
dmoe_glu.cuda(torch.cuda.current_device()).to(torch.bfloat16)
34+
glu.cuda(torch.cuda.current_device()).to(torch.bfloat16)
35+
36+
with torch.no_grad():
37+
glu.w1.copy_(dmoe_glu.w1.T)
38+
glu.v1.copy_(dmoe_glu.v1.T)
39+
glu.w2.copy_(dmoe_glu.w2)
40+
41+
return args, glu, dmoe_glu
42+
43+
_DENSE_TESTS = (
44+
(16, 1024, 512),
45+
(8, 2048, 512),
46+
)
47+
48+
class GLUTest(parameterized.TestCase):
49+
50+
@parameterized.parameters(*_DENSE_TESTS)
51+
def testGLU_ForwardGroupedMLP(self, bs, sl, hs):
52+
x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda()
53+
54+
_, glu, dmoe_glu = test_modules(
55+
hidden_size=hs,
56+
ffn_hidden_size=hs * 2,
57+
grouped_mlp=True)
58+
59+
expected_out = glu(x)
60+
tokens_per_expert = torch.tensor([bs * sl]).cuda()
61+
out = dmoe_glu(x.view(bs * sl, hs), tokens_per_expert)
62+
out = out.view(sl, bs, hs)
63+
64+
self.assertSequenceEqual(out.shape, x.shape)
65+
self.assertSequenceEqual(expected_out.shape, x.shape)
66+
self.assertTrue(testing.allclose(out, expected_out))
67+
68+
@parameterized.parameters(*_DENSE_TESTS)
69+
def testGLU_ForwardSparseMLP(self, bs, sl, hs):
70+
x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda()
71+
72+
_, glu, dmoe_glu = test_modules(
73+
hidden_size=hs,
74+
ffn_hidden_size=hs * 2,
75+
grouped_mlp=False)
76+
77+
expected_out = glu(x)
78+
with torch.no_grad():
79+
topo = stk.random.mask(bs * sl, hs * 2, 0, blocking=128).cuda()
80+
out = dmoe_glu(x.view(bs * sl, hs), topo)
81+
out = out.view(sl, bs, hs)
82+
83+
self.assertSequenceEqual(out.shape, x.shape)
84+
self.assertSequenceEqual(expected_out.shape, x.shape)
85+
self.assertTrue(testing.allclose(out, expected_out))
86+
87+
if __name__ == '__main__':
88+
unittest.main()

megablocks/layers/mlp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,18 +304,18 @@ class SparseMLP(torch.nn.Module):
304304
def __init__(self, args : Arguments):
305305
super().__init__()
306306
self.args = args
307-
num_rows_per_rank = (
307+
self._num_rows_per_rank = (
308308
(mpu.experts_per_rank(args) * mpu.features_per_rank(args)) //
309309
mpu.get_weight_parallel_world_size(args)
310310
)
311311

312312
self.w1 = torch.nn.Parameter(torch.empty(
313-
num_rows_per_rank,
313+
self._num_rows_per_rank,
314314
args.hidden_size,
315315
device=args.device,
316316
dtype=common.dtype(args)))
317317
self.w2 = torch.nn.Parameter(torch.empty(
318-
num_rows_per_rank,
318+
self._num_rows_per_rank,
319319
args.hidden_size,
320320
device=args.device,
321321
dtype=common.dtype(args)))
@@ -336,12 +336,12 @@ def __init__(self, args : Arguments):
336336
args, args.moe_num_experts, args.ffn_hidden_size,
337337
args.hidden_size, args.output_layer_init_method))
338338

339-
should_set_attribute = (
339+
self._should_set_parallelism_attribute = (
340340
args.moe_expert_model_parallelism or args.moe_weight_parallelism)
341341
mpu.set_expert_model_parallel_attributes(
342-
self.w1, should_set_attribute)
342+
self.w1, self._should_set_parallelism_attribute)
343343
mpu.set_expert_model_parallel_attributes(
344-
self.w2, should_set_attribute)
344+
self.w2, self._should_set_parallelism_attribute)
345345

346346
self.gradient_scale = None
347347
if self.args.moe_expert_model_parallelism:

megablocks/layers/testing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,17 @@ def __init__(self, args : Arguments):
3030
def forward(self, x):
3131
return torch.matmul(F.gelu(
3232
torch.matmul(x, self.w1), approximate="tanh"), self.w2)
33+
34+
class GLU(FFN):
35+
36+
def __init__(self, args : Arguments):
37+
super().__init__(args)
38+
self.v1 = torch.nn.Parameter(torch.empty(
39+
args.hidden_size,
40+
args.ffn_hidden_size,
41+
device=args.device,
42+
dtype=torch.float16 if args.fp16 else torch.float32))
43+
44+
def forward(self, x):
45+
x1 = F.gelu(torch.matmul(x, self.w1), approximate="tanh") * torch.matmul(x, self.v1)
46+
return torch.matmul(x1, self.w2)

0 commit comments

Comments
 (0)