-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Add optional Gated Attention #22372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add optional Gated Attention #22372
Changes from 2 commits
6330fb5
9077241
d605331
6d8c008
c826753
84d7033
97b9ae5
80c3d4a
b06a45c
35ebaeb
834d7cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -49,6 +49,12 @@ class GroupedQueryAttention(Layer): | |||||
| activity_regularizer: Regularizer for dense layer activity. | ||||||
| kernel_constraint: Constraint for dense layer kernels. | ||||||
| bias_constraint: Constraint for dense layer kernels. | ||||||
| use_gate: Boolean, whether to apply a gated attention mechanism. | ||||||
| When True, an additional gating branch is added based on the | ||||||
| (NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708]. | ||||||
|
||||||
| It applies a sigmoid-activated linear projection to the query | ||||||
| which then gates the attention output. This helps improve training | ||||||
| stability and eliminates "attention sinks". | ||||||
| seed: Optional integer to seed the dropout layer. | ||||||
|
|
||||||
| Call arguments: | ||||||
|
|
@@ -102,6 +108,7 @@ def __init__( | |||||
| activity_regularizer=None, | ||||||
| kernel_constraint=None, | ||||||
| bias_constraint=None, | ||||||
| use_gate=False, | ||||||
| seed=None, | ||||||
| **kwargs, | ||||||
| ): | ||||||
|
|
@@ -117,6 +124,7 @@ def __init__( | |||||
| self.num_repeats = num_query_heads // num_key_value_heads | ||||||
| self.dropout = dropout | ||||||
| self.use_bias = use_bias | ||||||
| self.use_gate = use_gate | ||||||
| self._flash_attention = flash_attention or is_flash_attention_enabled() | ||||||
| self.kernel_initializer = initializers.get(kernel_initializer) | ||||||
| self.bias_initializer = initializers.get(bias_initializer) | ||||||
|
|
@@ -170,7 +178,16 @@ def build( | |||||
| **self._get_common_kwargs_for_sublayer(), | ||||||
| ) | ||||||
| self._key_dense.build(key_shape) | ||||||
|
|
||||||
| if self.use_gate: | ||||||
| self._gate_dense = EinsumDense( | ||||||
| "bqm,muh->bquh", | ||||||
| output_shape=(None, self.num_query_heads, self.head_dim), | ||||||
| bias_axes="uh" if self.use_bias else None, | ||||||
| activation="sigmoid", | ||||||
| name="gate", | ||||||
| **self._get_common_kwargs_for_sublayer(), | ||||||
| ) | ||||||
| self._gate_dense.build(key_shape) | ||||||
|
||||||
| self._gate_dense.build(key_shape) | |
| self._gate_dense.build(query_shape) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -64,6 +64,12 @@ class MultiHeadAttention(Layer): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| activity_regularizer: Regularizer for dense layer activity. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel_constraint: Constraint for dense layer kernels. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bias_constraint: Constraint for dense layer kernels. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_gate: Boolean, whether to apply a gated attention mechanism. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| When True, an additional gating branch is added based on the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (NeurIPS 2025 Best Paper)[https://arxiv.org/abs/2505.06708]. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| It applies a sigmoid-activated linear projection to the query | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| which then gates the attention output. This helps improve training | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stability and eliminates "attention sinks". | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seed: Optional integer to seed the dropout layer. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Call arguments: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -117,6 +123,7 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| activity_regularizer=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel_constraint=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bias_constraint=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_gate=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seed=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -127,6 +134,7 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._value_dim = value_dim if value_dim else key_dim | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._dropout = dropout | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._use_bias = use_bias | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._use_gate = use_gate | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if output_shape: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(output_shape, int): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_shape = (output_shape,) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -201,6 +209,7 @@ def get_config(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "value_dim": self._value_dim, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "dropout": self._dropout, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "use_bias": self._use_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "use_gate": self._use_gate, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "output_shape": self._output_shape, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "attention_axes": self._attention_axes, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "kernel_initializer": initializers.serialize( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -271,6 +280,18 @@ def build( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **self._get_common_kwargs_for_sublayer(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._key_dense.build(key_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._use_gate: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._gate_dense = EinsumDense( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| einsum_equation, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_shape=_get_output_shape( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_rank - 1, [self._num_heads, self._key_dim] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bias_axes=bias_axes if self._use_bias else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| activation="sigmoid", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| name="gate", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **self._get_common_kwargs_for_sublayer(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._gate_dense.build(key_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._use_gate: | |
| self._gate_dense = EinsumDense( | |
| einsum_equation, | |
| output_shape=_get_output_shape( | |
| output_rank - 1, [self._num_heads, self._key_dim] | |
| ), | |
| bias_axes=bias_axes if self._use_bias else None, | |
| activation="sigmoid", | |
| name="gate", | |
| **self._get_common_kwargs_for_sublayer(), | |
| ) | |
| self._gate_dense.build(key_shape) | |
| if self._use_gate: | |
| query_einsum_equation, query_bias_axes, query_output_rank = _build_proj_equation( | |
| query_rank - 1, bound_dims=1, output_dims=2 | |
| ) | |
| self._gate_dense = EinsumDense( | |
| query_einsum_equation, | |
| output_shape=_get_output_shape( | |
| query_output_rank - 1, [self._num_heads, self._value_dim] | |
| ), | |
| bias_axes=query_bias_axes if self._use_bias else None, | |
| activation="sigmoid", | |
| name="gate", | |
| **self._get_common_kwargs_for_sublayer(), | |
| ) | |
| self._gate_dense.build(query_shape) |
Uh oh!
There was an error while loading. Please reload this page.