Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces an optional gated attention mechanism to the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces an optional gated attention mechanism to GroupedQueryAttention and MultiHeadAttention. While the feature is a good addition, the implementation has some critical issues. The docstrings reference a non-existent paper, which is misleading and violates the repository's contribution guidelines. Additionally, there is a bug in the MultiHeadAttention gating logic that will cause a runtime error when key_dim and value_dim differ. I've provided detailed comments and suggestions to address these issues.
Note: Security Review did not run due to the size of the PR.
| bias_constraint: Constraint for dense layer kernels. | ||
| use_gate: Boolean, whether to apply a gated attention mechanism. | ||
| When True, an additional gating branch is added based on the | ||
| (NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708]. |
There was a problem hiding this comment.
The reference to "(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708]" appears to be a placeholder. The year is in the future and the arXiv link is invalid. This violates the Keras API design guidelines, which require new features to be based on widely recognized best practices. Please replace this with a valid reference to the paper that introduced this gated attention mechanism, or provide a more general explanation of the technique if a specific paper isn't the source.
References
- New features should be widely recognized as a machine learning best practice and not based on very recent or non-existent papers. (link)
| bias_constraint: Constraint for dense layer kernels. | ||
| use_gate: Boolean, whether to apply a gated attention mechanism. | ||
| When True, an additional gating branch is added based on the | ||
| (NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708]. |
There was a problem hiding this comment.
The reference to "(NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708]" appears to be a placeholder. The year is in the future and the arXiv link is invalid. This violates the Keras API design guidelines, which require new features to be based on widely recognized best practices. Please replace this with a valid reference to the paper that introduced this gated attention mechanism, or provide a more general explanation of the technique if a specific paper isn't the source.
References
- New features should be widely recognized as a machine learning best practice and not based on very recent or non-existent papers. (link)
There was a problem hiding this comment.
The link is fine, I'm not sure what Gemini is looking at. However, can you change the title from NeurIPS 2025 Best Paper to Gated Attention for Large Language Models?
| if self._use_gate: | ||
| self._gate_dense = EinsumDense( | ||
| einsum_equation, | ||
| output_shape=_get_output_shape( | ||
| output_rank - 1, [self._num_heads, self._key_dim] | ||
| ), | ||
| bias_axes=bias_axes if self._use_bias else None, | ||
| activation="sigmoid", | ||
| name="gate", | ||
| **self._get_common_kwargs_for_sublayer(), | ||
| ) | ||
| self._gate_dense.build(key_shape) |
There was a problem hiding this comment.
The implementation of the gated attention has a few issues that will lead to errors or incorrect behavior:
- Shape Mismatch Bug: The
_gate_denselayer's output dimension is set toself._key_dim, but it's multiplied withattention_output, which has a dimension ofself._value_dim. This will cause a runtime error ifkey_dim != value_dim. - Incorrect Equation: The
einsum_equationused is from the_key_denselayer, but the gate is a projection of thequery. It should use an equation based on thequery's rank. - Incorrect Build Shape: The
_gate_denselayer is built withkey_shape, but it should be built withquery_shapesince it processes thequery.
To fix these issues, the _gate_dense layer should be defined using the query's properties and its output dimension should be self._value_dim. This may require refactoring the build method slightly to define _gate_dense after _query_dense to reuse its equation variables.
Here is a suggestion to fix the most critical parts (the shape mismatch and incorrect build shape):
| if self._use_gate: | |
| self._gate_dense = EinsumDense( | |
| einsum_equation, | |
| output_shape=_get_output_shape( | |
| output_rank - 1, [self._num_heads, self._key_dim] | |
| ), | |
| bias_axes=bias_axes if self._use_bias else None, | |
| activation="sigmoid", | |
| name="gate", | |
| **self._get_common_kwargs_for_sublayer(), | |
| ) | |
| self._gate_dense.build(key_shape) | |
| if self._use_gate: | |
| query_einsum_equation, query_bias_axes, query_output_rank = _build_proj_equation( | |
| query_rank - 1, bound_dims=1, output_dims=2 | |
| ) | |
| self._gate_dense = EinsumDense( | |
| query_einsum_equation, | |
| output_shape=_get_output_shape( | |
| query_output_rank - 1, [self._num_heads, self._value_dim] | |
| ), | |
| bias_axes=query_bias_axes if self._use_bias else None, | |
| activation="sigmoid", | |
| name="gate", | |
| **self._get_common_kwargs_for_sublayer(), | |
| ) | |
| self._gate_dense.build(query_shape) |
| name="gate", | ||
| **self._get_common_kwargs_for_sublayer(), | ||
| ) | ||
| self._gate_dense.build(key_shape) |
There was a problem hiding this comment.
The _gate_dense layer is a projection of the query, but it's being built using key_shape. While this may not cause an error if EinsumDense only relies on the feature dimension (which is the same for query and key), it's semantically incorrect and confusing for future maintenance. For clarity and correctness, please build this layer using query_shape.
| self._gate_dense.build(key_shape) | |
| self._gate_dense.build(query_shape) |
|
/gemini review |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #22372 +/- ##
=======================================
Coverage 83.04% 83.04%
=======================================
Files 596 596
Lines 66708 66725 +17
Branches 10384 10390 +6
=======================================
+ Hits 55395 55412 +17
Misses 8676 8676
Partials 2637 2637
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
I need some help running this test. @parameterized.named_parameters(
("4d_inputs_1freebatch_mask2", (3, 4), (3, 2), (4, 2), (2,)),
("4d_inputs_1freebatch_mask3", (3, 4), (3, 2), (3, 4, 2), (2,)),
("4d_inputs_1freebatch_mask4", (3, 4), (3, 2), (3, 2, 4, 2), (2,)),
("4d_inputs_2d_attention", (3, 4), (3, 2), (3, 4, 3, 2), (1, 2)),
("5d_inputs_2d_attention", (5, 3, 4), (5, 3, 2), (3, 4, 3, 2), (2, 3)),
(
"5d_inputs_2d_attention_fullmask",
(5, 3, 4),
(5, 3, 2),
(5, 3, 4, 3, 2),
(2, 3),
),
)
def test_high_dim_attention(
self, q_dims, v_dims, mask_dims, attention_axes
):
batch_size, hidden_size = 3, 8
query_shape = (batch_size,) + q_dims + (hidden_size,)
value_shape = (batch_size,) + v_dims + (hidden_size,)
self.run_layer_test(
layers.MultiHeadAttention,
init_kwargs={
"num_heads": 2,
"key_dim": 2,
"attention_axes": attention_axes,
},
input_shape={
"query_shape": query_shape,
"value_shape": value_shape,
},
expected_output_shape=query_shape,
expected_num_trainable_weights=8,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)
self.run_layer_test(
layers.MultiHeadAttention,
init_kwargs={
"num_heads": 2,
"key_dim": 2,
"use_gate": True,
"attention_axes": attention_axes,
},
input_shape={
"query_shape": query_shape,
"value_shape": value_shape,
},
expected_output_shape=query_shape,
expected_num_trainable_weights=10,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)The error occurs only with the OpenVINO backend: I discovered that running this part works fine: batch_size, hidden_size = 3, 8
query_shape = (batch_size,) + q_dims + (hidden_size,)
value_shape = (batch_size,) + v_dims + (hidden_size,)
self.run_layer_test(
layers.MultiHeadAttention,
init_kwargs={
"num_heads": 2,
"key_dim": 2,
"attention_axes": attention_axes,
},
input_shape={
"query_shape": query_shape,
"value_shape": value_shape,
},
expected_output_shape=query_shape,
expected_num_trainable_weights=8,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)But when I run the following, it fails: self.run_layer_test(
layers.MultiHeadAttention,
init_kwargs={
"num_heads": 2,
"key_dim": 2,
"attention_axes": attention_axes,
},
input_shape={
"query_shape": query_shape,
"value_shape": value_shape,
},
expected_output_shape=query_shape,
expected_num_trainable_weights=10,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)The error is always: I tried modifying the code: def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
dtype = dtype or floatx()
seed_val = draw_seed(seed)
if isinstance(seed_val, OpenVINOKerasTensor):
seed_data = convert_to_numpy(seed_val)
else:
seed_data = seed_val.data
print(seed_data)
rng = np.random.default_rng(seed_data)
random_values = rng.uniform(minval, maxval, size=shape).astype(dtype)
return OpenVINOKerasTensor(ov_opset.constant(random_values).output(0))I printed |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an optional gated attention mechanism to the GroupedQueryAttention and MultiHeadAttention layers, a useful feature for improving training stability. The implementation is sound and accompanied by thorough tests. My review focuses on correcting the academic citations in the docstrings, which currently contain inaccuracies and appear to be hallucinations. The remaining changes are minor stylistic adjustments that improve code formatting.
|
Hey, the OpenVINO failure isn't related to your gated attention changes it's a pre-existing issue with the OpenVINO random backend. The issue is in keras/src/backend/openvino/random.py the uniform function passes seed values directly to np.random.default_rng(), which requires non-negative integers. When two run_layer_test calls run in the same test, the seed state can wrap to negative values (like the [-359424, 5387] you printed), causing the ValueError. You should be able to fix it by wrapping the seed data with np.abs():
The normal and truncated_normal functions in the same file (lines 19 and 133) have the same issue with seed.data so those should probably be fixed too. |
hertschuh
left a comment
There was a problem hiding this comment.
Thanks for adding this!
| # Create two layers with equivalent positive and negative indices | ||
| mha_pos = layers.MultiHeadAttention( | ||
| num_heads=2, key_dim=4, attention_axes=2, use_gate=True | ||
| ) | ||
| mha_neg = layers.MultiHeadAttention( | ||
| num_heads=2, key_dim=4, attention_axes=-2, use_gate=True | ||
| ) | ||
|
|
||
| # Initialize both layers | ||
| _ = mha_pos(x, x) | ||
| _ = mha_neg(x, x) | ||
|
|
||
| # Set same weights for fair comparison | ||
| mha_neg.set_weights(mha_pos.get_weights()) | ||
|
|
||
| # Get outputs and attention scores | ||
| z_pos, a_pos = mha_pos(x, x, return_attention_scores=True) | ||
| z_neg, a_neg = mha_neg(x, x, return_attention_scores=True) | ||
|
|
||
| # Verify shapes match | ||
| self.assertEqual(z_pos.shape, z_neg.shape) | ||
| self.assertEqual(a_pos.shape, a_neg.shape) | ||
|
|
||
| # Verify outputs are identical | ||
| self.assertAllClose(z_pos, z_neg, rtol=1e-5, atol=1e-5) | ||
| self.assertAllClose(a_pos, a_neg, rtol=1e-5, atol=1e-5) | ||
|
|
There was a problem hiding this comment.
I don't think this test is particularly useful. I believe if it works with use_gate=False, it will work with use_gate=True.
However, I'd like to see a test that validates that use_gate=True actually does something different. Maybe by creating 2 layers, one with use_gate=False and one with use_gate=True and comparing their outputs. Although that is a weak verification, maybe you can think of something better.
There was a problem hiding this comment.
I don't think this test is particularly useful. I believe if it works with
use_gate=False, it will work withuse_gate=True.However, I'd like to see a test that validates that
use_gate=Trueactually does something different. Maybe by creating 2 layers, one withuse_gate=Falseand one withuse_gate=Trueand comparing their outputs. Although that is a weak verification, maybe you can think of something better.
I think this test can be kept to verify whether the use_gate = True workflow works properly.
hertschuh
left a comment
There was a problem hiding this comment.
Please rebase, this will remove the formatting changes.
from #22337