Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions keras/src/layers/attention/grouped_query_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class GroupedQueryAttention(Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_gate: Boolean, whether to apply a gated attention mechanism.
When True, an additional gating branch is added based on the
(Gated Attention for Large Language Models)[https://arxiv.org/abs/2505.06708].
It applies a sigmoid-activated linear projection to the query
which then gates the attention output. This helps improve training
stability and eliminates "attention sinks".
seed: Optional integer to seed the dropout layer.

Call arguments:
Expand Down Expand Up @@ -102,6 +108,7 @@ def __init__(
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_gate=False,
seed=None,
**kwargs,
):
Expand All @@ -117,6 +124,7 @@ def __init__(
self.num_repeats = num_query_heads // num_key_value_heads
self.dropout = dropout
self.use_bias = use_bias
self.use_gate = use_gate
self._flash_attention = flash_attention or is_flash_attention_enabled()
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
Expand Down Expand Up @@ -170,7 +178,16 @@ def build(
**self._get_common_kwargs_for_sublayer(),
)
self._key_dense.build(key_shape)

if self.use_gate:
self._gate_dense = EinsumDense(
"bqm,muh->bquh",
output_shape=(None, self.num_query_heads, self.head_dim),
bias_axes="uh" if self.use_bias else None,
activation="sigmoid",
name="gate",
**self._get_common_kwargs_for_sublayer(),
)
self._gate_dense.build(query_shape)
self._value_dense = EinsumDense(
"bkm,mvh->bkvh",
output_shape=(None, self.num_key_value_heads, self.head_dim),
Expand Down Expand Up @@ -247,7 +264,8 @@ def call(
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
)

if self.use_gate:
gate = self._gate_dense(query)
query = self._query_dense(query)
key = self._key_dense(key)
value = self._value_dense(value)
Expand All @@ -266,10 +284,11 @@ def call(
attention_mask=attention_mask,
training=training,
)

output = self._output_dense(
output
) # (batch_dim, target_seq_len, feature_dim)
# (batch_dim, target_seq_len, feature_dim)
if self.use_gate:
output = self._output_dense(ops.multiply(output, gate))
else:
output = self._output_dense(output)

if return_attention_scores:
return output, scores
Expand Down Expand Up @@ -483,6 +502,7 @@ def get_config(self):
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"use_bias": self.use_bias,
"use_gate": self.use_gate,
"dropout": self.dropout,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
Expand Down
82 changes: 82 additions & 0 deletions keras/src/layers/attention/grouped_query_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,44 @@ def test_basics(self):
run_training_check=False,
)

self.run_layer_test(
layers.GroupedQueryAttention,
init_kwargs={
"num_query_heads": 2,
"num_key_value_heads": 2,
"head_dim": 2,
"use_gate": True,
},
input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)},
expected_output_shape=(2, 8, 16),
expected_num_trainable_weights=10,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)

self.run_layer_test(
layers.GroupedQueryAttention,
init_kwargs={
"num_query_heads": 2,
"num_key_value_heads": 2,
"head_dim": 2,
"use_bias": False,
"dropout": 0.5,
"use_gate": True,
},
input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)},
expected_output_shape=(2, 8, 16),
expected_num_trainable_weights=5,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=1,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)

@pytest.mark.skipif(
backend.backend() not in ("jax", "torch"),
reason="Flash attention only supported on JAX and Torch",
Expand Down Expand Up @@ -207,6 +245,7 @@ def test_initializer(self):
num_query_heads=16,
num_key_value_heads=16,
head_dim=64,
use_gate=True,
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
)
layer.build((2, 4, 8), (2, 4, 8))
Expand All @@ -225,6 +264,11 @@ def test_initializer(self):
layer._output_dense.kernel,
)

self.assertNotAllClose(
layer._query_dense.kernel,
layer._gate_dense.kernel,
)

@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
Expand All @@ -241,6 +285,16 @@ def test_query_mask_propagation(self):
output = layer(query=masked_query, value=value)
self.assertAllClose(masked_query._keras_mask, output._keras_mask)

layer = layers.GroupedQueryAttention(
num_query_heads=2, num_key_value_heads=2, head_dim=2, use_gate=True
)
self.assertTrue(layer.supports_masking)
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = layers.Embedding(4, 8, mask_zero=True)(query)
value = np.random.normal(size=(3, 3, 8))
output = layer(query=masked_query, value=value)
self.assertAllClose(masked_query._keras_mask, output._keras_mask)

