diff --git a/keras/src/layers/attention/grouped_query_attention.py b/keras/src/layers/attention/grouped_query_attention.py index b57028446f0d..308ee53e4198 100644 --- a/keras/src/layers/attention/grouped_query_attention.py +++ b/keras/src/layers/attention/grouped_query_attention.py @@ -49,6 +49,12 @@ class GroupedQueryAttention(Layer): activity_regularizer: Regularizer for dense layer activity. kernel_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels. + use_gate: Boolean, whether to apply a gated attention mechanism. + When True, an additional gating branch is added based on the + (Gated Attention for Large Language Models)[https://arxiv.org/abs/2505.06708]. + It applies a sigmoid-activated linear projection to the query + which then gates the attention output. This helps improve training + stability and eliminates "attention sinks". seed: Optional integer to seed the dropout layer. Call arguments: @@ -102,6 +108,7 @@ def __init__( activity_regularizer=None, kernel_constraint=None, bias_constraint=None, + use_gate=False, seed=None, **kwargs, ): @@ -117,6 +124,7 @@ def __init__( self.num_repeats = num_query_heads // num_key_value_heads self.dropout = dropout self.use_bias = use_bias + self.use_gate = use_gate self._flash_attention = flash_attention or is_flash_attention_enabled() self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) @@ -170,7 +178,16 @@ def build( **self._get_common_kwargs_for_sublayer(), ) self._key_dense.build(key_shape) - + if self.use_gate: + self._gate_dense = EinsumDense( + "bqm,muh->bquh", + output_shape=(None, self.num_query_heads, self.head_dim), + bias_axes="uh" if self.use_bias else None, + activation="sigmoid", + name="gate", + **self._get_common_kwargs_for_sublayer(), + ) + self._gate_dense.build(query_shape) self._value_dense = EinsumDense( "bkm,mvh->bkvh", output_shape=(None, self.num_key_value_heads, self.head_dim), @@ -247,7 +264,8 @@ def call( attention_mask=attention_mask, use_causal_mask=use_causal_mask, ) - + if self.use_gate: + gate = self._gate_dense(query) query = self._query_dense(query) key = self._key_dense(key) value = self._value_dense(value) @@ -266,10 +284,11 @@ def call( attention_mask=attention_mask, training=training, ) - - output = self._output_dense( - output - ) # (batch_dim, target_seq_len, feature_dim) + # (batch_dim, target_seq_len, feature_dim) + if self.use_gate: + output = self._output_dense(ops.multiply(output, gate)) + else: + output = self._output_dense(output) if return_attention_scores: return output, scores @@ -483,6 +502,7 @@ def get_config(self): "num_query_heads": self.num_query_heads, "num_key_value_heads": self.num_key_value_heads, "use_bias": self.use_bias, + "use_gate": self.use_gate, "dropout": self.dropout, "kernel_initializer": initializers.serialize( self.kernel_initializer diff --git a/keras/src/layers/attention/grouped_query_attention_test.py b/keras/src/layers/attention/grouped_query_attention_test.py index 7e2afc979506..8c157b23d11a 100644 --- a/keras/src/layers/attention/grouped_query_attention_test.py +++ b/keras/src/layers/attention/grouped_query_attention_test.py @@ -60,6 +60,44 @@ def test_basics(self): run_training_check=False, ) + self.run_layer_test( + layers.GroupedQueryAttention, + init_kwargs={ + "num_query_heads": 2, + "num_key_value_heads": 2, + "head_dim": 2, + "use_gate": True, + }, + input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)}, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=10, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + + self.run_layer_test( + layers.GroupedQueryAttention, + init_kwargs={ + "num_query_heads": 2, + "num_key_value_heads": 2, + "head_dim": 2, + "use_bias": False, + "dropout": 0.5, + "use_gate": True, + }, + input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)}, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=5, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + @pytest.mark.skipif( backend.backend() not in ("jax", "torch"), reason="Flash attention only supported on JAX and Torch", @@ -207,6 +245,7 @@ def test_initializer(self): num_query_heads=16, num_key_value_heads=16, head_dim=64, + use_gate=True, kernel_initializer=initializers.TruncatedNormal(stddev=0.02), ) layer.build((2, 4, 8), (2, 4, 8)) @@ -225,6 +264,11 @@ def test_initializer(self): layer._output_dense.kernel, ) + self.assertNotAllClose( + layer._query_dense.kernel, + layer._gate_dense.kernel, + ) + @pytest.mark.skipif( backend.backend() == "numpy", reason="Numpy backend does not support masking.", @@ -241,6 +285,16 @@ def test_query_mask_propagation(self): output = layer(query=masked_query, value=value) self.assertAllClose(masked_query._keras_mask, output._keras_mask) + layer = layers.GroupedQueryAttention( + num_query_heads=2, num_key_value_heads=2, head_dim=2, use_gate=True + ) + self.assertTrue(layer.supports_masking) + query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]) + masked_query = layers.Embedding(4, 8, mask_zero=True)(query) + value = np.random.normal(size=(3, 3, 8)) + output = layer(query=masked_query, value=value) + self.assertAllClose(masked_query._keras_mask, output._keras_mask) + @parameterized.named_parameters(("causal", True), ("not_causal", 0)) @pytest.mark.skipif( backend.backend() == "numpy", @@ -276,6 +330,34 @@ def test_masking(self, use_causal_mask): ) self.assertAllClose(output, output_with_manual_mask) + layer = layers.GroupedQueryAttention( + num_query_heads=2, num_key_value_heads=2, head_dim=2, use_gate=True + ) + query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]) + masked_query = layers.Embedding(4, 8, mask_zero=True)(query) + value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]]) + masked_value = layers.Embedding(6, 8, mask_zero=True)(value) + output = layer( + query=masked_query, + value=masked_value, + use_causal_mask=use_causal_mask, + ) + mask = np.array( + [[[1, 1, 0]] * 3 + [[0, 0, 0]] * 2] + + [[[1, 0, 0]] * 5] + + [[[1, 1, 1]] + [[0, 0, 0]] * 4] + ).astype(bool) + if use_causal_mask: + mask = mask & np.array( + [[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3] + ).astype(bool) + del masked_query._keras_mask + del masked_value._keras_mask + output_with_manual_mask = layer( + query=masked_query, value=masked_value, attention_mask=mask + ) + self.assertAllClose(output, output_with_manual_mask) + @parameterized.named_parameters( ("disable_flash_attention", False), ("enable_flash_attention", True) ) diff --git a/keras/src/layers/attention/multi_head_attention.py b/keras/src/layers/attention/multi_head_attention.py index 4cf70ee2c112..b6391cbe4841 100644 --- a/keras/src/layers/attention/multi_head_attention.py +++ b/keras/src/layers/attention/multi_head_attention.py @@ -64,6 +64,12 @@ class MultiHeadAttention(Layer): activity_regularizer: Regularizer for dense layer activity. kernel_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels. + use_gate: Boolean, whether to apply a gated attention mechanism. + When True, an additional gating branch is added based on the + (Gated Attention for Large Language Models)[https://arxiv.org/abs/2505.06708]. + It applies a sigmoid-activated linear projection to the query + which then gates the attention output. This helps improve training + stability and eliminates "attention sinks". seed: Optional integer to seed the dropout layer. Call arguments: @@ -117,6 +123,7 @@ def __init__( activity_regularizer=None, kernel_constraint=None, bias_constraint=None, + use_gate=False, seed=None, **kwargs, ): @@ -127,6 +134,7 @@ def __init__( self._value_dim = value_dim if value_dim else key_dim self._dropout = dropout self._use_bias = use_bias + self._use_gate = use_gate if output_shape: if isinstance(output_shape, int): output_shape = (output_shape,) @@ -201,6 +209,7 @@ def get_config(self): "value_dim": self._value_dim, "dropout": self._dropout, "use_bias": self._use_bias, + "use_gate": self._use_gate, "output_shape": self._output_shape, "attention_axes": self._attention_axes, "kernel_initializer": initializers.serialize( @@ -271,6 +280,23 @@ def build( **self._get_common_kwargs_for_sublayer(), ) self._key_dense.build(key_shape) + if self._use_gate: + query_einsum_equation, query_bias_axes, query_output_rank = ( + _build_proj_equation( + query_rank - 1, bound_dims=1, output_dims=2 + ) + ) + self._gate_dense = EinsumDense( + query_einsum_equation, + output_shape=_get_output_shape( + query_output_rank - 1, [self._num_heads, self._value_dim] + ), + bias_axes=query_bias_axes if self._use_bias else None, + activation="sigmoid", + name="gate", + **self._get_common_kwargs_for_sublayer(), + ) + self._gate_dense.build(query_shape) einsum_equation, bias_axes, output_rank = _build_proj_equation( value_rank - 1, bound_dims=1, output_dims=2 ) @@ -549,6 +575,10 @@ def call( # N = `num_attention_heads` # H = `size_per_head` + # `gate` = [B, T, N, H] + if self._use_gate: + gate = self._gate_dense(query) + # `query` = [B, T, N, H] query = self._query_dense(query) @@ -565,7 +595,12 @@ def call( training, return_attention_scores, ) - attention_output = self._output_dense(attention_output) + if self._use_gate: + attention_output = self._output_dense( + ops.multiply(attention_output, gate) + ) + else: + attention_output = self._output_dense(attention_output) # Set mask on output if needed if query_mask is not None: diff --git a/keras/src/layers/attention/multi_head_attention_test.py b/keras/src/layers/attention/multi_head_attention_test.py index 3e9a7325a4f7..a70488a0fb3f 100644 --- a/keras/src/layers/attention/multi_head_attention_test.py +++ b/keras/src/layers/attention/multi_head_attention_test.py @@ -68,6 +68,43 @@ def test_basics(self): run_training_check=False, ) + self.run_layer_test( + layers.MultiHeadAttention, + init_kwargs={ + "num_heads": 2, + "key_dim": 2, + "use_gate": True, + }, + input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)}, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=10, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + + self.run_layer_test( + layers.MultiHeadAttention, + init_kwargs={ + "num_heads": 2, + "key_dim": 2, + "value_dim": 4, + "use_bias": False, + "dropout": 0.5, + "use_gate": True, + }, + input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)}, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=5, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + @pytest.mark.skipif( backend.backend() not in ("jax", "torch"), reason="Flash attention only supported on JAX and Torch", @@ -202,6 +239,27 @@ def test_high_dim_attention( run_training_check=False, ) + self.run_layer_test( + layers.MultiHeadAttention, + init_kwargs={ + "num_heads": 2, + "key_dim": 2, + "use_gate": True, + "attention_axes": attention_axes, + }, + input_shape={ + "query_shape": query_shape, + "value_shape": value_shape, + }, + expected_output_shape=query_shape, + expected_num_trainable_weights=10, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + def test_attention_axes_negative_indexing(self): x = np.random.normal(size=(2, 3, 8, 4)) @@ -232,6 +290,33 @@ def test_attention_axes_negative_indexing(self): self.assertAllClose(z_pos, z_neg, rtol=1e-5, atol=1e-5) self.assertAllClose(a_pos, a_neg, rtol=1e-5, atol=1e-5) + # Create two layers with equivalent positive and negative indices + mha_pos = layers.MultiHeadAttention( + num_heads=2, key_dim=4, attention_axes=2, use_gate=True + ) + mha_neg = layers.MultiHeadAttention( + num_heads=2, key_dim=4, attention_axes=-2, use_gate=True + ) + + # Initialize both layers + _ = mha_pos(x, x) + _ = mha_neg(x, x) + + # Set same weights for fair comparison + mha_neg.set_weights(mha_pos.get_weights()) + + # Get outputs and attention scores + z_pos, a_pos = mha_pos(x, x, return_attention_scores=True) + z_neg, a_neg = mha_neg(x, x, return_attention_scores=True) + + # Verify shapes match + self.assertEqual(z_pos.shape, z_neg.shape) + self.assertEqual(a_pos.shape, a_neg.shape) + + # Verify outputs are identical + self.assertAllClose(z_pos, z_neg, rtol=1e-5, atol=1e-5) + self.assertAllClose(a_pos, a_neg, rtol=1e-5, atol=1e-5) + @parameterized.named_parameters( ("without_key_same_proj", (4, 8), (2, 8), None, None), ("with_key_same_proj", (4, 8), (2, 8), (2, 3), None), @@ -314,6 +399,7 @@ def test_initializer(self): layer = layers.MultiHeadAttention( num_heads=12, key_dim=64, + use_gate=True, kernel_initializer=initializers.TruncatedNormal(stddev=0.02), ) layer.build((2, 4, 8), (2, 4, 8)) @@ -331,6 +417,10 @@ def test_initializer(self): layer._query_dense.kernel, layer._output_dense.kernel, ) + self.assertNotAllClose( + layer._query_dense.kernel, + layer._gate_dense.kernel, + ) @pytest.mark.skipif( backend.backend() == "numpy", @@ -485,6 +575,7 @@ def test_mha_constraints(self): layer = layers.MultiHeadAttention( num_heads=num_heads, key_dim=key_dim, + use_gate=True, kernel_constraint="non_neg", ) layer.build(query.shape, key.shape, value.shape) @@ -497,9 +588,13 @@ def test_mha_constraints(self): self.assertIsInstance( layer._key_dense.kernel.constraint, constraints.NonNeg ) + self.assertIsInstance( + layer._gate_dense.kernel.constraint, constraints.NonNeg + ) layer = layers.MultiHeadAttention( num_heads=num_heads, key_dim=key_dim, + use_gate=True, bias_constraint="non_neg", ) layer.build(query.shape, key.shape, value.shape) @@ -512,6 +607,9 @@ def test_mha_constraints(self): self.assertIsInstance( layer._key_dense.bias.constraint, constraints.NonNeg ) + self.assertIsInstance( + layer._gate_dense.bias.constraint, constraints.NonNeg + ) @pytest.mark.requires_trainable_backend def test_lora(self): @@ -522,13 +620,14 @@ def test_lora(self): num_heads=3, key_dim=8, use_bias=False, + use_gate=True, ) layer.build(query.shape, key.shape, value.shape) layer.query_dense.enable_lora(2) layer.key_dense.enable_lora(2) layer.value_dense.enable_lora(2) - self.assertLen(layer.trainable_variables, 7) + self.assertLen(layer.trainable_variables, 8) self.assertLen(layer.non_trainable_variables, 3) # Try eager call @@ -574,6 +673,7 @@ def test_lora(self): num_heads=3, key_dim=8, use_bias=False, + use_gate=True, )(inputs["query"], inputs["key"], inputs["value"]) new_model = models.Model(inputs, outputs) @@ -705,3 +805,24 @@ def test_quantize_int8(self): output_quantized = layer(query, key, value) mse = ops.mean(ops.square(output_float - output_quantized)) self.assertLess(mse, 1e-3) # A weak correctness test + + layer = layers.MultiHeadAttention( + num_heads=3, + key_dim=8, + use_gate=True, + use_bias=False, + ) + layer.build(query.shape, value.shape, key.shape) + output_float = layer(query, key, value) + for sublayer in layer._flatten_layers(): + try: + sublayer.quantize("int8") + except: + pass + + # Verify weights dtype + self.assertDType(layer._query_dense._kernel, "int8") + self.assertDType(layer._key_dense._kernel, "int8") + self.assertDType(layer._value_dense._kernel, "int8") + self.assertDType(layer._gate_dense._kernel, "int8") + self.assertDType(layer._output_dense._kernel, "int8")