-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathgpt_oss_backbone.py
More file actions
219 lines (201 loc) · 8.29 KB
/
gpt_oss_backbone.py
File metadata and controls
219 lines (201 loc) · 8.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import keras
from keras.layers import ReversibleEmbedding
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.gpt_oss.gpt_oss_decoder import (
GptOssTransformerDecoder,
)
from keras_hub.src.models.gpt_oss.gpt_oss_layer_norm import (
GptOssLayerNormalization,
)
def _gpt_oss_kernel_initializer(stddev=0.02):
return keras.initializers.RandomNormal(stddev=stddev)
@keras_hub_export("keras_hub.models.GptOssBackbone")
class GptOssBackbone(Backbone):
"""A GPT-style Transformer with a Mixture of Experts.
This network implements a GPT-style decoder network with Mixture of Expert
(MoE) layers, similar to the architecture described in
["Mixtral of Experts"](https://arxiv.org/pdf/2401.04088) but with
customizations found in some open-source GPT models. It includes the
embedding lookups and transformer layers.
The default constructor gives a fully customizable, randomly initialized
GptOss model with any number of layers, heads, and embedding
dimensions. To load preset architectures and weights, use the `from_preset`
constructor.
Args:
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
num_query_heads: int. The number of query attention heads for
each transformer.
hidden_dim: int. The size of the transformer encoding and pooling
layers.
intermediate_dim: int. The output dimension of the first Dense layer
in a three-layer feedforward network for each transformer.
num_key_value_heads: int. The number of key and value attention heads
for each transformer.
num_experts: int. The number of experts for the MoE layers.
top_k: int. The number of experts to use for each token.
Defaults to `2`.
rope_max_wavelength: int. The maximum angular wavelength of
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
rope_scaling_factor: float. The scaling factor for
calculation of rotary embedding. Defaults to `1.0`.
layer_norm_epsilon: float. Epsilon for the layer
normalization layers in the transformer decoder. Defaults to `1e-6`.
sliding_window: int. The sliding window for the attention
layers. This controls the maximum cache size for the attention
layers in each transformer decoder. Only `sliding_window` number
of tokens are saved in the cache and used to generate the next
token. Defaults to `4096`.
head_dim: int. Head dimension for attention layers. This
parameter is accepted for HuggingFace compatibility but ignored.
The head dimension is calculated dynamically as hidden_dim //
num_query_heads. Defaults to `None`.
dropout: float. Attention dropout probability.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
`float32` precision regardless of dtype.
Examples:
```python
import numpy as np
import keras_hub
# Load a pretrained GptOss backbone from a preset.
model = keras_hub.models.GptOssBackbone.from_preset("gpt_oss_20b_en")
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
"padding_mask": np.array(
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], dtype="int32"
),
}
model(input_data)
# Randomly initialized GptOss decoder with custom config.
model = keras_hub.models.GptOssBackbone(
vocabulary_size=10,
hidden_dim=512,
num_layers=2,
num_query_heads=32,
num_key_value_heads=8,
intermediate_dim=1024,
num_experts=4,
top_k=2,
sliding_window=256,
layer_norm_epsilon=1e-6,
dtype="float32"
)
model(input_data)
```
"""
def __init__(
self,
vocabulary_size,
num_layers,
num_query_heads,
hidden_dim,
intermediate_dim,
num_key_value_heads,
num_experts,
top_k=2,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
layer_norm_epsilon=1e-6,
sliding_window=4096,
head_dim=None,
dropout=0,
output_router_logits=False,
dtype=None,
**kwargs,
):
# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
tie_weights=False,
embeddings_initializer=_gpt_oss_kernel_initializer(stddev=0.01),
dtype=dtype,
name="token_embedding",
)
self.transformer_layers = []
for i in range(num_layers):
layer = GptOssTransformerDecoder(
intermediate_dim=intermediate_dim,
num_query_heads=num_query_heads,
num_key_value_heads=num_key_value_heads,
num_experts=num_experts,
top_k=top_k,
output_router_logits=output_router_logits,
rope_max_wavelength=rope_max_wavelength,
rope_scaling_factor=rope_scaling_factor,
layer_norm_epsilon=layer_norm_epsilon,
kernel_initializer=_gpt_oss_kernel_initializer(stddev=0.02),
# GPT-OSS uses SW attention in every other layer
sliding_window=sliding_window if i % 2 == 1 else None,
dropout=dropout,
head_dim=head_dim,
dtype=dtype,
name=f"transformer_layer_{i}",
)
self.transformer_layers.append(layer)
self.layer_norm = GptOssLayerNormalization(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="sequence_output_layernorm",
)
# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)
x = self.token_embedding(token_id_input)
for transformer_layer in self.transformer_layers:
x = transformer_layer(x, decoder_padding_mask=padding_mask_input)
sequence_output = self.layer_norm(x)
super().__init__(
inputs={
"token_ids": token_id_input,
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)
# === Config ===
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_query_heads = num_query_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_key_value_heads = num_key_value_heads
self.num_experts = num_experts
self.top_k = top_k
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_factor = rope_scaling_factor
self.sliding_window = sliding_window
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
self.output_router_logits = output_router_logits
self.head_dim = head_dim
def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_query_heads": self.num_query_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_experts": self.num_experts,
"top_k": self.top_k,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"num_key_value_heads": self.num_key_value_heads,
"sliding_window": self.sliding_window,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"output_router_logits": self.output_router_logits,
"head_dim": self.head_dim,
}
)
return config