Skip to content

Different sized encoder for TransformerDecoder #182

@Adamits

Description

@Adamits

It would be convenient to allow the encoder output_size to be different from the TransformerDecoder embedding size. To illustrate the issue with this, the below code snippet

import torch
import math

def generate_square_subsequent_mask(length: int) -> torch.Tensor:
        return torch.triu(torch.full((length, length), -math.inf), diagonal=1)


# INITIALIZE A TRANSFORMER WITH THIS HIDDEN AND EMBEDDING SIZE
hid=128
emb=64
decoder_layer = torch.nn.TransformerDecoderLayer(
    d_model=emb,
    dim_feedforward=hid,
    nhead=2,
    dropout=0.2,
    activation="relu",
    batch_first=True,
)
frank_transformer = torch.nn.TransformerDecoder(
    decoder_layer=decoder_layer,
    num_layers=2,
    norm=torch.nn.LayerNorm(emb),
)

# INITIALIZE TARGETS WITH EMBEDDING SIZE
# AND A FAKE ENCODER OUTPUT WITH HIDDEN SIZE
b = 4
seq_len = 10
target_embedding = torch.randn((b, seq_len, emb))
encoder_hidden = torch.randn(b, seq_len, hid)
target_sequence_length = target_embedding.size(1)
# -> seq_len x seq_len.
causal_mask = generate_square_subsequent_mask(
    seq_len
)
# -> B x seq_len x d_model.
output = frank_transformer(
    target_embedding,
    encoder_hidden,
    tgt_mask=causal_mask,
    # memory_key_padding_mask=source_mask,
    # tgt_key_padding_mask=target_mask,
)

throws:

File "test.py", line 34, in <module>
    output = frank_transformer(
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/transformer.py", line 460, in forward
    output = mod(output, memory, tgt_mask=tgt_mask,
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/transformer.py", line 847, in forward
    x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
  File "torch/nn/modules/transformer.py", line 865, in _mha_block
    x = self.multihead_attn(x, mem, mem,
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "torch/nn/functional.py", line 5300, in multi_head_attention_forward
    q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
  File "torch/nn/functional.py", line 4836, in _in_projection_packed
    kv_proj = linear(k, w_kv, b_kv)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (40x128 and 64x128)

But if I change the code such that encoder_hidden = torch.randn(b, seq_len, hid) --> encoder_hidden = torch.randn(b, seq_len, emb), then this works fine.

Essentially, we need the self-attention and multihead-attention to expect different input sizes (which may also require the layer norms to change too).

I am putting this up, and will try to work out a solution. The easiest thing for allowing this behavior in yoyodyne would be to either project the encoder output size into the decoder embedding size, or visa versa, but I feel that this changes the architecture more than necessary. Instead, I would like to consider if there is an elegant way to update the sa_block and mha_block such that it does not break other things in the transformer (e.g. layer norm).

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions