Skip to content

Commit 97b9ae5

Browse files
committed
modify by review
1 parent 84d7033 commit 97b9ae5

File tree

4 files changed

+5
-52
lines changed

4 files changed

+5
-52
lines changed

keras/src/layers/attention/grouped_query_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def call(
286286
)
287287
# (batch_dim, target_seq_len, feature_dim)
288288
if self.use_gate:
289-
output = self._output_dense(gate * output)
289+
output = self._output_dense(ops.multiply(output, gate))
290290
else:
291291
output = self._output_dense(output)
292292

keras/src/layers/attention/grouped_query_attention_test.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -225,26 +225,6 @@ def test_compute_output_shape(
225225
)
226226
self.assertEqual(output.shape, comp_output_shape)
227227

228-
layer = layers.GroupedQueryAttention(
229-
num_query_heads=num_query_heads,
230-
num_key_value_heads=num_key_value_heads,
231-
head_dim=2,
232-
use_gate=True,
233-
)
234-
batch_size = 7
235-
query_shape = (batch_size,) + query_dims
236-
value_shape = (batch_size,) + value_dims
237-
key_shape = (batch_size,) + key_dims if key_dims else None
238-
239-
query = np.ones(query_shape)
240-
value = np.ones(value_shape)
241-
key = np.ones(key_shape) if key_shape else None
242-
output = layer(query=query, value=value, key=key)
243-
comp_output_shape = layer.compute_output_shape(
244-
query_shape, value_shape, key_shape
245-
)
246-
self.assertEqual(output.shape, comp_output_shape)
247-
248228
@parameterized.named_parameters(
249229
("query_value_dim_mismatch", (2, 4, 8), (2, 2, 7), 2),
250230
("key_value_dim_mismatch", (2, 4, 8), (2, 2, 8), (2, 1, 7)),

keras/src/layers/attention/multi_head_attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class MultiHeadAttention(Layer):
6666
bias_constraint: Constraint for dense layer kernels.
6767
use_gate: Boolean, whether to apply a gated attention mechanism.
6868
When True, an additional gating branch is added based on the
69-
(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708].
69+
(Gated Attention for Large Language Models)[https://arxiv.org/abs/2505.06708].
7070
It applies a sigmoid-activated linear projection to the query
7171
which then gates the attention output. This helps improve training
7272
stability and eliminates "attention sinks".
@@ -596,7 +596,9 @@ def call(
596596
return_attention_scores,
597597
)
598598
if self._use_gate:
599-
attention_output = self._output_dense(attention_output * gate)
599+
attention_output = self._output_dense(
600+
ops.multiply(attention_output, gate)
601+
)
600602
else:
601603
attention_output = self._output_dense(attention_output)
602604

keras/src/layers/attention/multi_head_attention_test.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -370,35 +370,6 @@ def test_compute_output_shape(
370370
)
371371
self.assertEqual(output.shape, comp_output_shape)
372372

373-
layer = layers.MultiHeadAttention(
374-
num_heads=2,
375-
key_dim=2,
376-
value_dim=2,
377-
output_shape=output_shape,
378-
use_gate=True,
379-
)
380-
batch_size = 7
381-
query_shape = (batch_size,) + query_dims
382-
value_shape = (batch_size,) + value_dims
383-
key_shape = (batch_size,) + key_dims if key_dims else None
384-
385-
query = np.ones(query_shape)
386-
value = np.ones(value_shape)
387-
key = np.ones(key_shape) if key_shape else None
388-
output = layer(query=query, value=value, key=key)
389-
comp_output_shape = layer.compute_output_shape(
390-
query_shape, value_shape, key_shape
391-
)
392-
self.assertEqual(output.shape, comp_output_shape)
393-
394-
# Test shapes as lists.
395-
comp_output_shape = layer.compute_output_shape(
396-
list(query_shape),
397-
list(value_shape),
398-
list(key_shape) if key_shape is not None else None,
399-
)
400-
self.assertEqual(output.shape, comp_output_shape)
401-
402373
@parameterized.named_parameters(
403374
("query_value_dim_mismatch", (2, 4, 8), (2, 2, 7), (2,)),
404375
("key_value_dim_mismatch", (2, 4, 8), (2, 2, 8), (2, 1, 7)),

0 commit comments

Comments
 (0)