Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,9 +1101,9 @@ def _get_distributed_iterator(self, distribution):
for data in self.data_adapter.get_jax_iterator():
if layouts is None:
layouts = tree.map_structure(
lambda d: distribution.get_data_layout(
d.shape
).backend_layout,
lambda d: (
distribution.get_data_layout(d.shape).backend_layout
),
data,
)
yield _distribute_data(data, layouts)
Expand Down
4 changes: 2 additions & 2 deletions keras/src/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ def while_loop(
maximum_iterations=None,
):
current_iter = 0
iteration_check = (
lambda iter: maximum_iterations is None or iter < maximum_iterations
iteration_check = lambda iter: (
maximum_iterations is None or iter < maximum_iterations
)
is_tuple = isinstance(loop_vars, (tuple, list))
loop_vars = tuple(loop_vars) if is_tuple else (loop_vars,)
Expand Down
4 changes: 2 additions & 2 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,8 @@ def while_loop(
maximum_iterations=None,
):
current_iter = 0
iteration_check = (
lambda iter: maximum_iterations is None or iter < maximum_iterations
iteration_check = lambda iter: (
maximum_iterations is None or iter < maximum_iterations
)
is_tuple = isinstance(loop_vars, (tuple, list))
loop_vars = tuple(loop_vars) if is_tuple else (loop_vars,)
Expand Down
32 changes: 26 additions & 6 deletions keras/src/layers/attention/grouped_query_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class GroupedQueryAttention(Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_gate: Boolean, whether to apply a gated attention mechanism.
When True, an additional gating branch is added based on the
(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The reference to "(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708]" appears to be a placeholder. The year is in the future and the arXiv link is invalid. This violates the Keras API design guidelines, which require new features to be based on widely recognized best practices. Please replace this with a valid reference to the paper that introduced this gated attention mechanism, or provide a more general explanation of the technique if a specific paper isn't the source.

References
  1. New features should be widely recognized as a machine learning best practice and not based on very recent or non-existent papers. (link)

It applies a sigmoid-activated linear projection to the query
which then gates the attention output. This helps improve training
stability and eliminates "attention sinks".
seed: Optional integer to seed the dropout layer.

Call arguments:
Expand Down Expand Up @@ -102,6 +108,7 @@ def __init__(
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_gate=False,
seed=None,
**kwargs,
):
Expand All @@ -117,6 +124,7 @@ def __init__(
self.num_repeats = num_query_heads // num_key_value_heads
self.dropout = dropout
self.use_bias = use_bias
self.use_gate = use_gate
self._flash_attention = flash_attention or is_flash_attention_enabled()
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
Expand Down Expand Up @@ -170,7 +178,16 @@ def build(
**self._get_common_kwargs_for_sublayer(),
)
self._key_dense.build(key_shape)

if self.use_gate:
self._gate_dense = EinsumDense(
"bqm,muh->bquh",
output_shape=(None, self.num_query_heads, self.head_dim),
bias_axes="uh" if self.use_bias else None,
activation="sigmoid",
name="gate",
**self._get_common_kwargs_for_sublayer(),
)
self._gate_dense.build(key_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _gate_dense layer is a projection of the query, but it's being built using key_shape. While this may not cause an error if EinsumDense only relies on the feature dimension (which is the same for query and key), it's semantically incorrect and confusing for future maintenance. For clarity and correctness, please build this layer using query_shape.

Suggested change
self._gate_dense.build(key_shape)
self._gate_dense.build(query_shape)

self._value_dense = EinsumDense(
"bkm,mvh->bkvh",
output_shape=(None, self.num_key_value_heads, self.head_dim),
Expand Down Expand Up @@ -247,7 +264,8 @@ def call(
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
)

if self.use_gate:
gate = self._gate_dense(query)
query = self._query_dense(query)
key = self._key_dense(key)
value = self._value_dense(value)
Expand All @@ -266,10 +284,11 @@ def call(
attention_mask=attention_mask,
training=training,
)

output = self._output_dense(
output
) # (batch_dim, target_seq_len, feature_dim)
# (batch_dim, target_seq_len, feature_dim)
if self.use_gate:
output = self._output_dense(gate * output)
else:
output = self._output_dense(output)

if return_attention_scores:
return output, scores
Expand Down Expand Up @@ -483,6 +502,7 @@ def get_config(self):
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"use_bias": self.use_bias,
"use_gate": self.use_gate,
"dropout": self.dropout,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
Expand Down
116 changes: 116 additions & 0 deletions keras/src/layers/attention/grouped_query_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,44 @@ def test_basics(self):
run_training_check=False,
)

self.run_layer_test(
layers.GroupedQueryAttention,
init_kwargs={
"num_query_heads": 2,
"num_key_value_heads": 2,
"head_dim": 2,
"use_gate": True,
},
input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)},
expected_output_shape=(2, 8, 16),
expected_num_trainable_weights=10,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)

self.run_layer_test(
layers.GroupedQueryAttention,
init_kwargs={
"num_query_heads": 2,
"num_key_value_heads": 2,
"head_dim": 2,
"use_bias": False,
"dropout": 0.5,
"use_gate": True,
},
input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)},
expected_output_shape=(2, 8, 16),
expected_num_trainable_weights=5,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=1,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)

@pytest.mark.skipif(
backend.backend() not in ("jax", "torch"),
reason="Flash attention only supported on JAX and Torch",
Expand Down Expand Up @@ -187,6 +225,26 @@ def test_compute_output_shape(
)
self.assertEqual(output.shape, comp_output_shape)

layer = layers.GroupedQueryAttention(
num_query_heads=num_query_heads,
num_key_value_heads=num_key_value_heads,
head_dim=2,
use_gate=True,
)
batch_size = 7
query_shape = (batch_size,) + query_dims
value_shape = (batch_size,) + value_dims
key_shape = (batch_size,) + key_dims if key_dims else None

query = np.ones(query_shape)
value = np.ones(value_shape)
key = np.ones(key_shape) if key_shape else None
output = layer(query=query, value=value, key=key)
comp_output_shape = layer.compute_output_shape(
query_shape, value_shape, key_shape
)
self.assertEqual(output.shape, comp_output_shape)

@parameterized.named_parameters(
("query_value_dim_mismatch", (2, 4, 8), (2, 2, 7), 2),
("key_value_dim_mismatch", (2, 4, 8), (2, 2, 8), (2, 1, 7)),
Expand All @@ -207,6 +265,7 @@ def test_initializer(self):
num_query_heads=16,
num_key_value_heads=16,
head_dim=64,
use_gate=True,
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
)
layer.build((2, 4, 8), (2, 4, 8))
Expand All @@ -225,6 +284,11 @@ def test_initializer(self):
layer._output_dense.kernel,
)

self.assertNotAllClose(
layer._query_dense.kernel,
layer._gate_dense.kernel,
)

@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
Expand Down Expand Up @@ -252,6 +316,30 @@ def test_query_mask_propagation(self):
)
self.assertAllClose(masked_query._keras_mask, output._keras_mask)

try:
layer = layers.GroupedQueryAttention(
num_query_heads=2,
num_key_value_heads=2,
head_dim=2,
use_gate=True,
)
self.assertTrue(layer.supports_masking)
query = np.array(
[[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]
)
masked_query = layers.Embedding(4, 8, mask_zero=True)(query)
value = np.random.normal(size=(3, 3, 8))
output = layer(query=masked_query, value=value)
except RuntimeError as e:
if e.args[0].startswith(
"(*bias): last dimension must be contiguous"
):
self.skipTest(
"PyTorch errors out on GPU: issue to track bug is here "
"https://github.com/keras-team/keras/issues/20459"
)
self.assertAllClose(masked_query._keras_mask, output._keras_mask)

@parameterized.named_parameters(("causal", True), ("not_causal", 0))
@pytest.mark.skipif(
backend.backend() == "numpy",
Expand Down Expand Up @@ -287,6 +375,34 @@ def test_masking(self, use_causal_mask):
)
self.assertAllClose(output, output_with_manual_mask)

layer = layers.GroupedQueryAttention(
num_query_heads=2, num_key_value_heads=2, head_dim=2, use_gate=True
)
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = layers.Embedding(4, 8, mask_zero=True)(query)
value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
masked_value = layers.Embedding(6, 8, mask_zero=True)(value)
output = layer(
query=masked_query,
value=masked_value,
use_causal_mask=use_causal_mask,
)
mask = np.array(
[[[1, 1, 0]] * 3 + [[0, 0, 0]] * 2]
+ [[[1, 0, 0]] * 5]
+ [[[1, 1, 1]] + [[0, 0, 0]] * 4]
).astype(bool)
if use_causal_mask:
mask = mask & np.array(
[[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3]
).astype(bool)
del masked_query._keras_mask
del masked_value._keras_mask
output_with_manual_mask = layer(
query=masked_query, value=masked_value, attention_mask=mask
)
self.assertAllClose(output, output_with_manual_mask)

@parameterized.named_parameters(
("disable_flash_attention", False), ("enable_flash_attention", True)
)
Expand Down
30 changes: 29 additions & 1 deletion keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class MultiHeadAttention(Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_gate: Boolean, whether to apply a gated attention mechanism.
When True, an additional gating branch is added based on the
(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The reference to "(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708]" appears to be a placeholder. The year is in the future and the arXiv link is invalid. This violates the Keras API design guidelines, which require new features to be based on widely recognized best practices. Please replace this with a valid reference to the paper that introduced this gated attention mechanism, or provide a more general explanation of the technique if a specific paper isn't the source.

References
  1. New features should be widely recognized as a machine learning best practice and not based on very recent or non-existent papers. (link)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The link is fine, I'm not sure what Gemini is looking at. However, can you change the title from NeurIPS 2025 Best Paper to Gated Attention for Large Language Models?

It applies a sigmoid-activated linear projection to the query
which then gates the attention output. This helps improve training
stability and eliminates "attention sinks".
seed: Optional integer to seed the dropout layer.

Call arguments:
Expand Down Expand Up @@ -117,6 +123,7 @@ def __init__(
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_gate=False,
seed=None,
**kwargs,
):
Expand All @@ -127,6 +134,7 @@ def __init__(
self._value_dim = value_dim if value_dim else key_dim
self._dropout = dropout
self._use_bias = use_bias
self._use_gate = use_gate
if output_shape:
if isinstance(output_shape, int):
output_shape = (output_shape,)
Expand Down Expand Up @@ -201,6 +209,7 @@ def get_config(self):
"value_dim": self._value_dim,
"dropout": self._dropout,
"use_bias": self._use_bias,
"use_gate": self._use_gate,
"output_shape": self._output_shape,
"attention_axes": self._attention_axes,
"kernel_initializer": initializers.serialize(
Expand Down Expand Up @@ -271,6 +280,18 @@ def build(
**self._get_common_kwargs_for_sublayer(),
)
self._key_dense.build(key_shape)
if self._use_gate:
self._gate_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._num_heads, self._key_dim]
),
bias_axes=bias_axes if self._use_bias else None,
activation="sigmoid",
name="gate",
**self._get_common_kwargs_for_sublayer(),
)
self._gate_dense.build(key_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of the gated attention has a few issues that will lead to errors or incorrect behavior:

  1. Shape Mismatch Bug: The _gate_dense layer's output dimension is set to self._key_dim, but it's multiplied with attention_output, which has a dimension of self._value_dim. This will cause a runtime error if key_dim != value_dim.
  2. Incorrect Equation: The einsum_equation used is from the _key_dense layer, but the gate is a projection of the query. It should use an equation based on the query's rank.
  3. Incorrect Build Shape: The _gate_dense layer is built with key_shape, but it should be built with query_shape since it processes the query.

To fix these issues, the _gate_dense layer should be defined using the query's properties and its output dimension should be self._value_dim. This may require refactoring the build method slightly to define _gate_dense after _query_dense to reuse its equation variables.

Here is a suggestion to fix the most critical parts (the shape mismatch and incorrect build shape):

Suggested change
if self._use_gate:
self._gate_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._num_heads, self._key_dim]
),
bias_axes=bias_axes if self._use_bias else None,
activation="sigmoid",
name="gate",
**self._get_common_kwargs_for_sublayer(),
)
self._gate_dense.build(key_shape)
if self._use_gate:
query_einsum_equation, query_bias_axes, query_output_rank = _build_proj_equation(
query_rank - 1, bound_dims=1, output_dims=2
)
self._gate_dense = EinsumDense(
query_einsum_equation,
output_shape=_get_output_shape(
query_output_rank - 1, [self._num_heads, self._value_dim]
),
bias_axes=query_bias_axes if self._use_bias else None,
activation="sigmoid",
name="gate",
**self._get_common_kwargs_for_sublayer(),
)
self._gate_dense.build(query_shape)

einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_rank - 1, bound_dims=1, output_dims=2
)
Expand Down Expand Up @@ -549,6 +570,10 @@ def call(
# N = `num_attention_heads`
# H = `size_per_head`

# `gate` = [B, T, N, H]
if self._use_gate:
gate = self._gate_dense(query)

# `query` = [B, T, N, H]
query = self._query_dense(query)

Expand All @@ -565,7 +590,10 @@ def call(
training,
return_attention_scores,
)
attention_output = self._output_dense(attention_output)
if self._use_gate:
attention_output = self._output_dense(attention_output * gate)
else:
attention_output = self._output_dense(attention_output)

# Set mask on output if needed
if query_mask is not None:
Expand Down
Loading
Loading