Skip to content

Matmul error when using output_all_encoded_layers = True, and pooler #451

Open
@MostHumble

Description

@MostHumble

Hi,

First off thanks for this great contribution!

There seems to be an issue with the handling of then encoder_outputs in the pooler level when passing output_all_encoded_layers = True.

encoder_outputs = self.encoder(
embedding_output,
attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
subset_mask=subset_mask)
if masked_tokens_mask is None:
sequence_output = encoder_outputs[-1]
pooled_output = self.pooler(
sequence_output) if self.pooler is not None else None
else:
# TD [2022-03-01]: the indexing here is very tricky.
attention_mask_bool = attention_mask.bool()

because when doing that, I'm getting:

File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/PatientTrajectoryForecasting/utils/bert_layers_mosa.py:567, in BertPooler.forward(self, hidden_states, pool)
    561 def forward(self,
    562             hidden_states: torch.Tensor,
    563             pool: Optional[bool] = True) -> torch.Tensor:
    564     # We "pool" the model by simply taking the hidden state corresponding
    565     # to the first token.
    566     first_token_tensor = hidden_states[:, 0] if pool else hidden_states
--> 567     pooled_output = self.dense(first_token_tensor)
    568     pooled_output = self.activation(pooled_output)
    569     return pooled_output

File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x54784 and 768x768)

I believe the issue is due to the padding function not being applied to the hidden layens before appending to the list in the bert encoder level:

all_encoder_layers = []
if subset_mask is None:
for layer_module in self.layer:
hidden_states = layer_module(hidden_states,
cu_seqlens,
seqlen,
None,
indices,
attn_mask=attention_mask,
bias=alibi_attn_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
# Pad inputs and mask. It will insert back zero-padded tokens.
# Assume ntokens is total number of tokens (padded and non-padded)
# and ntokens_unpad is total number of non-padded tokens.
# Then padding performs the following de-compression:
# hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
hidden_states = bert_padding_module.pad_input(
hidden_states, indices, batch, seqlen)
else:

(Edit: yep this works, but not haven't checked for deps)

all_encoder_layers.append(bert_padding_module.pad_input(
                hidden_states, indices, batch, seqlen))

The same thing should probably be done when the subset_mask is not None...

Thanks again for your contribution to the comunity!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions