-
Notifications
You must be signed in to change notification settings - Fork 107
Description
There is a variable error with the full_bigbird_mask method in the multi-head attention class for the big bird mask that uses MAX_SEQ_LEN instead of from_sequence_length passed, this will affect the creation of attention_mask with the using the convert_attn_list_to_mask(self, rand_attn) method.
temp_mask = [ full_bigbird_mask( # pylint: disable=g-complex-comprehension self.from_seq_length, self.to_seq_length, self.from_block_size, self.to_block_size, rand_attn=rand_attn[i]) for i in range(self.num_attention_heads) ]
`def full_bigbird_mask(from_seq_length,
to_seq_length,
from_block_size,
to_block_size,
rand_attn):
"""Calculate BigBird attention pattern as a full dense matrix.
Args:
from_seq_length: int. length of from sequence.
to_seq_length: int. length of to sequence.
from_block_size: int. size of block in from sequence.
to_block_size: int. size of block in to sequence.
rand_attn: adjajency matrix for random attention.
Returns:
attention mask matrix of shape [from_seq_length, to_seq_length]
"""
attn_mask = np.zeros((MAX_SEQ_LEN, MAX_SEQ_LEN), dtype=np.int32)
for i in range(1, (MAX_SEQ_LEN // from_block_size) - 1):`
full_bird_mask method uses MAX_SEQ_LEN instead of from_seq_length or to_seq_length which does not make the method dynamic as MAX_SEQ_LEN is only defined at the top of the module and seems to be causing a glitch with the convert_attn_list_to_mask method.