diff --git a/parlai/agents/image_seq2seq/image_seq2seq.py b/parlai/agents/image_seq2seq/image_seq2seq.py index a320df1947a..867643f3a96 100644 --- a/parlai/agents/image_seq2seq/image_seq2seq.py +++ b/parlai/agents/image_seq2seq/image_seq2seq.py @@ -6,23 +6,23 @@ """ Image+Seq2Seq Agent. """ -import torch + from typing import Dict, List, Tuple +import torch + from .modules import ImageSeq2seqModel from parlai.agents.transformer.transformer import TransformerGeneratorAgent from parlai.core.dict import DictionaryAgent -from parlai.core.message import Message from parlai.core.torch_agent import Batch - -# from parlai.utils.typing import Dict, List +from parlai.core.torch_image_agent import TorchImageAgent TOKEN_IMAGE = '__image__' TOKEN_NO_IMAGE = '__no_image__' -class ImageSeq2seqAgent(TransformerGeneratorAgent): +class ImageSeq2seqAgent(TransformerGeneratorAgent, TorchImageAgent): """ ImageSeq2seqAgent Agent. @@ -45,18 +45,9 @@ def add_cmdline_args(cls, argparser): """ Override to add one arg. """ - super(ImageSeq2seqAgent, cls).add_cmdline_args(argparser) + TransformerGeneratorAgent.add_cmdline_args(argparser) + TorchImageAgent.add_cmdline_args(argparser) group = argparser.add_argument_group('Image Encoder Args') - group.add_argument( - '--image-features-dim', type=int, default=2048, help='dim for image feats' - ) - group.add_argument( - '--image-encoder-num-layers', - type=int, - default=1, - recommended=1, - help='Number of layers for image encoder', - ) group.add_argument( '--include-image-token', type='bool', @@ -105,33 +96,22 @@ def _dummy_batch(self, batchsize: int, maxlen: int) -> Batch: return Batch( text_vec=torch.ones(batchsize, maxlen).long().cuda(), label_vec=torch.ones(batchsize, 2).long().cuda(), - image=torch.ones(batchsize, self.opt.get('image_features_dim')).cuda(), + image=torch.ones(batchsize, self.image_features_dim).cuda(), personalities=torch.ones(batchsize, self.opt.get('embedding_size')).cuda(), ) - def batchify(self, obs_batch: List[Message], sort: bool = False) -> Batch: - """ - Override to handle images. + def batchify_image_features(self, batch: Batch) -> Batch: """ - batch = super().batchify(obs_batch, sort) - - def _process_img(img): - if img is not None and isinstance(img, torch.Tensor): - if img.dim() == 4: - img = img[0, :, 0, 0] - if self.use_cuda: - img = img.cuda() - if self.opt.get('fp16'): - img = img.half() - else: - img = img.float() - - return img + Format and return the batched image features. + Image features represented by tensors will set to the right type. + """ if type(batch.image) == list and any(b is not None for b in batch): images = [] for img in batch.image: - images.append(_process_img(img)) + if isinstance(img, torch.Tensor): + img = self._process_image_features(img) + images.append(img) batch.image = images return batch diff --git a/parlai/agents/image_seq2seq/modules.py b/parlai/agents/image_seq2seq/modules.py index 383ab10aeab..24136f5f6c7 100644 --- a/parlai/agents/image_seq2seq/modules.py +++ b/parlai/agents/image_seq2seq/modules.py @@ -6,9 +6,12 @@ """ Modules for ImageSeq2seqAgent Agent. """ + +from functools import reduce +from typing import List, Tuple, Optional, Union + import torch import torch.nn as nn -from typing import List, Tuple from parlai.agents.transformer.modules import ( TransformerGeneratorModel, @@ -54,7 +57,6 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent): padding_idx=self.pad_idx, learn_positional_embeddings=opt['learn_positional_embeddings'], embeddings_scale=opt['embeddings_scale'], - reduction_type=None, n_positions=n_positions, n_segments=opt.get('n_segments', 0), activation=opt['activation'], @@ -69,7 +71,7 @@ class ContextWithImageEncoder(TransformerEncoder): """ ContextWithImage Module. - Encodes image and context via simple concatenation. + Encodes image features and context, and combines by summing or concatenation. """ def __init__( @@ -86,7 +88,6 @@ def __init__( padding_idx=0, learn_positional_embeddings=False, embeddings_scale=False, - reduction_type='mean', n_positions=1024, activation='relu', variant='aiayn', @@ -94,6 +95,8 @@ def __init__( output_scaling=1.0, image_encoder_num_layers=1, image_features_dim=2048, + image_combination_mode='append', + n_image_tokens=1, ): """ Override TransformerEncoder __init__. @@ -101,45 +104,59 @@ def __init__( Setup the image encoder; create some dummy tensors for inserting image into input """ + + self.padding_idx = padding_idx self.n_img_layers = image_encoder_num_layers self.img_dim = image_features_dim + self.image_combination_mode = image_combination_mode + self.n_image_tokens = n_image_tokens + if self.image_combination_mode == 'add' and self.n_image_tokens > 1: + raise ValueError( + 'Image encoding cannot be added to context encoding if there is more than one image token!' + ) + reduction_type = None # Must pass back unreduced encoding and mask super().__init__( - n_heads, - n_layers, - embedding_size, - ffn_size, - vocabulary_size, - embedding, - dropout, - attention_dropout, - relu_dropout, - padding_idx, - learn_positional_embeddings, - embeddings_scale, - reduction_type, - n_positions, - activation, - variant, - n_segments, - output_scaling, + n_heads=n_heads, + n_layers=n_layers, + embedding_size=embedding_size, + ffn_size=ffn_size, + vocabulary_size=vocabulary_size, + embedding=embedding, + dropout=dropout, + attention_dropout=attention_dropout, + relu_dropout=relu_dropout, + padding_idx=padding_idx, + learn_positional_embeddings=learn_positional_embeddings, + embeddings_scale=embeddings_scale, + reduction_type=reduction_type, + n_positions=n_positions, + activation=activation, + variant=variant, + n_segments=n_segments, + output_scaling=output_scaling, ) + self.full_embedding_size = self.embedding_size * self.n_image_tokens + # Images will be embedded to this size, and then the embedding will be folded + # into however many tokens are needed self._build_image_encoder() - self.dummy_image_enc = torch.nn.Parameter( - torch.zeros((self.embedding_size)), requires_grad=False + self.register_buffer( + 'dummy_image_enc', torch.zeros((self.full_embedding_size,)) ) - self.ones_mask = torch.nn.Parameter(torch.ones(1).bool(), requires_grad=False) + self.register_buffer('ones_mask', torch.ones(self.n_image_tokens).bool()) def _build_image_encoder(self): - image_layers = [nn.Linear(self.img_dim, self.embedding_size)] + image_layers = [nn.Linear(self.img_dim, self.full_embedding_size)] for _ in range(self.n_img_layers - 1): image_layers += [ nn.ReLU(), - nn.Dropout(p=self.opt['dropout']), - nn.Linear(self.img_dim, self.embedding_size), + nn.Dropout(p=self.dropout_frac), + nn.Linear(self.full_embedding_size, self.full_embedding_size), ] self.image_encoder = nn.Sequential(*image_layers) - def encode_images(self, images: List[object]) -> Tuple[List[int], torch.Tensor]: + def encode_images( + self, images: Union[List[object], torch.Tensor] + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """ Encode Images. @@ -147,14 +164,15 @@ def encode_images(self, images: List[object]) -> Tuple[List[int], torch.Tensor]: is a tensor). :param images: - list of objects of length N, of which some maybe be None + either a list of objects of length N, of which some maybe be None, or a + tensor of shape (batch size, self.img_dim) :return: a (image_encoded, image_mask) tuple, where: - - image_enc is a torch.Tensor of dim N x self.img_dim, - representing the encoded batch of images - - image_mask is a torch.Tensor of dim N x 1 + - image_encoded is a torch.Tensor of dim N x self.n_image_tokens x + self.embedding_size, representing the encoded batch of images + - image_mask is a torch.Tensor of dim N x self.n_image_tokens """ image_masks = image_encoded = None valid_inds = [ @@ -164,8 +182,8 @@ def encode_images(self, images: List[object]) -> Tuple[List[int], torch.Tensor]: ] if valid_inds: - image_masks = [] - image_encoded = [] + image_mask_list = [] + image_encoded_list = [] valid_imgs = torch.stack([images[i] for i in valid_inds]) valid_img_enc = self.image_encoder(valid_imgs) @@ -173,31 +191,38 @@ def encode_images(self, images: List[object]) -> Tuple[List[int], torch.Tensor]: img_num = 0 for i in range(len(images)): if i in valid_inds: - image_masks.append(self.ones_mask) - image_encoded.append(valid_img_enc[img_num, :]) + image_mask_list.append(self.ones_mask) + image_encoded_list.append(valid_img_enc[img_num, :]) img_num += 1 else: - image_masks.append(~self.ones_mask) - image_encoded.append(self.dummy_image_enc) + image_mask_list.append(~self.ones_mask) + image_encoded_list.append(self.dummy_image_enc) - image_masks = torch.stack(image_masks) - image_encoded = torch.stack(image_encoded).unsqueeze(1) + image_masks = torch.stack(image_mask_list) + image_encoded = torch.stack(image_encoded_list).reshape( + (len(images), self.n_image_tokens, self.embedding_size) + ) + assert image_masks.shape == image_encoded.shape[:2] return image_encoded, image_masks def forward( - self, src_tokens: torch.Tensor, image_features: List[object] - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, + src_tokens: Optional[torch.Tensor], + image_features: Optional[Union[List[object], torch.Tensor]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Encode images with context. Encodes tokens (if given) and images (if given) separately. - Combines via concatenation, where images are added to the end of the tensor. + Combines via either addition, prepending, or appending the image embedding to + the context embedding. :param src_tokens: A bsz x seq_len tensor of src_tokens; possibly None :param image_features: - A list of (torch.tensor) + Either a list of (torch.tensor) or a tensor of shape (batch_size, + self.img_dim) :return: A (full_enc, full_mask) tuple, which represents the encoded context @@ -205,6 +230,8 @@ def forward( """ context_encoded = context_mask = None image_encoded = extra_masks = None + if src_tokens is not None and image_features is not None: + assert src_tokens.size(0) == len(image_features) if src_tokens is not None: context_encoded, context_mask = super().forward(src_tokens) if image_features is not None: @@ -219,11 +246,42 @@ def forward( 'set correctly.' ) - full_enc = self.cat([context_encoded, image_encoded]) - full_mask = self.cat([context_mask, extra_masks]) + if self.image_combination_mode == 'add': + full_enc = self._add([context_encoded, image_encoded]) + # image_encoded broadcasted along dim=1 + full_mask = context_mask + elif self.image_combination_mode == 'append': + full_enc = self._cat([context_encoded, image_encoded]) + full_mask = self._cat([context_mask, extra_masks]) + elif self.image_combination_mode == 'prepend': + full_enc = self._cat([image_encoded, context_encoded]) + full_mask = self._cat([extra_masks, context_mask]) + else: + raise ValueError('Image combination mode not recognized!') + + if full_enc.dtype == torch.half: + full_enc, full_mask = self._fix_for_fp16( + full_enc=full_enc, full_mask=full_mask + ) + return full_enc, full_mask - def cat(self, tensors: List[torch.Tensor]) -> torch.Tensor: + def _add(self, tensors: List[Optional[torch.Tensor]]) -> torch.Tensor: + """ + Handle addition of None tensors. + + Smart addition. Adds tensors if they are not None. + + :param tensors: + A list of torch.Tensor, with at least one non-null object + + :return: + The result of adding all non-null objects in tensors + """ + tensors = [t for t in tensors if t is not None] + return reduce(lambda a, b: a + b, tensors) + + def _cat(self, tensors: List[Optional[torch.Tensor]]) -> torch.Tensor: """ Handle concatenation of None tensors. @@ -237,3 +295,39 @@ def cat(self, tensors: List[torch.Tensor]) -> torch.Tensor: """ tensors = [t for t in tensors if t is not None] return torch.cat([t for t in tensors], dim=1) + + def _fix_for_fp16( + self, full_enc: torch.Tensor, full_mask: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + In fp16 mode, either remove extra tokens or add new ones on to get to a multiple + of 8. + """ + + if full_mask is None: + # full_mask is None corresponds to no input tokens, and in case there are no + # tokens to add/remove to get a multiple of 8 + return full_enc, full_mask + + num_tokens_to_remove = full_enc.size(1) % 8 + if num_tokens_to_remove == 0: + # Tensor already divisible by 8 + pass + elif (~full_mask[:, -num_tokens_to_remove:].all()).item(): + # The tokens we'd like to remove are all padding, so subtract them from + # the end + full_enc = full_enc[:, :-1, :] + full_mask = full_mask[:, :-1] + else: + # We can't subtract that many padding tokens, so add some to the end + num_tokens_to_add = 8 - num_tokens_to_remove + enc_extension = full_enc.new_full( + size=(full_enc.size(0), num_tokens_to_add, full_enc.size(2)), + fill_value=self.padding_idx, + ) + mask_extension = full_mask.new_full( + size=(full_mask.size(0), num_tokens_to_add), fill_value=self.padding_idx + ) + full_enc = torch.cat([full_enc, enc_extension], dim=1) + full_mask = torch.cat([full_mask, mask_extension], dim=1) + return full_enc, full_mask diff --git a/parlai/agents/transformer/image_polyencoder.py b/parlai/agents/transformer/image_polyencoder.py new file mode 100644 index 00000000000..20c244bd8bc --- /dev/null +++ b/parlai/agents/transformer/image_polyencoder.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# hack to make sure -m transformer/generator works as expected +""" +Poly-encoder agent that ingests image features. +""" + +from typing import Any, Dict + +import torch + +from parlai.agents.image_seq2seq.modules import ContextWithImageEncoder +from parlai.agents.transformer.modules import get_n_positions_from_options +from parlai.agents.transformer.polyencoder import PolyencoderAgent, PolyEncoderModule +from parlai.core.torch_agent import Batch +from parlai.core.torch_image_agent import TorchImageAgent +from parlai.utils.misc import warn_once + + +class ImagePolyencoderAgent(PolyencoderAgent, TorchImageAgent): + """ + Poly-encoder Agent that ingests image features. + + Agent that allows encoding image features and adding or concatenating them to the + context encoding. + """ + + @classmethod + def add_cmdline_args(cls, argparser): + """ + Add command-line arguments specifically for this agent. + """ + PolyencoderAgent.add_cmdline_args(argparser) + TorchImageAgent.add_cmdline_args(argparser) + agent = argparser.add_argument_group('ImagePolyencoder Args') + agent.add_argument( + '--image-combination-mode', + type=str, + default='prepend', + choices=['add', 'append', 'prepend'], + help='How to combine image embedding (if used) with context embedding', + ) + # TODO: more thoroughly test out whether one of these choices is best and add a + # 'recommended' arg here. 'add' and 'prepend' seem to be roughly similar in + # performance + agent.add_argument( + '--n-image-tokens', + type=int, + default=1, + help=( + 'Number of tokens that the image encoding will consist of (when adding ' + 'or prepending)' + ), + ) + agent.set_defaults(reduction_type=None) + # This agent doesn't support any encoder output reductions + return agent + + def build_model(self, states=None): + """ + Return built model. + """ + return ImagePolyencoderModule(self.opt, self.dict, self.NULL_IDX) + + def batchify_image_features(self, batch: Batch) -> Batch: + """ + Return the image features as a Tensor of the correct type. + + Fill in missing feature vectors. Here, we require image features to be saved in + `batch` as a Tensor for passing through the image encoder. This is required for + data_parallel. + """ + + # Checks/formatting of batch.image + bsz = self._get_batch_size(batch) + if batch.image is None or len(batch.image) == 0: + batch.image = [None] * bsz + else: + assert len(batch.image) == bsz + + # Process all image feature vectors, or add in zero vectors if missing + processed_features_list = [] + processed_zero_features = self._process_image_features( + torch.zeros((self.image_features_dim,)) + ) + for orig_features in batch.image: + if isinstance(orig_features, torch.Tensor): + processed_features_list.append( + self._process_image_features(orig_features) + ) + else: + if orig_features is not None: + warn_once( + 'Unsupported image feature format. Image features will be ignored!' + ) + processed_features_list.append(processed_zero_features) + + # Turn into batchsize x image_features_dim for DataParallel + batch.image = torch.stack(processed_features_list) + + return batch + + def _get_batch_size(self, batch) -> int: + """ + Return the size of the batch. + + Use the size of the text vec if it exists; otherwise, use the length of the + image feature list. + """ + if batch.text_vec is not None: + return batch.text_vec.size(0) + else: + return len(batch.image) + + def _model_context_input(self, batch) -> Dict[str, Any]: + """ + Override PolyencoderAgent's context inputs into the model. + """ + return {'ctxt_tokens': batch.text_vec, 'ctxt_image': batch.image} + + def load_state_dict(self, state_dict): + """ + Override to account for weights used for image features. + """ + for tensor in ['dummy_image_enc', 'ones_mask']: + key = f'encoder_ctxt.{tensor}' + if hasattr(self.model.encoder_ctxt, tensor) and key not in state_dict: + state_dict[key] = getattr(self.model.encoder_ctxt, tensor) + if hasattr(self.model.encoder_ctxt, 'image_encoder'): + for layer_idx, layer in enumerate(self.model.encoder_ctxt.image_encoder): + for tensor in ['weight', 'bias']: + key = f'encoder_ctxt.image_encoder.{layer_idx}.{tensor}' + if hasattr(layer, tensor) and key not in state_dict: + state_dict[key] = getattr(layer, tensor) + super().load_state_dict(state_dict) + + +class ImagePolyencoderModule(PolyEncoderModule): + """ + Poly-encoder model with image features. + + Model that allows encoding image features and adding or concatenating them to the + context encoding. + """ + + def get_encoder(self, opt, dict_, null_idx, reduction_type, for_context: bool): + """ + Return encoder that allows for image features to be passed in, given options. + + :param opt: + opt dict + :param dict: + dictionary agent + :param null_idx: + null/pad index into dict + :param reduction_type: only used for compatibility with the superclass method + :param for_context: + whether this is the context encoder (as opposed to the candidate encoder) + :return: + either a TransformerEncoder or a ContextWithImageEncoder, initialized + correctly + """ + if for_context: + if reduction_type is not None: + raise NotImplementedError('No encoder output reductions supported!') + n_positions = get_n_positions_from_options(opt) + embeddings = self._get_embeddings( + dict_=dict_, null_idx=null_idx, embedding_size=opt['embedding_size'] + ) + return ContextWithImageEncoder( + n_heads=opt['n_heads'], + n_layers=opt['n_layers'], + embedding_size=opt['embedding_size'], + ffn_size=opt['ffn_size'], + vocabulary_size=len(dict_), + embedding=embeddings, + dropout=opt['dropout'], + attention_dropout=opt['attention_dropout'], + relu_dropout=opt['relu_dropout'], + padding_idx=null_idx, + learn_positional_embeddings=opt['learn_positional_embeddings'], + embeddings_scale=opt['embeddings_scale'], + n_positions=n_positions, + n_segments=opt['n_segments'], + activation=opt['activation'], + variant=opt['variant'], + output_scaling=opt['output_scaling'], + image_encoder_num_layers=opt['image_encoder_num_layers'], + image_features_dim=opt['image_features_dim'], + image_combination_mode=opt['image_combination_mode'], + n_image_tokens=opt['n_image_tokens'], + ) + else: + # The candidate encoder is the same as for PolyEncoderModule + return super().get_encoder( + opt=opt, + dict_=dict_, + null_idx=null_idx, + reduction_type=reduction_type, + for_context=for_context, + ) + + def _context_encoder_input(self, ctxt_inputs: Dict[str, Any]) -> Dict[str, Any]: + """ + Override PolyEncoderModule's inputs into the context encoder. + """ + assert set(ctxt_inputs.keys()) == {'ctxt_tokens', 'ctxt_image'} + return { + 'src_tokens': ctxt_inputs['ctxt_tokens'], + 'image_features': ctxt_inputs['ctxt_image'], + } + + def _get_context_batch_size(self, **ctxt_inputs: torch.Tensor) -> int: + """ + Return the batch size of the context. + """ + if ctxt_inputs['ctxt_tokens'] is not None: + return ctxt_inputs['ctxt_tokens'].size(0) + else: + return ctxt_inputs['ctxt_image'].size(0) diff --git a/parlai/agents/transformer/modules.py b/parlai/agents/transformer/modules.py index 87be814f8e2..f9714e526bd 100644 --- a/parlai/agents/transformer/modules.py +++ b/parlai/agents/transformer/modules.py @@ -429,7 +429,8 @@ def __init__( self.reduction_type = reduction_type self.padding_idx = padding_idx # this is --dropout, not --relu-dropout or --attention-dropout - self.dropout = nn.Dropout(p=dropout) + self.dropout_frac = dropout + self.dropout = nn.Dropout(p=self.dropout_frac) self.variant = variant self.n_segments = n_segments @@ -548,11 +549,7 @@ def forward(self, input, positions=None, segments=None): output = tensor.sum(dim=1) / divisor return output elif self.reduction_type is None or 'none' in self.reduction_type: - output = tensor - ret = (output, mask) - if self.reduction_type == 'none_with_pos_embs': - ret = (output, mask, position_embs) - return ret + return tensor, mask else: raise ValueError( "Can't handle --reduction-type {}".format(self.reduction_type) diff --git a/parlai/agents/transformer/polyencoder.py b/parlai/agents/transformer/polyencoder.py index 08de3274429..f97b3388769 100644 --- a/parlai/agents/transformer/polyencoder.py +++ b/parlai/agents/transformer/polyencoder.py @@ -7,13 +7,21 @@ """ Poly-encoder Agent. """ -from .biencoder import AddLabelFixedCandsTRA -from .modules import TransformerEncoder -from .modules import get_n_positions_from_options + +from typing import Any, Dict, Optional, Tuple + +import torch + +from parlai.core.opt import Opt from parlai.core.torch_ranker_agent import TorchRankerAgent +from .biencoder import AddLabelFixedCandsTRA +from .modules import ( + BasicAttention, + MultiHeadAttention, + TransformerEncoder, + get_n_positions_from_options, +) from .transformer import TransformerRankerAgent -from .modules import BasicAttention, MultiHeadAttention -import torch class PolyencoderAgent(TorchRankerAgent): @@ -60,16 +68,6 @@ def add_cmdline_args(cls, argparser): 'the key)', recommended='basic', ) - agent.add_argument( - '--polyencoder-attention-keys', - type=str, - default='context', - choices=['context', 'position'], - help='Input emb vectors for the first level of attention. ' - 'Context refers to the context outputs; position refers to the ' - 'computed position embeddings.', - recommended='context', - ) agent.add_argument( '--poly-attention-num-heads', type=int, @@ -96,6 +94,27 @@ def add_cmdline_args(cls, argparser): ) return agent + @classmethod + def upgrade_opt(cls, opt_from_disk: Opt): + # call the parent upgrades + opt_from_disk = super(PolyencoderAgent, cls).upgrade_opt(opt_from_disk) + + polyencoder_attention_keys_value = opt_from_disk.get( + 'polyencoder_attention_keys' + ) + if polyencoder_attention_keys_value is not None: + # 2020-02-19 We are deprecating this flag because it was used for a one-time + # set of experiments and won't be used again. This flag was defaulted to + # 'context', so throw an exception otherwise. + if polyencoder_attention_keys_value == 'context': + del opt_from_disk['polyencoder_attention_keys'] + else: + raise NotImplementedError( + 'This --polyencoder-attention-keys mode (found in commit 06f0d9f) is no longer supported!' + ) + + return opt_from_disk + def __init__(self, opt, shared=None): super().__init__(opt, shared) self.rank_loss = torch.nn.CrossEntropyLoss(reduce=True, size_average=True) @@ -164,7 +183,7 @@ def encode_candidates(self, padded_cands): Encode candidates. """ padded_cands = padded_cands.unsqueeze(1) - _, _, _, cand_rep = self.model(cand_tokens=padded_cands) + _, _, cand_rep = self.model(cand_tokens=padded_cands) return cand_rep def score_candidates(self, batch, cand_vecs, cand_encs=None): @@ -174,8 +193,8 @@ def score_candidates(self, batch, cand_vecs, cand_encs=None): The Poly-encoder encodes the candidate and context independently. Then, the model applies additional attention before ultimately scoring a candidate. """ - bsz = batch.text_vec.size(0) - ctxt_rep, ctxt_rep_mask, ctxt_pos, _ = self.model(ctxt_tokens=batch.text_vec) + bsz = self._get_batch_size(batch) + ctxt_rep, ctxt_rep_mask, _ = self.model(**self._model_context_input(batch)) if cand_encs is not None: if bsz == 1: @@ -184,21 +203,39 @@ def score_candidates(self, batch, cand_vecs, cand_encs=None): cand_rep = cand_encs.expand(bsz, cand_encs.size(1), -1) # bsz x num cands x seq len elif len(cand_vecs.shape) == 3: - _, _, _, cand_rep = self.model(cand_tokens=cand_vecs) + _, _, cand_rep = self.model(cand_tokens=cand_vecs) # bsz x seq len (if batch cands) or num_cands x seq len (if fixed cands) elif len(cand_vecs.shape) == 2: - _, _, _, cand_rep = self.model(cand_tokens=cand_vecs.unsqueeze(1)) + _, _, cand_rep = self.model(cand_tokens=cand_vecs.unsqueeze(1)) num_cands = cand_rep.size(0) # will be bsz if using batch cands cand_rep = cand_rep.expand(num_cands, bsz, -1).transpose(0, 1).contiguous() - scores = self.model( - ctxt_rep=ctxt_rep, - ctxt_rep_mask=ctxt_rep_mask, - cand_rep=cand_rep, - ctxt_pos=ctxt_pos, + ctxt_rep=ctxt_rep, ctxt_rep_mask=ctxt_rep_mask, cand_rep=cand_rep ) return scores + def _get_batch_size(self, batch) -> int: + """ + Return the size of the batch. + + Can be overridden by subclasses that do not always have text input. + """ + return batch.text_vec.size(0) + + def _model_context_input(self, batch) -> Dict[str, Any]: + """ + Create the input context value for the model. + + Must return a dictionary. This will be passed directly into the model via + `**kwargs`, i.e., + + >>> model(**_model_context_input(batch)) + + This is intentionally overridable so that richer models can pass additional + inputs. + """ + return {'ctxt_tokens': batch.text_vec} + def load_state_dict(self, state_dict): """ Override to account for codes. @@ -215,16 +252,27 @@ class PolyEncoderModule(torch.nn.Module): See https://arxiv.org/abs/1905.01969 for more details """ - def __init__(self, opt, dict, null_idx): + def __init__(self, opt, dict_, null_idx): super(PolyEncoderModule, self).__init__() self.null_idx = null_idx - self.encoder_ctxt = self.get_encoder(opt, dict, null_idx, 'none_with_pos_embs') - self.encoder_cand = self.get_encoder(opt, dict, null_idx, opt['reduction_type']) + self.encoder_ctxt = self.get_encoder( + opt=opt, + dict_=dict_, + null_idx=null_idx, + reduction_type=None, + for_context=True, + ) + self.encoder_cand = self.get_encoder( + opt=opt, + dict_=dict_, + null_idx=null_idx, + reduction_type=opt['reduction_type'], + for_context=False, + ) self.type = opt['polyencoder_type'] self.n_codes = opt['poly_n_codes'] self.attention_type = opt['poly_attention_type'] - self.attention_keys = opt.get('polyencoder_attention_keys', 'context') self.attention_num_heads = opt['poly_attention_num_heads'] self.codes_attention_type = opt['codes_attention_type'] self.codes_attention_num_heads = opt['codes_attention_num_heads'] @@ -265,7 +313,7 @@ def __init__(self, opt, dict, null_idx): get_weights=False, ) - def get_encoder(self, opt, dict, null_idx, reduction_type): + def get_encoder(self, opt, dict_, null_idx, reduction_type, for_context: bool): """ Return encoder, given options. @@ -275,23 +323,24 @@ def get_encoder(self, opt, dict, null_idx, reduction_type): dictionary agent :param null_idx: null/pad index into dict - :reduction_type: + :param reduction_type: reduction type for the encoder - + :param for_context: + whether this is the context encoder (as opposed to the candidate encoder). + Useful for subclasses. :return: a TransformerEncoder, initialized correctly """ n_positions = get_n_positions_from_options(opt) - embeddings = torch.nn.Embedding( - len(dict), opt['embedding_size'], padding_idx=null_idx + embeddings = self._get_embeddings( + dict_=dict_, null_idx=null_idx, embedding_size=opt['embedding_size'] ) - torch.nn.init.normal_(embeddings.weight, 0, opt['embedding_size'] ** -0.5) return TransformerEncoder( n_heads=opt['n_heads'], n_layers=opt['n_layers'], embedding_size=opt['embedding_size'], ffn_size=opt['ffn_size'], - vocabulary_size=len(dict), + vocabulary_size=len(dict_), embedding=embeddings, dropout=opt['dropout'], attention_dropout=opt['attention_dropout'], @@ -307,6 +356,13 @@ def get_encoder(self, opt, dict, null_idx, reduction_type): output_scaling=opt['output_scaling'], ) + def _get_embeddings(self, dict_, null_idx, embedding_size): + embeddings = torch.nn.Embedding( + len(dict_), embedding_size, padding_idx=null_idx + ) + torch.nn.init.normal_(embeddings.weight, 0, embedding_size ** -0.5) + return embeddings + def attend(self, attention_layer, queries, keys, values, mask): """ Apply attention. @@ -335,28 +391,30 @@ def attend(self, attention_layer, queries, keys, values, mask): else: raise Exception('Unrecognized type of attention') - def encode(self, ctxt_tokens, cand_tokens): + def encode( + self, cand_tokens: Optional[torch.Tensor], **ctxt_inputs: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Encode a text sequence. - :param ctxt_tokens: - 2D long tensor, batchsize x sent_len + :param ctxt_inputs: + Dictionary of context inputs. If not empty, should contain at least + 'ctxt_tokens', a 2D long tensor of shape batchsize x sent_len :param cand_tokens: 3D long tensor, batchsize x num_cands x sent_len Note this will actually view it as a 2D tensor :return: - (ctxt_rep, ctxt_mask, ctxt_pos, cand_rep) + (ctxt_rep, ctxt_mask, cand_rep) - ctxt_rep 3D float tensor, batchsize x n_codes x dim - ctxt_mask byte: batchsize x n_codes (all 1 in case of polyencoder with code. Which are the vectors to use in the ctxt_rep) - - ctxt_pos 3D float tensor, batchsize x sent_len x dim - cand_rep (3D float tensor) batchsize x num_cands x dim """ cand_embed = None ctxt_rep = None ctxt_rep_mask = None - ctxt_pos = None + if cand_tokens is not None: assert len(cand_tokens.shape) == 3 bsz = cand_tokens.size(0) @@ -364,23 +422,25 @@ def encode(self, ctxt_tokens, cand_tokens): cand_embed = self.encoder_cand(cand_tokens.view(bsz * num_cands, -1)) cand_embed = cand_embed.view(bsz, num_cands, -1) - if ctxt_tokens is not None: - assert len(ctxt_tokens.shape) == 2 - bsz = ctxt_tokens.size(0) + if len(ctxt_inputs) > 0: + assert 'ctxt_tokens' in ctxt_inputs + if ctxt_inputs['ctxt_tokens'] is not None: + assert len(ctxt_inputs['ctxt_tokens'].shape) == 2 + bsz = self._get_context_batch_size(**ctxt_inputs) # get context_representation. Now that depends on the cases. - ctxt_out, ctxt_mask, ctxt_pos = self.encoder_ctxt(ctxt_tokens) - att_keys = ctxt_out if self.attention_keys == 'context' else ctxt_pos + ctxt_out, ctxt_mask = self.encoder_ctxt( + **self._context_encoder_input(ctxt_inputs) + ) dim = ctxt_out.size(2) if self.type == 'codes': ctxt_rep = self.attend( self.code_attention, queries=self.codes.repeat(bsz, 1, 1), - keys=att_keys, + keys=ctxt_out, values=ctxt_out, mask=ctxt_mask, ) - ctxt_pos = None # we don't need this anymore ctxt_rep_mask = ctxt_rep.new_ones(bsz, self.n_codes).byte() elif self.type == 'n_first': @@ -389,17 +449,40 @@ def encode(self, ctxt_tokens, cand_tokens): difference = self.n_codes - ctxt_out.size(1) extra_rep = ctxt_out.new_zeros(bsz, difference, dim) ctxt_rep = torch.cat([ctxt_out, extra_rep], dim=1) - ctxt_pos = torch.cat([ctxt_pos, extra_rep], dim=1) extra_mask = ctxt_mask.new_zeros(bsz, difference) ctxt_rep_mask = torch.cat([ctxt_mask, extra_mask], dim=1) else: ctxt_rep = ctxt_out[:, 0 : self.n_codes, :] - ctxt_pos = ctxt_pos[:, 0 : self.n_codes, :] ctxt_rep_mask = ctxt_mask[:, 0 : self.n_codes] - return ctxt_rep, ctxt_rep_mask, ctxt_pos, cand_embed + return ctxt_rep, ctxt_rep_mask, cand_embed + + def _get_context_batch_size(self, **ctxt_inputs: torch.Tensor) -> int: + """ + Return the batch size of the context. + + Can be overridden by subclasses that do not always have text tokens in the + context. + """ + return ctxt_inputs['ctxt_tokens'].size(0) + + def _context_encoder_input(self, ctxt_inputs: Dict[str, Any]) -> Dict[str, Any]: + """ + Return the inputs to the context encoder as a dictionary. + + Must return a dictionary. This will be passed directly into the model via + `**kwargs`, i.e., + + >>> encoder_ctxt(**_context_encoder_input(ctxt_inputs)) + + This is needed because the context encoder's forward function may have different + argument names than that of the model itself. This is intentionally overridable + so that richer models can pass additional inputs. + """ + assert set(ctxt_inputs.keys()) == {'ctxt_tokens'} + return {'input': ctxt_inputs['ctxt_tokens']} - def score(self, ctxt_rep, ctxt_rep_mask, ctxt_pos, cand_embed): + def score(self, ctxt_rep, ctxt_rep_mask, cand_embed): """ Score the candidates. @@ -408,29 +491,24 @@ def score(self, ctxt_rep, ctxt_rep_mask, ctxt_pos, cand_embed): :param ctxt_rep_mask: 2D byte tensor, bsz x ctxt_len, in case there are some elements of the ctxt that we should not take into account. - :param ctx_pos: 3D float tensor, bsz x sent_len x dim :param cand_embed: 3D float tensor, bsz x num_cands x dim :return: scores, 2D float tensor: bsz x num_cands """ - # Attention keys determined by self.attention_keys - # 'context' == use context final rep; otherwise use context position embs - keys = ctxt_rep if self.attention_keys == 'context' else ctxt_pos # reduces the context representation to a 3D tensor bsz x num_cands x dim ctxt_final_rep = self.attend( - self.attention, cand_embed, keys, ctxt_rep, ctxt_rep_mask + self.attention, cand_embed, ctxt_rep, ctxt_rep, ctxt_rep_mask ) scores = torch.sum(ctxt_final_rep * cand_embed, 2) return scores def forward( self, - ctxt_tokens=None, cand_tokens=None, ctxt_rep=None, ctxt_rep_mask=None, - ctxt_pos=None, cand_rep=None, + **ctxt_inputs, ): """ Forward pass of the model. @@ -440,8 +518,9 @@ def forward( we need to have one single forward() method. Therefore the operation_type can be either 'encode' or 'score'. - :param ctxt_tokens: - tokenized contexts + :param ctxt_inputs: + Dictionary of context inputs. Will include at least 'ctxt_tokens', + containing tokenized contexts :param cand_tokens: tokenized candidates :param ctxt_rep: @@ -451,19 +530,15 @@ def forward( encoder :param ctxt_rep_mask: mask for ctxt rep - :param ctxt_pos: - position embeddings for the ctxt_rep. If self.type == 'codes', these - are None, as their use is earlier in the pipeline. :param cand_rep: encoded representation of the candidates """ - if ctxt_tokens is not None or cand_tokens is not None: - return self.encode(ctxt_tokens, cand_tokens) + if len(ctxt_inputs) > 0 or cand_tokens is not None: + return self.encode(cand_tokens=cand_tokens, **ctxt_inputs) elif ( ctxt_rep is not None and ctxt_rep_mask is not None and cand_rep is not None ): - # ctxt_pos can be none, if we are using codes (not first M) - return self.score(ctxt_rep, ctxt_rep_mask, ctxt_pos, cand_rep) + return self.score(ctxt_rep, ctxt_rep_mask, cand_rep) raise Exception('Unsupported operation') diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 7728d80d983..c0fdacc8892 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -932,7 +932,12 @@ def init_optim(self, params, optim_states=None, saved_optim_type=None): elif not optimstate_fp16 and self.fp16: # old optimizer was fp32, but now we're doing fp16. # this is a bit clunky, but alternatives are worse - self.optimizer.optimizer.load_state_dict(optim_states) + try: + self.optimizer.optimizer.load_state_dict(optim_states) + except ValueError: + warn_once( + 'WARNING: not loading optim state since model params changed.' + ) return else: # previously trained in fp32, loading in fp32. @@ -943,7 +948,9 @@ def init_optim(self, params, optim_states=None, saved_optim_type=None): try: self.optimizer.load_state_dict(optim_states) except ValueError: - print('WARNING: not loading optim state since model params changed.') + warn_once( + 'WARNING: not loading optim state since model params changed.' + ) def build_lr_scheduler(self, states=None, hard_reset=False): """ diff --git a/parlai/core/torch_image_agent.py b/parlai/core/torch_image_agent.py new file mode 100644 index 00000000000..c85d53e4960 --- /dev/null +++ b/parlai/core/torch_image_agent.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Subclass of TorchAgent used for handling image features. +""" + +from abc import abstractmethod +from typing import List + +import torch + +from parlai.core.message import Message +from parlai.core.torch_agent import Batch, TorchAgent + + +class TorchImageAgent(TorchAgent): + """ + Subclass of TorchAgent that allows for encoding image features. + + Provides flags and utility methods. + """ + + @classmethod + def add_cmdline_args(cls, argparser): + """ + Add command-line arguments specifically for this agent. + """ + super(TorchImageAgent, cls).add_cmdline_args(argparser) + agent = argparser.add_argument_group('Image args') + agent.add_argument( + '--image-features-dim', + type=int, + default=2048, + help='Dimensionality of image features', + ) + agent.add_argument( + '--image-encoder-num-layers', + type=int, + default=1, + recommended=1, + help='Number of linear layers to encode image features with', + ) + return agent + + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + self.image_features_dim = opt['image_features_dim'] + self.image_encoder_num_layers = opt['image_encoder_num_layers'] + + def batchify(self, obs_batch: List[Message], sort: bool = False) -> Batch: + """ + Override to handle image features. + """ + batch = super().batchify(obs_batch, sort) + batch = self.batchify_image_features(batch) + return batch + + @abstractmethod + def batchify_image_features(self, batch: Batch) -> Batch: + """ + Put this batch of images into the correct format for this agent. + + self._process_image_features() will likely be useful for this. + """ + raise NotImplementedError( + 'Subclasses must implement method for batching images!' + ) + + def _process_image_features(self, features: torch.Tensor) -> torch.Tensor: + """ + Format shape and type of input image-feature tensor. + """ + if features.dim() == 4: + features = features[0, :, 0, 0] + assert features.size() == (self.image_features_dim,) + if self.use_cuda: + features = features.cuda() + else: + features = features.cpu() + if self.opt.get('fp16'): + features = features.half() + else: + features = features.float() + + return features diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 230897a02a3..a9d84d28bd8 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -793,5 +793,104 @@ def test_invsqrt_learning_rate(self): ) +@testing_utils.skipUnlessTorch14 +class TestImagePolyencoder(unittest.TestCase): + """ + Unit tests for the ImagePolyencoderAgent. + + Test that the model is able to handle simple train tasks. + """ + + base_args = { + 'log_every_n_secs': 5, + 'validation_every_n_secs': 30, + 'model': 'transformer/image_polyencoder', + 'embedding_size': 32, + 'n_heads': 2, + 'n_layers': 2, + 'n_positions': 128, + 'truncate': 128, + 'ffn_size': 128, + 'variant': 'xlm', + 'activation': 'gelu', + 'candidates': 'batch', + 'eval_candidates': 'batch', # No inline cands + 'embeddings_scale': False, + 'gradient_clip': 0.1, + 'learningrate': 3e-5, + 'batchsize': 16, + 'optimizer': 'adamax', + 'learn_positional_embeddings': True, + 'reduction_type': 'first', + 'num_epochs': 30, + } + text_args = {'task': 'integration_tests:nocandidate'} + image_args = { + 'task': 'integration_tests:ImageTeacher', + 'image_mode': 'resnet152', + 'image_features_dim': 2048, + 'image_encoder_num_layers': 1, + 'image_combination_mode': 'prepend', + 'n_image_tokens': 1, + 'num_epochs': 60, + } + multitask_args = { + 'task': 'integration_tests:nocandidate,integration_tests:ImageTeacher', + 'image_mode': 'resnet152', + 'image_features_dim': 2048, + 'image_encoder_num_layers': 1, + 'image_combination_mode': 'prepend', + 'n_image_tokens': 1, + 'multitask_weights': [1, 1], + 'num_epochs': 30, + } + + @testing_utils.retry(ntries=3) + def test_text_task(self): + """ + Test that model correctly handles text task. + + Random chance is 10%, so this should be able to get much better than that very + quickly. + """ + args = Opt({**self.base_args, **self.text_args}) + valid, test = testing_utils.train_model(args) + assert ( + valid['accuracy'] > 0.2 + ), f'ImagePolyencoderAgent val-set accuracy on a simple task was {valid["accuracy"].value():0.2f}.' + + @testing_utils.retry(ntries=3) + @testing_utils.skipUnlessTorch + @testing_utils.skipUnlessGPU + def test_image_task(self): + """ + Test that model correctly handles a basic image training task. + + Random chance is 10%, so this should be able to get much better than that very + quickly. + """ + args = Opt({**self.base_args, **self.image_args}) + valid, test = testing_utils.train_model(args) + assert ( + valid['accuracy'] > 0.15 + ), f'ImagePolyencoderAgent val-set accuracy on a simple task was {valid["accuracy"].value():0.2f}.' + + @testing_utils.retry(ntries=3) + @testing_utils.skipUnlessTorch + @testing_utils.skipUnlessGPU + def test_multitask(self): + """ + Test that model correctly handles multiple inputs. + + Random chance is 10%, so this should be able to get much better than that very + quickly. + """ + args = Opt({**self.base_args, **self.multitask_args}) + valid, test = testing_utils.train_model(args) + assert ( + valid['accuracy'] > 0.2 + ), f'ImagePolyencoderAgent val-set accuracy on a simple task was {valid["accuracy"].value():0.2f}.' + + if __name__ == '__main__': unittest.main()