Help me understand how to use the flax.linen.make_attention_mask function #3163
Replies: 2 comments
-
Hi @davidshen84,
The masking function will mask attention weights in QK^T depending on what tokens in the input sequence are being masked ( see here for brief review ). The following code may provide clarification based on the example you provided. Code Example import chex
import jax
import jax.numpy as jnp
import flax.linen as nn
# check masking function
# here the last element of the sequence is masked
# resulting mask outlines masking of attention weights QK^T
q = jnp.array([1.0, 1.0, 1.0, 0.0])
mask = nn.make_attention_mask(q>0, q>0) # note we are using self attention
print(mask)
### example demonstrating the difference between masking and not masking ###
# initialise random key and dummy data
key = jax.random.PRNGKey(0)
dummy_data = jax.random.randint(key, (5, 10, 10), 0, 100)
# initialise attention layer
attention_layer = nn.SelfAttention(num_heads=1)
params = attention_layer.init(key, dummy_data)
# create attention mask
q = jnp.ones([5, 10]).at[:, 1:].set(0) # mask all tokens except the first in each sequence
mask = nn.make_attention_mask(q>0, q>0)
# check equality
chex.assert_trees_all_close(
m.apply(params, inputs_q = input_q, mask = mask),
m.apply(params, inputs_q = input_q, mask=None)
) Outputs
In the actual implementation of the In your code there are some errors, m.apply(params, inputs_q, inputs_q, mask, mask) The In addition your input embeddings and the input parameters to the If I can provide further explanation on the actual Hope this helps. |
Beta Was this translation helpful? Give feedback.
-
Thanks a lot!
…On Fri, 14 Jul 2023 at 01:01, Peter David Fagan ***@***.***> wrote:
Hi @davidshen84 <https://github.com/davidshen84>,
The masking function will mask attention weights in QK^T depending on what
elements of the sequence are being masked see here
<https://lukesalamone.github.io/posts/what-are-attention-masks/> . The
following code
<https://colab.research.google.com/drive/1TS6A7y2ALgeqDLWnlGtDKKHY-t-13DKW?usp=sharing>
may provide clarification.
import chex
import jax
import jax.numpy as jnp
import flax.linen as nn
# check masking function
# here the last element of the sequence is masked
# resulting mask outlines masking of attention weights QK^T
q = jnp.array([1.0, 1.0, 1.0, 0.0])
nn.make_attention_mask(q>0, q>0)
# example demonstrating the difference between masking and not masking
rng = jax.random.PRNGKey(0)
key1, _ = jax.random.split(rng)
m = nn.SelfAttention(num_heads=1)
input_q = jax.random.randint(key1, (5, 10, 10), 0, 100)
params = m.init(key1, input_q)
q = jnp.ones([5, 10]).at[:, 1:].set(0)
mask = nn.make_attention_mask(q>0, q>0)
chex.assert_trees_all_equal(
m.apply(params, inputs_q = input_q, mask = mask),
m.apply(params, inputs_q = input_q, mask=None)
)
In your code there are some errors,
m.apply(params, inputs_q, inputs_q, mask, mask)
The SelfAttention module is a special case of MultiHeadedAttention and
hence it takes just one inputs_q parameter. The above code isn't executable
as it doesn't obey this.
In addition your input embeddings and the input parameters to the
make_attention_mask function have the same dimensions, these need to be
revised as the dimensions should be different. In particular, the
make_attention_make method doesn't expect an embedding.
—
Reply to this email directly, view it on GitHub
<#3163 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAQBTIX5QVL562CGXCDTSLXQAEV5ANCNFSM6AAAAAAZTAATRM>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Beta Was this translation helpful? Give feedback.
-
Hi,
I am trying to build a SelfAttention layer with a mask parameter. I think I should use the
make_attention_mask
function to create the mask variable. I create a trivial example to help me understand the matrix shapes. But the output is not what I expected.No matter how I set the mask in the
q
variable, the outputs of the last two lines are always the same. It looks like themask
parameter has no effect.I was expecting one column of the
SelfAttention
layer's output to be set to negative infinity.Beta Was this translation helpful? Give feedback.
All reactions