-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathgemma3_attention.py
More file actions
419 lines (362 loc) · 15.3 KB
/
gemma3_attention.py
File metadata and controls
419 lines (362 loc) · 15.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import inspect
import keras
import numpy as np
from keras import ops
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization
from keras_hub.src.utils.keras_utils import clone_initializer
from keras_hub.src.utils.keras_utils import fused_attention_op_available
from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
from keras_hub.src.utils.keras_utils import running_on_gpu
from keras_hub.src.utils.keras_utils import running_on_tpu
class CachedGemma3Attention(keras.layers.Layer):
"""A cached grouped query attention layer for Gemma3.
This is the same as the attention layer used for Gemma and Gemma2. It
exposes a few additional args:
`use_query_key_norm`: bool. If True, apply RMS normalization on query
and key. For Gemma3, this is True.
`rope_wavelength`: float. Configurable value for RoPE wavelength. Gemma3
uses 10K for local attention layers and 1M for global attention layers.
`gate_dim_reduction`: int. In the gating layers, the output dimension is
`intermediate_dim // gate_dim_reduction`. For Gemma and Gemma2, this
value is 2. For Gemma3, it is 1.
Moreover, the call() method takes in a `cache_update_mask` so as to make
sure that the key-value cache is updated only for the non-prompt tokens
during generation.
"""
def __init__(
self,
head_dim,
num_query_heads,
num_key_value_heads,
kernel_initializer="glorot_uniform",
logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
query_head_dim_normalize=True,
use_query_key_norm=False,
layer_norm_epsilon=1e-6,
rope_wavelength=10_000.0,
rope_scaling_factor=1.0,
use_bidirectional_attention=False,
dropout=0,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.logit_soft_cap = logit_soft_cap
self.use_sliding_window_attention = use_sliding_window_attention
self.sliding_window_size = sliding_window_size
self.query_head_dim_normalize = query_head_dim_normalize
self.use_query_key_norm = use_query_key_norm
self.layer_norm_epsilon = layer_norm_epsilon
self.rope_wavelength = rope_wavelength
self.rope_scaling_factor = rope_scaling_factor
self.use_bidirectional_attention = use_bidirectional_attention
self.dropout = dropout
self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)
self.num_key_value_groups = num_query_heads // num_key_value_heads
self.query_head_dim_normalize = query_head_dim_normalize
def build(self, inputs_shape):
self.hidden_dim = inputs_shape[-1]
self.query_dense = keras.layers.EinsumDense(
"btd,ndh->btnh",
output_shape=(None, self.num_query_heads, self.head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="query",
)
self.query_dense.build(inputs_shape)
self.key_dense = keras.layers.EinsumDense(
"bsd,kdh->bskh",
output_shape=(None, self.num_key_value_heads, self.head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="key",
)
self.key_dense.build(inputs_shape)
self.value_dense = keras.layers.EinsumDense(
"bsd,kdh->bskh",
output_shape=(None, self.num_key_value_heads, self.head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="value",
)
self.value_dense.build(inputs_shape)
if self.use_query_key_norm:
self.query_norm = RMSNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="query_norm",
)
self.query_norm.build(
self.query_dense.compute_output_shape(inputs_shape)
)
self.key_norm = RMSNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="key_norm",
)
self.key_norm.build(
self.key_dense.compute_output_shape(inputs_shape)
)
self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)
self.output_dense = keras.layers.EinsumDense(
equation="btnh,nhd->btd",
output_shape=(None, self.hidden_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self.output_dense.build(
(None, None, self.num_query_heads, self.head_dim)
)
self.softmax = keras.layers.Softmax(dtype="float32")
self.rope_layer = RotaryEmbedding(
max_wavelength=self.rope_wavelength,
scaling_factor=self.rope_scaling_factor,
dtype=self.dtype_policy,
)
self.built = True
def _apply_rope(self, x, start_index):
"""Rope rotate q or k."""
x = self.rope_layer(x, start_index=start_index)
return x
def _use_fused_attention_op(self):
if not fused_attention_op_available():
return False
if self.dropout > 0.0:
return False
if running_on_gpu():
# GPU never supports softcap in the fused op.
if self.logit_soft_cap is not None:
return False
return gpu_supports_fused_attention_op()
elif running_on_tpu():
# TPU supports softcap with on keras >= 3.10.
sig = inspect.signature(ops.dot_product_attention)
return "attn_logits_soft_cap" in sig.parameters
else:
return False
def _compute_attention(
self,
q,
k,
v,
attention_mask,
training=False,
cache_update_index=0,
):
if self.query_head_dim_normalize:
query_normalization = 1 / np.sqrt(self.head_dim)
else:
query_normalization = 1 / np.sqrt(
self.hidden_dim // self.num_query_heads
)
if self.use_sliding_window_attention and attention_mask is not None:
attention_mask = self._mask_sliding_window(
attention_mask,
cache_update_index=cache_update_index,
)
if self._use_fused_attention_op():
if attention_mask is not None:
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")
# Only pass soft cap if needed as not all keras versions support.
if self.logit_soft_cap:
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
else:
kwargs = {}
return ops.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
scale=query_normalization,
**kwargs,
)
q *= ops.cast(query_normalization, dtype=q.dtype)
q_shape = ops.shape(q)
q = ops.reshape(
q,
(
*q_shape[:-2],
self.num_key_value_heads,
self.num_query_heads // self.num_key_value_heads,
q_shape[-1],
),
)
b, q_len, _, _, h = ops.shape(q)
# Fallback to standard attention if flash attention is disabled
attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)
if self.logit_soft_cap is not None:
attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
attention_logits = ops.multiply(
ops.tanh(attention_logits), self.logit_soft_cap
)
if attention_mask is not None:
# We add two dimensions at axis 1 and 2 to make it [B, 1, 1, S, S]
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.expand_dims(attention_mask, axis=1)
orig_dtype = attention_logits.dtype
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
attention_softmax = ops.cast(attention_softmax, orig_dtype)
if self.dropout:
attention_softmax = self.dropout_layer(
attention_softmax, training=training
)
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
return ops.reshape(results, (b, q_len, self.num_query_heads, h))
def _compute_bidirectional_sliding_mask(self, batch_size, sequence_length):
"""Computes a bidirectional sliding window attention mask.
A token can attend to any other token if their absolute distance is
within half the sliding window size. This mask is used in embedding
models like `EmbeddingGemma`.
Args:
batch_size: The batch size for the mask.
sequence_length: The length of the sequence.
Returns:
A boolean attention mask with shape
`(batch_size, sequence_length, sequence_length)`.
"""
i = keras.ops.expand_dims(
keras.ops.arange(sequence_length, dtype="int32"), axis=1
)
j = keras.ops.arange(sequence_length, dtype="int32")
# If sliding window size is 4, the token in question attends to 1
# token before and 2 tokens after.
w_right = self.sliding_window_size // 2
w_left = self.sliding_window_size - w_right - 1
# Calculate the relative distance.
distance = i - j
mask = keras.ops.logical_and(distance <= w_left, distance >= -w_right)
mask = keras.ops.expand_dims(mask, axis=0)
return keras.ops.broadcast_to(
mask, (batch_size, sequence_length, sequence_length)
)
def _mask_sliding_window(
self,
attention_mask,
cache_update_index=0,
):
batch_size, query_len, key_len = ops.shape(attention_mask)
if self.use_bidirectional_attention:
bidirectional_sliding_mask = (
self._compute_bidirectional_sliding_mask(
batch_size=batch_size,
# `query_len = key_len` for embedding models
sequence_length=query_len,
)
)
return ops.logical_and(attention_mask, bidirectional_sliding_mask)
# Compute the sliding window for square attention.
all_ones = ops.ones((key_len, key_len), "bool")
if keras.config.backend() == "tensorflow":
# TODO: trui/tril has issues with dynamic shape on the tensorflow
# backend. We should fix, but use `band_part` for now.
import tensorflow as tf
band_size = ops.minimum(key_len, self.sliding_window_size - 1)
band_size = ops.cast(band_size, "int32")
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
else:
sliding_mask = ops.triu(
all_ones, -1 * self.sliding_window_size + 1
) * ops.tril(all_ones, self.sliding_window_size - 1)
# Slice the window for short queries during generation.
start = (cache_update_index, 0)
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
sliding_mask = ops.expand_dims(sliding_mask, 0)
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
def call(
self,
x,
attention_mask=None,
cache=None,
cache_update_index=0,
cache_update_mask=None,
training=False,
):
query = self.query_dense(x)
if self.use_query_key_norm:
query = self.query_norm(query)
query = self._apply_rope(query, cache_update_index)
if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
key_update = self.key_dense(x)
if self.use_query_key_norm:
key_update = self.key_norm(key_update)
key_update = self._apply_rope(key_update, cache_update_index)
value_update = self.value_dense(x)
# Update cache. Note that the cache is updated only if the
# corresponding `cache_update_mask` value is True. This is to
# ensure that we don't update the cache at indices corresponding to
# the prompt. For Gemma3, in particular, this is useful because
# image tokens have bidirectional attention. During generation,
# if we have uneven inputs during generation, we might end up having
# causal attention between image tokens, which is incorrect. To
# avoid this, bidirectional attention is taken care of during
# the prefill step, and during generation, the cache is not updated
# for the prompt. The shape of `cache_update_mask` is
# `(bsz, seq_len)`, where `seq_len` is 1 when we are generating
# token-by-token.
start = [0, cache_update_index, 0, 0]
if cache_update_mask is not None:
cache_update_mask = ops.expand_dims(
ops.expand_dims(cache_update_mask, axis=-1),
axis=-1,
)
key_original = ops.slice(
key_cache, start, ops.shape(key_update)
)
value_original = ops.slice(
value_cache, start, ops.shape(value_update)
)
key_update = ops.where(
cache_update_mask,
key_update,
key_original,
)
value_update = ops.where(
cache_update_mask,
value_update,
value_original,
)
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
key = self.key_dense(x)
if self.use_query_key_norm:
key = self.key_norm(key)
key = self._apply_rope(key, cache_update_index)
value = self.value_dense(x)
attention_vec = self._compute_attention(
query,
key,
value,
attention_mask,
training=training,
cache_update_index=cache_update_index,
)
# Wipe attn vec if there are no attended tokens.
no_attended_tokens = ops.expand_dims(
ops.all(ops.equal(attention_mask, 0), axis=-1, keepdims=True),
axis=-1,
)
attention_vec = ops.where(
no_attended_tokens, ops.zeros_like(attention_vec), attention_vec
)
attention_output = self.output_dense(attention_vec)
if cache is not None:
return attention_output, cache
return attention_output
def compute_output_shape(self, input_shape):
return input_shape