-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathgpt_oss_attention.py
More file actions
330 lines (288 loc) · 12 KB
/
gpt_oss_attention.py
File metadata and controls
330 lines (288 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import math
import keras
from keras import ops
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.utils.keras_utils import clone_initializer
class GptOssAttention(keras.layers.Layer):
"""A cached attention layer with sliding window and sink tokens.
This layer implements the attention mechanism described in the GPT-OSS
paper. It includes grouped-query attention, rotary position embeddings,
sliding window attention, and sink tokens for improved performance on
long sequences.
Args:
num_query_heads: int. The number of query attention heads.
num_key_value_heads: int. The number of key and value attention
heads.
rope_max_wavelength: int. The maximum wavelength for the
rotary position embedding. Defaults to 10000.
rope_scaling_factor: float. The scaling factor for the
rotary position embedding. Defaults to 1.0.
kernel_initializer: str. The initializer for the kernel
weights. Defaults to "glorot_uniform".
sliding_window: int. The size of the sliding window.
Defaults to 4096.
dropout: float. The dropout rate. Defaults to 0.
head_dim: int. Head dimension for attention. If None,
calculated as hidden_dim // num_query_heads. Defaults to None.
"""
def __init__(
self,
num_query_heads,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
kernel_initializer="glorot_uniform",
sliding_window=4096,
dropout=0,
head_dim=None,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.sliding_window = sliding_window
self.dropout = dropout
self.head_dim = head_dim
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_factor = rope_scaling_factor
self.num_key_value_groups = num_query_heads // num_key_value_heads
self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)
def build(self, inputs_shape):
# Einsum variables:
# b = batch size
# q = query length
# k = key/value length
# m = the model's hidden_dim
# u = num query heads
# v = num key/value heads
# h = head dim
self._hidden_dim = inputs_shape[-1]
if self.head_dim is not None:
self._head_dim = self.head_dim
else:
self._head_dim = self._hidden_dim // self.num_query_heads
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
self._rotary_dim = (self._head_dim // 2) * 2
self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, self._head_dim),
bias_axes="uh",
kernel_initializer=self._kernel_initializer,
bias_initializer="zeros",
dtype=self.dtype_policy,
name="query",
)
self.query_dense.build(inputs_shape)
self.key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
self._head_dim,
),
bias_axes="vh",
kernel_initializer=self._kernel_initializer,
bias_initializer="zeros",
dtype=self.dtype_policy,
name="key",
)
self.key_dense.build(inputs_shape)
self.value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
self._head_dim,
),
bias_axes="vh",
kernel_initializer=self._kernel_initializer,
bias_initializer="zeros",
dtype=self.dtype_policy,
name="value",
)
self.value_dense.build(inputs_shape)
self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)
self.output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, self._hidden_dim),
bias_axes="m",
kernel_initializer=self._kernel_initializer,
bias_initializer="zeros",
dtype=self.dtype_policy,
name="attention_output",
)
self.output_dense.build(
(None, None, self.num_query_heads, self._head_dim)
)
self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor, # YaRN scaling factor
rope_type="yarn",
beta_fast=32.0,
beta_slow=1.0,
original_max_position_embeddings=4096,
dtype=self.dtype_policy,
)
self.sinks = self.add_weight(
shape=(self.num_query_heads,),
initializer="random_normal",
dtype=self.dtype,
name="sinks",
)
self._dot_product_equation = "bquh,bkuh->buqk"
self._combine_equation = "buqk,bkuh->bquh"
self.built = True
def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
start_index = (
cache_update_index if cache_update_index is not None else 0
)
query = self.query_dense(hidden_states)
# Compute RoPE for queries (only
# to first _rotary_dim dimensions)
if self._rotary_dim < self._head_dim:
query_rot = query[..., : self._rotary_dim]
query_rot = self.rotary_embedding_layer(
query_rot, start_index=start_index
)
query = ops.concatenate(
[query_rot, query[..., self._rotary_dim :]], axis=-1
)
else:
query = self.rotary_embedding_layer(query, start_index=start_index)
def _compute_key_value(x):
key, value = self.key_dense(x), self.value_dense(x)
# Compute RoPE for keys (only apply to first _rotary_dim dimensions)
if self._rotary_dim < self._head_dim:
key_rot = key[..., : self._rotary_dim]
key_rot = self.rotary_embedding_layer(
key_rot, start_index=start_index
)
key = ops.concatenate(
[key_rot, key[..., self._rotary_dim :]], axis=-1
)
else:
key = self.rotary_embedding_layer(key, start_index=start_index)
return key, value
if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
key, value = _compute_key_value(hidden_states)
# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
attention_output = self._compute_attention(
query, key, value, attention_mask, start_index
)
attention_output = self.dropout_layer(
attention_output, training=training
)
attention_output = self.output_dense(attention_output)
if cache is not None:
return attention_output, cache
return attention_output
def _compute_attention(
self, query, key, value, attention_mask=None, start_index=0
):
attention_scores = ops.einsum(self._dot_product_equation, query, key)
attention_scores = ops.multiply(
attention_scores,
ops.cast(self._inv_norm_factor, self.compute_dtype),
)
# Apply sliding window mask if specified
if self.sliding_window is not None and self.sliding_window > 0:
q_len = ops.shape(attention_scores)[-2]
kv_len = ops.shape(attention_scores)[-1]
# Query positions are offset by start_index during generation
q_positions = ops.arange(q_len) + start_index
kv_positions = ops.arange(kv_len)
# Mask true for positions outside sliding window
# For causal attention: mask if kv_pos < q_pos - sliding_window
mask = (
kv_positions[None, :]
>= q_positions[:, None] - self.sliding_window
)
if self.compute_dtype == "float32":
sliding_adder = ops.cast(-1e9, self.compute_dtype)
else:
sliding_adder = ops.cast(-1e4, self.compute_dtype)
attention_scores = ops.where(
mask[None, None, :, :], attention_scores, sliding_adder
)
if attention_mask is not None:
# The mask is a boolean tensor, True for positions to be masked.
# We add a large negative number to the masked positions.
# Use a large negative value for masking
if self.compute_dtype == "float32":
adder = ops.cast(-1e9, self.compute_dtype)
else:
adder = ops.cast(-1e4, self.compute_dtype)
attention_scores = ops.where(
ops.expand_dims(attention_mask, axis=1), attention_scores, adder
)
# Handle sink tokens by concatenating them to the logits.
b = ops.shape(attention_scores)[0]
q = ops.shape(attention_scores)[2]
sinks = ops.reshape(self.sinks, (1, self.num_query_heads, 1, 1))
sinks = ops.broadcast_to(sinks, (b, self.num_query_heads, q, 1))
# attention_scores shape: [b, num_heads, q, k]
# sinks shape: [b, num_heads, q, 1]
# We need to concatenate along the last dimension
combined_logits = ops.concatenate([attention_scores, sinks], axis=-1)
# Stabilize logits before softmax for numerical stability.
max_logits = ops.max(combined_logits, axis=-1, keepdims=True)
max_logits = ops.stop_gradient(max_logits)
combined_logits = combined_logits - max_logits
probs = ops.softmax(combined_logits, axis=-1)
# Remove the sink probabilities before computing the output.
attention_scores = probs[..., :-1]
attention_scores = ops.cast(attention_scores, self.compute_dtype)
attention_output = ops.einsum(
self._combine_equation, attention_scores, value
)
return attention_output
def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"kernel_initializer": keras.initializers.serialize(
self._kernel_initializer
),
"sliding_window": self.sliding_window,
"dropout": self.dropout,
"head_dim": self.head_dim,
}
)
return config