-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathopenelm.py
223 lines (180 loc) · 6.7 KB
/
openelm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
head_dim: int
num_transformer_layers: int
model_dim: int
vocab_size: int
ffn_dim_divisor: int
num_query_heads: List
num_kv_heads: List
ffn_multipliers: List
ffn_with_glu: bool = True
normalize_qk_projections: bool = True
share_input_output_layers: bool = True
rms_norm_eps: float = 1e-6
rope_freq_constant: float = 10000
def make_divisible(
v: Union[float, int],
divisor: int = 8,
min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by the divisor
It can be seen at:
https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
Args:
v: input value
divisor: default to 8
min_value: minimum divisor value
Returns:
new_v: new divisible value
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
self.head_dim = head_dim = args.head_dim
self.layer_id = layer_id
self.model_dim = model_dim = args.model_dim
self.n_heads = n_heads = args.num_query_heads[layer_id]
self.n_kv_heads = n_kv_heads = args.num_kv_heads[layer_id]
self.scale = head_dim**-0.5
op_size = (n_heads + (n_kv_heads * 2)) * head_dim
self.qkv_proj = nn.Linear(model_dim, op_size, bias=False)
self.out_proj = nn.Linear(n_heads * head_dim, model_dim, bias=False)
self.normalize_qk_projections = args.normalize_qk_projections
if self.normalize_qk_projections:
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.qkv_proj(x)
qkv = qkv.reshape(
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
).transpose(0, 2, 1, 3)
queries, keys, values = mx.split(
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
)
# Prepare the queries, keys and values for the attention computation
if self.normalize_qk_projections:
queries = self.q_norm(queries)
keys = self.k_norm(keys)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
self.args = args
dim = args.model_dim
ffn_multiplier = args.ffn_multipliers[layer_id]
intermediate_dim = int(
make_divisible(
ffn_multiplier * args.model_dim,
divisor=args.ffn_dim_divisor,
)
)
self.proj_1 = nn.Linear(dim, 2 * intermediate_dim, bias=False)
self.proj_2 = nn.Linear(intermediate_dim, dim, bias=False)
def __call__(self, x) -> mx.array:
x = self.proj_1(x)
gate, x = mx.split(x, 2, axis=-1)
return self.proj_2(nn.silu(gate) * x)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
dim = args.model_dim
self.attn = Attention(args, layer_id=layer_id)
self.ffn = MLP(args, layer_id=layer_id)
self.ffn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
self.attn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.attn_norm(x), mask, cache)
h = x + r
r = self.ffn(self.ffn_norm(h))
out = h + r
return out
class OpenELMModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_transformer_layers = args.num_transformer_layers
assert self.vocab_size > 0
self.token_embeddings = nn.Embedding(args.vocab_size, args.model_dim)
self.layers = [
TransformerBlock(args, layer_id=layer_id)
for layer_id in range(self.num_transformer_layers)
]
self.norm = nn.RMSNorm(args.model_dim, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.token_embeddings(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = OpenELMModel(args)
if not args.share_input_output_layers:
self.lm_head = nn.Linear(args.model_dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, mask, cache)
if self.args.share_input_output_layers:
out = self.transformer.token_embeddings.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.layers