The way to mask attention matrices in Flax #2915
Unanswered
kenkenpa2126
asked this question in
Q&A
Replies: 2 comments 4 replies
-
Hey @kenkenpa2126
Can you explain why this would happen? The way I see it |
Beta Was this translation helpful? Give feedback.
3 replies
-
Hey @kenkenpa2126, I did a simple experiment and can confirm that the gradients are indeed masked when they are replaced by any constant: import jax
import jax.numpy as jnp
weights = jnp.full((3, 3), 2.0)
mask = jnp.array([
[1, 0, 1],
[1, 1, 0],
[0, 1, 0],
])
def fn(weights, mask):
big_neg = jnp.finfo(jnp.float32).min
weights = jnp.where(mask, weights, big_neg)
return jnp.sum(weights)
grads = jax.grad(fn)(weights, mask)
print(grads)
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
It seems that attention matrices are masked both for queries and keys simultaneously and masked positions are changed to
jnp.finfo(dtype).min
, not to-jnp.inf
in Flax implementation.flax/flax/linen/attention.py
Line 104 in 10a2123
When both queries and keys are masked at the same time and the masked positions are changed to
-jnp.inf
the rows of all-jnp.inf
cause if the query has masked positions, and it causes an error in the calculation in softmax because the denominator becomes zero.I guess that's why masked positions are changed to
jnp.finfo(dtype).min
instead of-jnp.inf
.However, it lets gradient flow into the masked positions, and it's thought to be not good.
Also, we can change the masked positions to
-jnp.inf
if making attention masks only for keys.I wonder what brings Flax to adopt this way to mask attention matrices both for queries and keys at the same time.
Beta Was this translation helpful? Give feedback.
All reactions