@parameterized.named_parameters(("causal", True), ("not_causal", 0))
@pytest.mark.skipif(
backend.backend() == "numpy",
Expand Down Expand Up @@ -276,6 +330,34 @@ def test_masking(self, use_causal_mask):
)
self.assertAllClose(output, output_with_manual_mask)

layer = layers.GroupedQueryAttention(
num_query_heads=2, num_key_value_heads=2, head_dim=2, use_gate=True
)
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = layers.Embedding(4, 8, mask_zero=True)(query)
value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
masked_value = layers.Embedding(6, 8, mask_zero=True)(value)
output = layer(
query=masked_query,
value=masked_value,
use_causal_mask=use_causal_mask,
)
mask = np.array(
[[[1, 1, 0]] * 3 + [[0, 0, 0]] * 2]
+ [[[1, 0, 0]] * 5]
+ [[[1, 1, 1]] + [[0, 0, 0]] * 4]
).astype(bool)
if use_causal_mask:
mask = mask & np.array(
[[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3]
).astype(bool)
del masked_query._keras_mask
del masked_value._keras_mask
output_with_manual_mask = layer(
query=masked_query, value=masked_value, attention_mask=mask
)
self.assertAllClose(output, output_with_manual_mask)

@parameterized.named_parameters(
("disable_flash_attention", False), ("enable_flash_attention", True)
)
Expand Down
37 changes: 36 additions & 1 deletion keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class MultiHeadAttention(Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_gate: Boolean, whether to apply a gated attention mechanism.
When True, an additional gating branch is added based on the
(Gated Attention for Large Language Models)[https://arxiv.org/abs/2505.06708].
It applies a sigmoid-activated linear projection to the query
which then gates the attention output. This helps improve training
stability and eliminates "attention sinks".
seed: Optional integer to seed the dropout layer.

Call arguments:
Expand Down Expand Up @@ -117,6 +123,7 @@ def __init__(
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_gate=False,
seed=None,
**kwargs,
):
Expand All @@ -127,6 +134,7 @@ def __init__(
self._value_dim = value_dim if value_dim else key_dim
self._dropout = dropout
self._use_bias = use_bias
self._use_gate = use_gate
if output_shape:
if isinstance(output_shape, int):
output_shape = (output_shape,)
Expand Down Expand Up @@ -201,6 +209,7 @@ def get_config(self):
"value_dim": self._value_dim,
"dropout": self._dropout,
"use_bias": self._use_bias,
"use_gate": self._use_gate,
"output_shape": self._output_shape,
"attention_axes": self._attention_axes,
"kernel_initializer": initializers.serialize(
Expand Down Expand Up @@ -271,6 +280,23 @@ def build(
**self._get_common_kwargs_for_sublayer(),
)
self._key_dense.build(key_shape)
if self._use_gate:
query_einsum_equation, query_bias_axes, query_output_rank = (
_build_proj_equation(
query_rank - 1, bound_dims=1, output_dims=2
)
)
self._gate_dense = EinsumDense(
query_einsum_equation,
output_shape=_get_output_shape(
query_output_rank - 1, [self._num_heads, self._value_dim]
),
bias_axes=query_bias_axes if self._use_bias else None,
activation="sigmoid",
name="gate",
**self._get_common_kwargs_for_sublayer(),
)
self._gate_dense.build(query_shape)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_rank - 1, bound_dims=1, output_dims=2
)
Expand Down Expand Up @@ -549,6 +575,10 @@ def call(
# N = `num_attention_heads`
# H = `size_per_head`

# `gate` = [B, T, N, H]
if self._use_gate:
gate = self._gate_dense(query)

# `query` = [B, T, N, H]
query = self._query_dense(query)

Expand All @@ -565,7 +595,12 @@ def call(
training,
return_attention_scores,
)
attention_output = self._output_dense(attention_output)
if self._use_gate:
attention_output = self._output_dense(
ops.multiply(attention_output, gate)
)
else:
attention_output = self._output_dense(attention_output)

# Set mask on output if needed
if query_mask is not None:
Expand Down
Loading
Loading