diff --git a/.gitignore b/.gitignore index f6dd6c4..beca99b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,10 @@ # Data files -data +data/ +my_stuff/ +testing.py + +# Model files +models/ # Experiment files exp @@ -7,6 +12,7 @@ scripts/dev # Image files img/cst +img/attention_visualizations/ # Editor files *.DS_Store diff --git a/Hello.py b/Hello.py new file mode 100644 index 0000000..0ee124d --- /dev/null +++ b/Hello.py @@ -0,0 +1,28 @@ +import streamlit as st + +st.set_page_config( + page_title="Hello", + page_icon="👋", +) + +st.write("# Welcome to Streamlit! 👋") + +st.sidebar.success("Select a demo above.") + +st.markdown( + """ + Streamlit is an open-source app framework built specifically for + Machine Learning and Data Science projects. + **👈 Select a demo from the sidebar** to see some examples + of what Streamlit can do! + ### Want to learn more? + - Check out [streamlit.io](https://streamlit.io) + - Jump into our [documentation](https://docs.streamlit.io) + - Ask a question in our [community + forums](https://discuss.streamlit.io) + ### See more complex demos + - Use a neural net to [analyze the Udacity Self-driving Car Image + Dataset](https://github.com/streamlit/demo-self-driving) + - Explore a [New York City rideshare dataset](https://github.com/streamlit/demo-uber-nyc-pickups) +""" +) diff --git a/README.md b/README.md index 702e591..31c7486 100644 --- a/README.md +++ b/README.md @@ -1,79 +1,56 @@ -# FiLM: Visual Reasoning with a General Conditioning Layer - -## Ethan Perez, Florian Strub, Harm de Vries, Vincent Dumoulin, Aaron Courville - -This code implements a Feature-wise Linear Modulation approach to Visual Reasoning - answering multi-step questions on images. This codebase reproduces results from the AAAI 2018 paper "FiLM: Visual Reasoning with a General Conditioning Layer" (citation [here](https://github.com/ethanjperez/film#film)), which extends prior work "Learning Visual Reasoning Without Strong Priors" presented at ICML's MLSLP workshop. Please see the [retrospective paper](https://ml-retrospectives.github.io/neurips2019/accepted_retrospectives/2019/film/) (citation [here](https://github.com/ethanjperez/film#retrospective-for-film)) for an honest reflection on FiLM after the work that followed, including when to (and not to) use FiLM and tips-and-tricks for effectively training a network with FiLM layers. - -### Code Outline - -This code is a fork from the code for "Inferring and Executing Programs for Visual Reasoning" available [here](https://github.com/facebookresearch/clevr-iep). - -Our FiLM Generator is located in [vr/models/film_gen.py](https://github.com/ethanjperez/film/blob/master/vr/models/film_gen.py), and our FiLMed Network and FiLM layer implementation is located in [vr/models/filmed_net.py](https://github.com/ethanjperez/film/blob/master/vr/models/filmed_net.py). - -We inserted a new model mode "FiLM" which integrates into forked code for [CLEVR baselines](https://arxiv.org/abs/1612.06890) and the [Program Generator + Execution Engine model](https://arxiv.org/abs/1705.03633). Throughout the code, for our model, our FiLM Generator acts in place of the "program generator" which generates the FiLM parameters for an the FiLMed Network, i.e. "execution engine." In some sense, FiLM parameters can vaguely be thought of as a "soft program" of sorts, but we use this denotation in the code to integrate better with the forked models. - -### Setup and Training - -Because of this integration, setup instructions for the FiLM model are nearly the same as for "Inferring and Executing Programs for Visual Reasoning." We will post more detailed instructions on how to use our code in particular soon for more step-by-step guidance. For now, the guidelines below should give substantial direction to those interested. - -First, follow the virtual environment setup [instructions](https://github.com/facebookresearch/clevr-iep#setup). - -Second, follow the CLEVR data preprocessing [instructions](https://github.com/facebookresearch/clevr-iep/blob/master/TRAINING.md#preprocessing-clevr). - -Lastly, model training details are similar at a high level (though adapted for FiLM and our repo) to [these](https://github.com/facebookresearch/clevr-iep/blob/master/TRAINING.md#training-on-clevr) for the Program Generator + Execution Engine model, though our model only uses one step of training, rather than a 3-step training procedure. - -The below script has the hyperparameters and settings to reproduce FiLM CLEVR results: +# Todo +- [-] Rapport +- Streamlit + - [ ] Docu sphinx + - [ ] Gros modèle pré-entraîné + - [x] Obtention des weights + - [ ] Streamlit poser questions sur image + - [ ] Visualisation des histogrammes gamma/beta + - [ ] Visualisation tSNE + - [ ] Visualisation de ce que le MLP "voit" + - [ ] Petit modèle, train sur CPU + - Avoir aussi le preprocessing réduit? + - Comment avoir un temps d'entraînement rapide? réduire architecture? réduire train/val dataset? + - [ ] Streamlit train + - [ ] Streamlit questions +- Bonus: + - Zero-shot + - Graph comparaison de performance sur jeux de donnée classique + +# Requirements +- Python 3.12 +- Other dependencies listed in `requirements.txt` + +# References +- The code in this repo is heavily inspired by the repos [Film](https://github.com/ethanjperez/film) and [Clever-iep](https://github.com/facebookresearch/clevr-iep) +- [Distill: Feature wise transformations](https://distill.pub/2018/feature-wise-transformations/) +- [Arxiv: FiLM: Visual Reasoning with a General Conditioning Layer](https://arxiv.org/pdf/1709.07871) + +# Get the data +For each script, check the `.sh` and/or the `.py` associated file to modify parameters. +To download the data, run: ```bash -sh scripts/train/film.sh +mkdir data +wget https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip -O data/CLEVR_v1.0.zip +unzip data/CLEVR_v1.0.zip -d data ``` - -For CLEVR-Humans, data preprocessing instructions are [here](https://github.com/facebookresearch/clevr-iep/blob/master/TRAINING.md#preprocessing-clevr-humans). -The below script has the hyperparameters and settings to reproduce FiLM CLEVR-Humans results: +To preprocess the data from pngs to a h5 file for each train/val/test set, run the following code. The data will be the raw pixels, there are options to extract features with the option `--model resnet101` (1024x14x14 output), or to set a maximum number of X processed images `--max_images X` (check `extract_features.py`). ```bash -sh scripts/train/film_humans.sh +sh scripts/extract_features.sh ``` - -Training a CLEVR-CoGenT model is very similar to training a normal CLEVR model. Training a model from pixels requires modifying the preprocessing with scripts included in the repo to preprocess pixels. The scripts to reproduce our results are also located in the scripts/train/ folder. - -We tried to not break existing models from the CLEVR codebase with our modifications, but we haven't tested their code after our changes. We recommend using using the CLEVR and "Inferring and Executing Programs for Visual Reasoning" code directly. - -Training a solid FiLM CLEVR model should only take ~12 hours on a good GPU (See training curves in the paper appendix). - -### Running models - -We added an interactive command line tool for use with the below command/script. It's actually super enjoyable to play around with trained models. It's great for gaining intuition around what various trained models have or have not learned and how they tackle reasoning questions. +To preprocess the questions, execute this script: ```bash -python run_model.py --program_generator --execution_engine +sh scripts/preprocess_questions.sh ``` -By default, the command runs on [this CLEVR image](https://github.com/ethanjperez/film/blob/master/img/CLEVR_val_000017.png) in our repo, but you may modify which image to use via command line flag to test on any CLEVR image. - -CLEVR vocab is enforced by default, but for CLEVR-Humans models, for example, you may append the command line flag option '--enforce_clevr_vocab 0' to ask any string of characters you please. - -In addition, one easier way to try out zero-shot with FiLM is to run a trained model with run_model.py, but with the implemented debug command line flag on so you can manipulate the FiLM parameters modulating the FiLMed network during the forward computation. For example, '--debug_every -1' will stop the program after the model generates FiLM parameters but before the FiLMed network carries out its forward pass using FiLM layers. - -Thanks for stopping by, and we hope you enjoy playing around with FiLM! - -### Bibtex - -#### FiLM +To train the model: ```bash -@InProceedings{perez2018film, - title={FiLM: Visual Reasoning with a General Conditioning Layer}, - author={Ethan Perez and Florian Strub and Harm de Vries and Vincent Dumoulin and Aaron C. Courville}, - booktitle={AAAI}, - year={2018} -} +sh scripts/train/film.sh ``` -#### Retrospective for FiLM +To run the model (on `CLEVR_val_000017.png` by default): ```bash -@misc{perez2019retrospective, - author = {Perez, Ethan}, - title = {{Retroespective for: "FiLM: Visual Reasoning with a General Conditioning Layer"}}, - year = {2019}, - howpublished = {\url{https://ml-retrospectives.github.io/published_retrospectives/2019/film/}}, -} -``` +sh scripts/run_model.sh +``` \ No newline at end of file diff --git a/docs/projet_DL_slidespres.pdf b/docs/projet_DL_slidespres.pdf new file mode 100644 index 0000000..2eb3fcb Binary files /dev/null and b/docs/projet_DL_slidespres.pdf differ diff --git a/img/stats/Betas: Layer 1.png b/img/stats/Betas Layer 1.png similarity index 100% rename from img/stats/Betas: Layer 1.png rename to img/stats/Betas Layer 1.png diff --git a/img/stats/Betas: Layer 2.png b/img/stats/Betas Layer 2.png similarity index 100% rename from img/stats/Betas: Layer 2.png rename to img/stats/Betas Layer 2.png diff --git a/img/stats/Betas: Layer 3.png b/img/stats/Betas Layer 3.png similarity index 100% rename from img/stats/Betas: Layer 3.png rename to img/stats/Betas Layer 3.png diff --git a/img/stats/Betas: Layer 4.png b/img/stats/Betas Layer 4.png similarity index 100% rename from img/stats/Betas: Layer 4.png rename to img/stats/Betas Layer 4.png diff --git a/img/stats/Gammas: Layer 1.png b/img/stats/Gammas Layer 1.png similarity index 100% rename from img/stats/Gammas: Layer 1.png rename to img/stats/Gammas Layer 1.png diff --git a/img/stats/Gammas: Layer 2.png b/img/stats/Gammas Layer 2.png similarity index 100% rename from img/stats/Gammas: Layer 2.png rename to img/stats/Gammas Layer 2.png diff --git a/img/stats/Gammas: Layer 3.png b/img/stats/Gammas Layer 3.png similarity index 100% rename from img/stats/Gammas: Layer 3.png rename to img/stats/Gammas Layer 3.png diff --git a/img/stats/Gammas: Layer 4.png b/img/stats/Gammas Layer 4.png similarity index 100% rename from img/stats/Gammas: Layer 4.png rename to img/stats/Gammas Layer 4.png diff --git a/pages/Large pre-trained model.py b/pages/Large pre-trained model.py new file mode 100644 index 0000000..22a5a9d --- /dev/null +++ b/pages/Large pre-trained model.py @@ -0,0 +1,105 @@ +import streamlit as st +import subprocess +import platform +import os +import time +import plotly_express as px +import numpy as np + +# Chose the python interpreter path +current_os = platform.system() +if current_os == "Windows": + python_executable = ".venv\Scripts\python.exe" +else: + python_executable = ".venv/bin/python" + +st.title("Feature-wise Linear Modulations") + +tab1, tab2 = st.tabs(["Visualizing", "Training"]) + +with tab1: + # Display error message if not data/best.pt + if not os.path.exists("data/best.pt"): + st.error("No model found at \"data/best.pt\". Please train or download the model") + + # Select and display the image with default image 17 + img_number = st.selectbox( + "Select an image number:", [str(i) for i in range(10, 20)], index=7 + ) + st.image( + f"img/CLEVR_val_0000{img_number}.png", + caption=f"CLEVR_val_0000{img_number}.png", + # use_container_width=True, + width=400, + ) + + # Checkbox to visualize attention + visualize = st.checkbox("Visualize attention") + + # Create a form so that hitting Enter submits the input + with st.form(key="question_form"): + user_input = st.text_input("Enter your question:") + submit_button = st.form_submit_button("Submit") + + + if submit_button: + # Launch the process (adjust parameters as needed) + process = subprocess.Popen( + [ + python_executable, + "scripts/run_model.py", + "--image", + f"img/CLEVR_val_0000{img_number}.png", + "--streamlit", + "True", + "--visualize_attention", + str(visualize), + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Send the user input to the process and capture the output + output, error = process.communicate(input = user_input) + output = output.strip() # Remove leading/trailing whitespace + + # Display the output + st.subheader("Model Response:") + st.write(output) + + # Optionally display any error messages + # if error: + # st.subheader("Errors:") + # st.write(error) + + # Display the image with attention, if requested + if visualize: + attention_img_path = f"img/attention_visualizations/{user_input} {output}/pool-feature-locations.png" + # Wait for the image to be created + while not os.path.exists(attention_img_path): + time.sleep(1) + st.image(attention_img_path, caption="Image with attention", width=400) + + # importation and processing of the parameters values for the three resblocks + parameters=torch.load('D:\\projet FiLM deep learning\\img\\params.pt') + beta=[] + gamma=[] + for i in range(3): + beta.extend(parameters[0][i][0:128].tolist()) + gamma.extend(parameters[0][i][128:256].tolist()) + + # ploting the histograms with Plotly + hist_gammas = px.histogram(gamma, nbins=70, marginal='rug') + hist_gammas.update_layout(title='Histogram of gammas values of the 3 resblocks', xaxis_title='Value', yaxis_title='Frequency') + st.plotly_chart(hist_gammas) + hist_betas = px.histogram(beta, nbins=70, marginal='rug') + hist_betas.update_layout(title='Histogram of gammas values of the 3 resblocks', xaxis_title='Value', yaxis_title='Frequency') + st.plotly_chart(hist_betas) + +with tab2: + epoch = st.slider("Epoch", 1, 20, 1) + model_choice = st.selectbox("Model", ["resnet", "raw"]) + if st.button("Train"): + st.write(f"Training started with {model_choice} for {epoch} epochs") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..aaa94be --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.ruff] +ignore = ["F401","E402"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6ad76ab..f76414e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,80 @@ -http://download.pytorch.org/whl/cu80/torch-0.1.11.post5-cp35-cp35m-linux_x86_64.whl -numpy -Pillow -scipy -torchvision -h5py -tqdm +altair==5.5.0 +asttokens==3.0.0 +attrs==25.1.0 +blinker==1.9.0 +cachetools==5.5.1 +certifi==2024.12.14 +charset-normalizer==3.4.1 +click==8.1.8 +contourpy==1.3.1 +cycler==0.12.1 +decorator==5.2.1 +executing==2.2.0 +filelock==3.16.1 +fonttools==4.55.3 +fsspec==2024.12.0 +gitdb==4.0.12 +GitPython==3.1.44 +h5py==3.13.0 +idna==3.10 +imageio==2.37.0 +ipdb==0.13.13 +ipython==9.0.2 +ipython_pygments_lexers==1.1.1 +jedi==0.19.2 +Jinja2==3.1.5 +joblib==1.4.2 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 +kiwisolver==1.4.8 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib==3.10.0 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mpmath==1.3.0 +narwhals==1.24.0 +networkx==3.4.2 +numpy==1.26.4 +packaging==24.2 +pandas==2.2.3 +parso==0.8.4 +pexpect==4.9.0 +pillow==11.1.0 +prompt_toolkit==3.0.50 +protobuf==5.29.3 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pyarrow==19.0.0 +pydeck==0.9.1 +Pygments==2.19.1 +pyparsing==3.2.1 +python-dateutil==2.9.0.post0 +pytz==2024.2 +referencing==0.36.2 +requests==2.32.3 +rich==13.9.4 +rpds-py==0.22.3 +scikit-learn==1.6.1 +scipy==1.15.1 +setuptools==76.0.0 +six==1.17.0 +smmap==5.0.2 +stack-data==0.6.3 +streamlit==1.41.1 +sympy==1.13.1 +tenacity==9.0.0 +termcolor==2.5.0 +threadpoolctl==3.5.0 +toml==0.10.2 +torch==2.6.0 +torchvision==0.21.0 +tornado==6.4.2 +tqdm==4.67.1 +traitlets==5.14.3 +typing_extensions==4.12.2 +tzdata==2025.1 +urllib3==2.3.0 +watchdog==6.0.0 +wcwidth==0.2.13 +plotly.express=0.4.0 diff --git a/scripts/extract_features.py b/scripts/extract_features.py index d6ef680..062d065 100644 --- a/scripts/extract_features.py +++ b/scripts/extract_features.py @@ -4,117 +4,128 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import argparse, os, json +import argparse, os, json # noqa: E401, F401 import h5py import numpy as np -from scipy.misc import imread, imresize -from tqdm import tqdm +from PIL import Image +import imageio import torch import torchvision +if torch.cuda.is_available(): + device = torch.device("cuda") +elif torch.backends.mps.is_available(): + device = torch.device("mps") +else: + device = torch.device("cpu") parser = argparse.ArgumentParser() -parser.add_argument('--input_image_dir', required=True) -parser.add_argument('--max_images', default=None, type=int) -parser.add_argument('--output_h5_file', required=True) +parser.add_argument("--input_image_dir", required=True) +parser.add_argument("--max_images", default=None, type=int) +parser.add_argument("--output_h5_file", required=True) -parser.add_argument('--image_height', default=224, type=int) -parser.add_argument('--image_width', default=224, type=int) +parser.add_argument("--image_height", default=224, type=int) +parser.add_argument("--image_width", default=224, type=int) -parser.add_argument('--model', default='resnet101') -parser.add_argument('--model_stage', default=3, type=int) -parser.add_argument('--batch_size', default=128, type=int) +parser.add_argument("--model", default="none") +parser.add_argument("--model_stage", default=3, type=int) +parser.add_argument("--batch_size", default=128, type=int) def build_model(args): - if args.model.lower() == 'none': - return None - if not hasattr(torchvision.models, args.model): - raise ValueError('Invalid model "%s"' % args.model) - if not 'resnet' in args.model: - raise ValueError('Feature extraction only supports ResNets') - cnn = getattr(torchvision.models, args.model)(pretrained=True) - layers = [ - cnn.conv1, - cnn.bn1, - cnn.relu, - cnn.maxpool, - ] - for i in range(args.model_stage): - name = 'layer%d' % (i + 1) - layers.append(getattr(cnn, name)) - model = torch.nn.Sequential(*layers) - model.cuda() - model.eval() - return model + if args.model.lower() == "none": + return None + if not hasattr(torchvision.models, args.model): + raise ValueError('Invalid model "%s"' % args.model) + if "resnet" not in args.model: + raise ValueError("Feature extraction only supports ResNets") + cnn = getattr(torchvision.models, args.model)(pretrained=True) + layers = [ + cnn.conv1, + cnn.bn1, + cnn.relu, + cnn.maxpool, + ] + for i in range(args.model_stage): + name = "layer%d" % (i + 1) + layers.append(getattr(cnn, name)) + model = torch.nn.Sequential(*layers) + model.to(device) # Send model to the MPS (or CPU) device. + model.eval() + return model def run_batch(cur_batch, model): - if model is None: - image_batch = np.concatenate(cur_batch, 0).astype(np.float32) - return image_batch / 255. # Scale pixel values to [0, 1] - - mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) - std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) + if model is None: + image_batch = np.concatenate(cur_batch, 0).astype(np.float32) + return image_batch / 255.0 # Scale pixel values to [0, 1] + mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) + std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) - image_batch = np.concatenate(cur_batch, 0).astype(np.float32) - image_batch = (image_batch / 255.0 - mean) / std - image_batch = torch.FloatTensor(image_batch).cuda() - image_batch = torch.autograd.Variable(image_batch, volatile=True) + image_batch = np.concatenate(cur_batch, 0).astype(np.float32) + image_batch = (image_batch / 255.0 - mean) / std + image_batch = torch.FloatTensor(image_batch).to(device) - feats = model(image_batch) - feats = feats.data.cpu().clone().numpy() + # Disable gradient calculations for inference. + with torch.no_grad(): + feats = model(image_batch) - return feats + feats = feats.cpu().numpy() + return feats def main(args): - input_paths = [] - idx_set = set() - for fn in os.listdir(args.input_image_dir): - if not fn.endswith('.png'): continue - idx = int(os.path.splitext(fn)[0].split('_')[-1]) - input_paths.append((os.path.join(args.input_image_dir, fn), idx)) - idx_set.add(idx) - input_paths.sort(key=lambda x: x[1]) - assert len(idx_set) == len(input_paths) - assert min(idx_set) == 0 and max(idx_set) == len(idx_set) - 1 - if args.max_images is not None: - input_paths = input_paths[:args.max_images] - print(input_paths[0]) - print(input_paths[-1]) - - model = build_model(args) - - img_size = (args.image_height, args.image_width) - with h5py.File(args.output_h5_file, 'w') as f: - feat_dset = None - i0 = 0 - cur_batch = [] - for i, (path, idx) in tqdm(enumerate(input_paths)): - img = imread(path, mode='RGB') - img = imresize(img, img_size, interp='bicubic') - img = img.transpose(2, 0, 1)[None] - cur_batch.append(img) - if len(cur_batch) == args.batch_size: - feats = run_batch(cur_batch, model) - if feat_dset is None: - N = len(input_paths) - _, C, H, W = feats.shape - feat_dset = f.create_dataset('features', (N, C, H, W), - dtype=np.float32) - i1 = i0 + len(cur_batch) - feat_dset[i0:i1] = feats - i0 = i1 + input_paths = [] + idx_set = set() + for fn in os.listdir(args.input_image_dir): + if not fn.endswith(".png"): + continue + idx = int(os.path.splitext(fn)[0].split("_")[-1]) + input_paths.append((os.path.join(args.input_image_dir, fn), idx)) + idx_set.add(idx) + input_paths.sort(key=lambda x: x[1]) + assert len(idx_set) == len(input_paths) + assert min(idx_set) == 0 and max(idx_set) == len(idx_set) - 1 + if args.max_images is not None: + input_paths = input_paths[: args.max_images] + print(input_paths[0]) + print(input_paths[-1]) + + model = build_model(args) + + img_size = (args.image_height, args.image_width) + with h5py.File(args.output_h5_file, "w") as f: + feat_dset = None + i0 = 0 cur_batch = [] - if len(cur_batch) > 0: - feats = run_batch(cur_batch, model) - i1 = i0 + len(cur_batch) - feat_dset[i0:i1] = feats - return - - -if __name__ == '__main__': - args = parser.parse_args() - main(args) + for i, (path, idx) in enumerate(input_paths): + img = imageio.imread(path, pilmode="RGB") + im = Image.fromarray(img) + im_resized = im.resize((img_size[1], img_size[0]), resample=Image.BICUBIC) + img = np.array(im_resized) + img = img.transpose(2, 0, 1)[None] + cur_batch.append(img) + if len(cur_batch) == args.batch_size: + feats = run_batch(cur_batch, model) + if feat_dset is None: + N = len(input_paths) + _, C, H, W = feats.shape + feat_dset = f.create_dataset( + "features", (N, C, H, W), dtype=np.float32 + ) + i1 = i0 + len(cur_batch) + feat_dset[i0:i1] = feats + i0 = i1 + print("Processed %d / %d images" % (i1, len(input_paths))) + cur_batch = [] + if len(cur_batch) > 0: + feats = run_batch(cur_batch, model) + i1 = i0 + len(cur_batch) + feat_dset[i0:i1] = feats + print("Processed %d / %d images" % (i1, len(input_paths))) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/scripts/extract_features.sh b/scripts/extract_features.sh new file mode 100644 index 0000000..9ca53be --- /dev/null +++ b/scripts/extract_features.sh @@ -0,0 +1,11 @@ +python scripts/extract_features.py \ + --input_image_dir data/CLEVR_v1.0/images/train \ + --output_h5_file data/train_features_raw.h5 + +python scripts/extract_features.py \ + --input_image_dir data/CLEVR_v1.0/images/val \ + --output_h5_file data/val_features_raw.h5 + +python scripts/extract_features.py \ + --input_image_dir data/CLEVR_v1.0/images/test \ + --output_h5_file data/test_features_raw.h5 \ No newline at end of file diff --git a/scripts/preprocess_questions.py b/scripts/preprocess_questions.py index 683783f..e65c519 100644 --- a/scripts/preprocess_questions.py +++ b/scripts/preprocess_questions.py @@ -8,7 +8,8 @@ import sys import os -sys.path.insert(0, os.path.abspath('.')) + +sys.path.insert(0, os.path.abspath(".")) import argparse @@ -28,166 +29,167 @@ parser = argparse.ArgumentParser() -parser.add_argument('--mode', default='prefix', - choices=['chain', 'prefix', 'postfix']) -parser.add_argument('--input_questions_json', required=True) -parser.add_argument('--input_vocab_json', default='') -parser.add_argument('--expand_vocab', default=0, type=int) -parser.add_argument('--unk_threshold', default=1, type=int) -parser.add_argument('--encode_unk', default=0, type=int) +parser.add_argument("--mode", default="prefix", choices=["chain", "prefix", "postfix"]) +parser.add_argument("--input_questions_json", required=True) +parser.add_argument("--input_vocab_json", default="") +parser.add_argument("--expand_vocab", default=0, type=int) +parser.add_argument("--unk_threshold", default=1, type=int) +parser.add_argument("--encode_unk", default=0, type=int) -parser.add_argument('--output_h5_file', required=True) -parser.add_argument('--output_vocab_json', default='') +parser.add_argument("--output_h5_file", required=True) +parser.add_argument("--output_vocab_json", default="") def program_to_str(program, mode): - if mode == 'chain': - if not vr.programs.is_chain(program): - return None - return vr.programs.list_to_str(program) - elif mode == 'prefix': - program_prefix = vr.programs.list_to_prefix(program) - return vr.programs.list_to_str(program_prefix) - elif mode == 'postfix': - program_postfix = vr.programs.list_to_postfix(program) - return vr.programs.list_to_str(program_postfix) - return None + if mode == "chain": + if not vr.programs.is_chain(program): + return None + return vr.programs.list_to_str(program) + elif mode == "prefix": + program_prefix = vr.programs.list_to_prefix(program) + return vr.programs.list_to_str(program_prefix) + elif mode == "postfix": + program_postfix = vr.programs.list_to_postfix(program) + return vr.programs.list_to_str(program_postfix) + return None def main(args): - if (args.input_vocab_json == '') and (args.output_vocab_json == ''): - print('Must give one of --input_vocab_json or --output_vocab_json') - return - - print('Loading data') - with open(args.input_questions_json, 'r') as f: - questions = json.load(f)['questions'] - - # Either create the vocab or load it from disk - if args.input_vocab_json == '' or args.expand_vocab == 1: - print('Building vocab') - if 'answer' in questions[0]: - answer_token_to_idx = build_vocab( - (q['answer'] for q in questions) - ) - question_token_to_idx = build_vocab( - (q['question'] for q in questions), - min_token_count=args.unk_threshold, - punct_to_keep=[';', ','], punct_to_remove=['?', '.'] - ) - all_program_strs = [] - for q in questions: - if 'program' not in q: continue - program_str = program_to_str(q['program'], args.mode) - if program_str is not None: - all_program_strs.append(program_str) - program_token_to_idx = build_vocab(all_program_strs) - vocab = { - 'question_token_to_idx': question_token_to_idx, - 'program_token_to_idx': program_token_to_idx, - 'answer_token_to_idx': answer_token_to_idx, - } - - if args.input_vocab_json != '': - print('Loading vocab') - if args.expand_vocab == 1: - new_vocab = vocab - with open(args.input_vocab_json, 'r') as f: - vocab = json.load(f) - if args.expand_vocab == 1: - num_new_words = 0 - for word in new_vocab['question_token_to_idx']: - if word not in vocab['question_token_to_idx']: - print('Found new word %s' % word) - idx = len(vocab['question_token_to_idx']) - vocab['question_token_to_idx'][word] = idx - num_new_words += 1 - print('Found %d new words' % num_new_words) - - if args.output_vocab_json != '': - with open(args.output_vocab_json, 'w') as f: - json.dump(vocab, f) - - # Encode all questions and programs - print('Encoding data') - questions_encoded = [] - programs_encoded = [] - question_families = [] - orig_idxs = [] - image_idxs = [] - answers = [] - types = [] - for orig_idx, q in enumerate(questions): - question = q['question'] - if 'program' in q: - types += [q['program'][-1]['function']] - - orig_idxs.append(orig_idx) - image_idxs.append(q['image_index']) - if 'question_family_index' in q: - question_families.append(q['question_family_index']) - question_tokens = tokenize(question, - punct_to_keep=[';', ','], - punct_to_remove=['?', '.']) - question_encoded = encode(question_tokens, - vocab['question_token_to_idx'], - allow_unk=args.encode_unk == 1) - questions_encoded.append(question_encoded) - - if 'program' in q: - program = q['program'] - program_str = program_to_str(program, args.mode) - program_tokens = tokenize(program_str) - program_encoded = encode(program_tokens, vocab['program_token_to_idx']) - programs_encoded.append(program_encoded) - - if 'answer' in q: - answers.append(vocab['answer_token_to_idx'][q['answer']]) - - # Pad encoded questions and programs - max_question_length = max(len(x) for x in questions_encoded) - for qe in questions_encoded: - while len(qe) < max_question_length: - qe.append(vocab['question_token_to_idx']['']) - - if len(programs_encoded) > 0: - max_program_length = max(len(x) for x in programs_encoded) - for pe in programs_encoded: - while len(pe) < max_program_length: - pe.append(vocab['program_token_to_idx']['']) - - # Create h5 file - print('Writing output') - questions_encoded = np.asarray(questions_encoded, dtype=np.int32) - programs_encoded = np.asarray(programs_encoded, dtype=np.int32) - print(questions_encoded.shape) - print(programs_encoded.shape) - - mapping = {} - for i, t in enumerate(set(types)): - mapping[t] = i - - print(mapping) - - types_coded = [] - for t in types: - types_coded += [mapping[t]] - - with h5py.File(args.output_h5_file, 'w') as f: - f.create_dataset('questions', data=questions_encoded) - f.create_dataset('image_idxs', data=np.asarray(image_idxs)) - f.create_dataset('orig_idxs', data=np.asarray(orig_idxs)) + if (args.input_vocab_json == "") and (args.output_vocab_json == ""): + print("Must give one of --input_vocab_json or --output_vocab_json") + return + + print("Loading data") + with open(args.input_questions_json, "r") as f: + questions = json.load(f)["questions"] + + # Either create the vocab or load it from disk + if args.input_vocab_json == "" or args.expand_vocab == 1: + print("Building vocab") + if "answer" in questions[0]: + answer_token_to_idx = build_vocab((q["answer"] for q in questions)) + question_token_to_idx = build_vocab( + (q["question"] for q in questions), + min_token_count=args.unk_threshold, + punct_to_keep=[";", ","], + punct_to_remove=["?", "."], + ) + all_program_strs = [] + for q in questions: + if "program" not in q: + continue + program_str = program_to_str(q["program"], args.mode) + if program_str is not None: + all_program_strs.append(program_str) + program_token_to_idx = build_vocab(all_program_strs) + vocab = { + "question_token_to_idx": question_token_to_idx, + "program_token_to_idx": program_token_to_idx, + "answer_token_to_idx": answer_token_to_idx, + } + + if args.input_vocab_json != "": + print("Loading vocab") + if args.expand_vocab == 1: + new_vocab = vocab + with open(args.input_vocab_json, "r") as f: + vocab = json.load(f) + if args.expand_vocab == 1: + num_new_words = 0 + for word in new_vocab["question_token_to_idx"]: + if word not in vocab["question_token_to_idx"]: + print("Found new word %s" % word) + idx = len(vocab["question_token_to_idx"]) + vocab["question_token_to_idx"][word] = idx + num_new_words += 1 + print("Found %d new words" % num_new_words) + + if args.output_vocab_json != "": + with open(args.output_vocab_json, "w") as f: + json.dump(vocab, f) + + # Encode all questions and programs + print("Encoding data") + questions_encoded = [] + programs_encoded = [] + question_families = [] + orig_idxs = [] + image_idxs = [] + answers = [] + types = [] + for orig_idx, q in enumerate(questions): + question = q["question"] + if "program" in q: + types += [q["program"][-1]["function"]] + + orig_idxs.append(orig_idx) + image_idxs.append(q["image_index"]) + if "question_family_index" in q: + question_families.append(q["question_family_index"]) + question_tokens = tokenize( + question, punct_to_keep=[";", ","], punct_to_remove=["?", "."] + ) + question_encoded = encode( + question_tokens, + vocab["question_token_to_idx"], + allow_unk=args.encode_unk == 1, + ) + questions_encoded.append(question_encoded) + + if "program" in q: + program = q["program"] + program_str = program_to_str(program, args.mode) + program_tokens = tokenize(program_str) + program_encoded = encode(program_tokens, vocab["program_token_to_idx"]) + programs_encoded.append(program_encoded) + + if "answer" in q: + answers.append(vocab["answer_token_to_idx"][q["answer"]]) + + # Pad encoded questions and programs + max_question_length = max(len(x) for x in questions_encoded) + for qe in questions_encoded: + while len(qe) < max_question_length: + qe.append(vocab["question_token_to_idx"][""]) if len(programs_encoded) > 0: - f.create_dataset('programs', data=programs_encoded) - if len(question_families) > 0: - f.create_dataset('question_families', data=np.asarray(question_families)) - if len(answers) > 0: - f.create_dataset('answers', data=np.asarray(answers)) - if len(types) > 0: - f.create_dataset('types', data=np.asarray(types_coded)) - - -if __name__ == '__main__': - args = parser.parse_args() - main(args) + max_program_length = max(len(x) for x in programs_encoded) + for pe in programs_encoded: + while len(pe) < max_program_length: + pe.append(vocab["program_token_to_idx"][""]) + + # Create h5 file + print("Writing output") + questions_encoded = np.asarray(questions_encoded, dtype=np.int32) + programs_encoded = np.asarray(programs_encoded, dtype=np.int32) + print(questions_encoded.shape) + print(programs_encoded.shape) + + mapping = {} + for i, t in enumerate(set(types)): + mapping[t] = i + + print(mapping) + + types_coded = [] + for t in types: + types_coded += [mapping[t]] + + with h5py.File(args.output_h5_file, "w") as f: + f.create_dataset("questions", data=questions_encoded) + f.create_dataset("image_idxs", data=np.asarray(image_idxs)) + f.create_dataset("orig_idxs", data=np.asarray(orig_idxs)) + + if len(programs_encoded) > 0: + f.create_dataset("programs", data=programs_encoded) + if len(question_families) > 0: + f.create_dataset("question_families", data=np.asarray(question_families)) + if len(answers) > 0: + f.create_dataset("answers", data=np.asarray(answers)) + if len(types) > 0: + f.create_dataset("types", data=np.asarray(types_coded)) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/scripts/preprocess_questions.sh b/scripts/preprocess_questions.sh new file mode 100644 index 0000000..08bff6a --- /dev/null +++ b/scripts/preprocess_questions.sh @@ -0,0 +1,14 @@ +python scripts/preprocess_questions.py \ + --input_questions_json data/CLEVR_v1.0/questions/CLEVR_train_questions.json \ + --output_h5_file data/train_questions.h5 \ + --output_vocab_json data/vocab.json + +python scripts/preprocess_questions.py \ + --input_questions_json data/CLEVR_v1.0/questions/CLEVR_val_questions.json \ + --output_h5_file data/val_questions.h5 \ + --input_vocab_json data/vocab.json + +python scripts/preprocess_questions.py \ + --input_questions_json data/CLEVR_v1.0/questions/CLEVR_test_questions.json \ + --output_h5_file data/test_questions.h5 \ + --input_vocab_json data/vocab.json \ No newline at end of file diff --git a/scripts/run_model.py b/scripts/run_model.py index 249356d..e87d559 100644 --- a/scripts/run_model.py +++ b/scripts/run_model.py @@ -4,6 +4,15 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import os +import imageio +import numpy as np +import torch +import torch.nn.functional as F + +from PIL import Image +from torch.autograd import Variable + import argparse import ipdb as pdb import json @@ -13,529 +22,615 @@ import time from tqdm import tqdm import sys -import os -sys.path.insert(0, os.path.abspath('.')) -import torch -from torch.autograd import Variable -import torch.nn.functional as F +sys.path.insert(0, os.path.abspath(".")) + import torchvision -import numpy as np import h5py -from scipy.misc import imread, imresize, imsave import vr.utils as utils import vr.programs from vr.data import ClevrDataset, ClevrDataLoader from vr.preprocess import tokenize, encode - parser = argparse.ArgumentParser() -parser.add_argument('--program_generator', default='models/best.pt') -parser.add_argument('--execution_engine', default='models/best.pt') -parser.add_argument('--baseline_model', default=None) -parser.add_argument('--model_type', default='FiLM') -parser.add_argument('--debug_every', default=float('inf'), type=float) -parser.add_argument('--use_gpu', default=1, type=int) +parser.add_argument("--program_generator", default="data/best.pt") +parser.add_argument("--execution_engine", default="data/best.pt") +parser.add_argument("--baseline_model", default=None) +parser.add_argument("--model_type", default="FiLM") +parser.add_argument("--debug_every", default=float("inf"), type=float) +parser.add_argument("--use_gpu", default=1, type=int) # For running on a preprocessed dataset -parser.add_argument('--input_question_h5', default=None) -parser.add_argument('--input_features_h5', default=None) +parser.add_argument("--input_question_h5", default=None) +parser.add_argument("--input_features_h5", default=None) # This will override the vocab stored in the checkpoint; # we need this to run CLEVR models on human data -parser.add_argument('--vocab_json', default=None) +parser.add_argument("--vocab_json", default=None) # For running on a single example -parser.add_argument('--question', default=None) -parser.add_argument('--image', default='img/CLEVR_val_000017.png') -parser.add_argument('--cnn_model', default='resnet101') -parser.add_argument('--cnn_model_stage', default=3, type=int) -parser.add_argument('--image_width', default=224, type=int) -parser.add_argument('--image_height', default=224, type=int) -parser.add_argument('--enforce_clevr_vocab', default=1, type=int) - -parser.add_argument('--batch_size', default=64, type=int) -parser.add_argument('--num_samples', default=None, type=int) -parser.add_argument('--num_last_words_shuffled', default=0, type=int) # -1 for all shuffled -parser.add_argument('--family_split_file', default=None) - -parser.add_argument('--sample_argmax', type=int, default=1) -parser.add_argument('--temperature', default=1.0, type=float) +parser.add_argument("--question", default=None) +parser.add_argument("--image", default="img/CLEVR_val_000017.png") +parser.add_argument("--cnn_model", default="resnet101") +parser.add_argument("--cnn_model_stage", default=3, type=int) +parser.add_argument("--image_width", default=224, type=int) +parser.add_argument("--image_height", default=224, type=int) +parser.add_argument("--enforce_clevr_vocab", default=1, type=int) + +parser.add_argument("--batch_size", default=64, type=int) +parser.add_argument("--num_samples", default=None, type=int) +parser.add_argument( + "--num_last_words_shuffled", default=0, type=int +) # -1 for all shuffled +parser.add_argument("--family_split_file", default=None) + +parser.add_argument("--sample_argmax", type=int, default=1) +parser.add_argument("--temperature", default=1.0, type=float) # FiLM models only -parser.add_argument('--gamma_option', default='linear', - choices=['linear', 'sigmoid', 'tanh', 'exp', 'relu', 'softplus']) -parser.add_argument('--gamma_scale', default=1, type=float) -parser.add_argument('--gamma_shift', default=0, type=float) -parser.add_argument('--gammas_from', default=None) # Load gammas from file -parser.add_argument('--beta_option', default='linear', - choices=['linear', 'sigmoid', 'tanh', 'exp', 'relu', 'softplus']) -parser.add_argument('--beta_scale', default=1, type=float) -parser.add_argument('--beta_shift', default=0, type=float) -parser.add_argument('--betas_from', default=None) # Load betas from file +parser.add_argument( + "--gamma_option", + default="linear", + choices=["linear", "sigmoid", "tanh", "exp", "relu", "softplus"], +) +parser.add_argument("--gamma_scale", default=1, type=float) +parser.add_argument("--gamma_shift", default=0, type=float) +parser.add_argument("--gammas_from", default=None) # Load gammas from file +parser.add_argument( + "--beta_option", + default="linear", + choices=["linear", "sigmoid", "tanh", "exp", "relu", "softplus"], +) +parser.add_argument("--beta_scale", default=1, type=float) +parser.add_argument("--beta_shift", default=0, type=float) +parser.add_argument("--betas_from", default=None) # Load betas from file # If this is passed, then save all predictions to this file -parser.add_argument('--output_h5', default=None) -parser.add_argument('--output_preds', default=None) -parser.add_argument('--output_viz_dir', default='img/') -parser.add_argument('--output_program_stats_dir', default=None) +parser.add_argument("--output_h5", default=None) +parser.add_argument("--output_preds", default=None) +parser.add_argument("--visualize_attention", default=False, type=bool) +parser.add_argument("--output_viz_dir", default="img/attention_visualizations/") +parser.add_argument("--output_program_stats_dir", default=None) +parser.add_argument("--streamlit", default=False, type=bool) grads = {} programs = {} # NOTE: Useful for zero-shot program manipulation when in debug mode + def main(args): - if args.debug_every <= 1: - pdb.set_trace() - model = None - if args.baseline_model is not None: - print('Loading baseline model from ', args.baseline_model) - model, _ = utils.load_baseline(args.baseline_model) - if args.vocab_json is not None: - new_vocab = utils.load_vocab(args.vocab_json) - model.rnn.expand_vocab(new_vocab['question_token_to_idx']) - elif args.program_generator is not None and args.execution_engine is not None: - pg, _ = utils.load_program_generator(args.program_generator, args.model_type) - ee, _ = utils.load_execution_engine( - args.execution_engine, verbose=False, model_type=args.model_type) - if args.vocab_json is not None: - new_vocab = utils.load_vocab(args.vocab_json) - pg.expand_encoder_vocab(new_vocab['question_token_to_idx']) - model = (pg, ee) - else: - print('Must give either --baseline_model or --program_generator and --execution_engine') - return + # Determine device: if use_gpu flag is set then use CUDA if available, + # else try MPS (for newer Macs) and fall back to CPU. + if args.use_gpu: + device = torch.device( + "cuda" + if torch.cuda.is_available() + else ("mps" if torch.backends.mps.is_available() else "cpu") + ) + else: + device = torch.device("cpu") + if not args.streamlit: + print("Using device:", device) + + model = None + if args.baseline_model is not None: + if not args.streamlit: + print("Loading baseline model from ", args.baseline_model) + model, _ = utils.load_baseline(args.baseline_model) + if args.vocab_json is not None: + new_vocab = utils.load_vocab(args.vocab_json) + model.rnn.expand_vocab(new_vocab["question_token_to_idx"]) + elif args.program_generator is not None and args.execution_engine is not None: + pg, _ = utils.load_program_generator(args.program_generator, args.model_type) + ee, _ = utils.load_execution_engine( + args.execution_engine, verbose=False, model_type=args.model_type + ) + if args.vocab_json is not None: + new_vocab = utils.load_vocab(args.vocab_json) + pg.expand_encoder_vocab(new_vocab["question_token_to_idx"]) + model = (pg, ee) + else: + print( + "Must give either --baseline_model or --program_generator and --execution_engine" + ) + return - dtype = torch.FloatTensor - if args.use_gpu == 1: - dtype = torch.cuda.FloatTensor - if args.question is not None and args.image is not None: - run_single_example(args, model, dtype, args.question) - # Interactive mode - elif args.image is not None and args.input_question_h5 is None and args.input_features_h5 is None: - feats_var = extract_image_features(args, dtype) - print(colored('Ask me something!', 'cyan')) - while True: - # Get user question - question_raw = input(">>> ") - run_single_example(args, model, dtype, question_raw, feats_var) - else: + if args.question is not None and args.image is not None: + run_single_example(args, model, device, args.question) + # Interactive mode + elif ( + args.image is not None + and args.input_question_h5 is None + and args.input_features_h5 is None + ): + feats_var = extract_image_features(args, device) + if not args.streamlit: + print(colored("Ask me something!", "cyan")) + while True: + if not args.streamlit: + question_raw = input(">>> ") + else: + question_raw = input("") + run_single_example(args, model, device, question_raw, feats_var) + else: + vocab = load_vocab(args) + loader_kwargs = { + "question_h5": args.input_question_h5, + "feature_h5": args.input_features_h5, + "vocab": vocab, + "batch_size": args.batch_size, + } + if args.num_samples is not None and args.num_samples > 0: + loader_kwargs["max_samples"] = args.num_samples + if args.family_split_file is not None: + with open(args.family_split_file, "r") as f: + loader_kwargs["question_families"] = json.load(f) + with ClevrDataLoader(**loader_kwargs) as loader: + run_batch(args, model, device, loader) + + +def extract_image_features(args, device): + # Build the CNN to use for feature extraction + if not args.streamlit: + print("Extracting image features...") + cnn = build_cnn(args, device) + + # Load and preprocess the image + img_size = (args.image_height, args.image_width) + + # Read image using imageio (ensuring RGB mode) + img = imageio.imread(args.image, pilmode="RGB") + + # Resize image using PIL + im = Image.fromarray(img) + im_resized = im.resize((img_size[1], img_size[0]), resample=Image.BICUBIC) + img = np.array(im_resized) + + # Transpose image dimensions to (1, channels, height, width) + img = img.transpose(2, 0, 1)[None] + + # Normalize the image + mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) + std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) + img = (img.astype(np.float32) / 255.0 - mean) / std + + # Create a PyTorch tensor for the image on the proper device + img_var = torch.tensor(img, dtype=torch.float32, device=device, requires_grad=True) + + # Use the CNN to extract features for the image + feats_var = cnn(img_var) + return feats_var + + +def run_single_example(args, model, device, question_raw, feats_var=None): + interactive = feats_var is not None + if not interactive: + feats_var = extract_image_features(args, device) + + # Tokenize the question vocab = load_vocab(args) - loader_kwargs = { - 'question_h5': args.input_question_h5, - 'feature_h5': args.input_features_h5, - 'vocab': vocab, - 'batch_size': args.batch_size, - } - if args.num_samples is not None and args.num_samples > 0: - loader_kwargs['max_samples'] = args.num_samples - if args.family_split_file is not None: - with open(args.family_split_file, 'r') as f: - loader_kwargs['question_families'] = json.load(f) - with ClevrDataLoader(**loader_kwargs) as loader: - run_batch(args, model, dtype, loader) - - -def extract_image_features(args, dtype): - # Build the CNN to use for feature extraction - print('Extracting image features...') - cnn = build_cnn(args, dtype) - - # Load and preprocess the image - img_size = (args.image_height, args.image_width) - img = imread(args.image, mode='RGB') - img = imresize(img, img_size, interp='bicubic') - img = img.transpose(2, 0, 1)[None] - mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) - std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) - img = (img.astype(np.float32) / 255.0 - mean) / std - - # Use CNN to extract features for the image - img_var = Variable(torch.FloatTensor(img).type(dtype), volatile=False, requires_grad=True) - feats_var = cnn(img_var) - return feats_var - - -def run_single_example(args, model, dtype, question_raw, feats_var=None): - interactive = feats_var is not None - if not interactive: - feats_var = extract_image_features(args, dtype) - - # Tokenize the question - vocab = load_vocab(args) - question_tokens = tokenize(question_raw, - punct_to_keep=[';', ','], - punct_to_remove=['?', '.']) - if args.enforce_clevr_vocab == 1: - for word in question_tokens: - if word not in vocab['question_token_to_idx']: - print(colored('No one taught me what "%s" means :( Try me again!' % (word), 'magenta')) + question_tokens = tokenize( + question_raw, punct_to_keep=[";", ","], punct_to_remove=["?", "."] + ) + if args.enforce_clevr_vocab == 1: + for word in question_tokens: + if word not in vocab["question_token_to_idx"]: + print( + colored( + 'No one taught me what "%s" means :( Try me again!' % (word), + "magenta", + ) + ) + return + question_encoded = encode( + question_tokens, vocab["question_token_to_idx"], allow_unk=True + ) + question_encoded = torch.tensor( + question_encoded, dtype=torch.long, device=device + ).view(1, -1) + question_var = Variable(question_encoded) + + # Run the model + scores = None + predicted_program = None + if type(model) is tuple: + pg, ee = model + pg.to(device) + pg.eval() + ee.to(device) + ee.eval() + if args.model_type == "FiLM": + predicted_program = pg(question_var) + else: + predicted_program = pg.reinforce_sample( + question_var, + temperature=args.temperature, + argmax=(args.sample_argmax == 1), + ) + programs[question_raw] = predicted_program + if args.debug_every <= -1: + pdb.set_trace() + scores = ee(feats_var, predicted_program, save_activations=True) + else: + model.to(device) + scores = model(question_var, feats_var) + + # Print results + predicted_probs = scores.data.cpu() + _, predicted_answer_idx = predicted_probs[0].max(dim=0) + predicted_probs = F.softmax(Variable(predicted_probs[0]), dim=0).data + predicted_answer = vocab["answer_idx_to_token"][predicted_answer_idx.item()] + + answers_to_probs = {} + for i in range(len(vocab["answer_idx_to_token"])): + answers_to_probs[vocab["answer_idx_to_token"][i]] = predicted_probs[i] + answers_to_probs_sorted = sorted( + answers_to_probs.items(), key=lambda x: x[1], reverse=True + ) + for i in range(len(answers_to_probs_sorted)): + if answers_to_probs_sorted[i][1] >= 1e-3 and args.debug_every < float("inf"): + print( + "%s: %.1f%%" + % ( + answers_to_probs_sorted[i][0].capitalize(), + 100 * answers_to_probs_sorted[i][1], + ) + ) + + if not interactive: + print(colored('Question: "%s"' % question_raw, "cyan")) + print(colored(str(predicted_answer).capitalize(), "magenta")) + + if interactive and not args.visualize_attention: return - question_encoded = encode(question_tokens, - vocab['question_token_to_idx'], - allow_unk=True) - question_encoded = torch.LongTensor(question_encoded).view(1, -1) - question_encoded = question_encoded.type(dtype).long() - question_var = Variable(question_encoded, volatile=False) - - # Run the model - scores = None - predicted_program = None - if type(model) is tuple: - pg, ee = model - pg.type(dtype) + + # Visualize Gradients w.r.t. output + cf_conv = ee.classifier[0](ee.cf_input) + cf_bn = ee.classifier[1](cf_conv) + pre_pool = ee.classifier[2](cf_bn) + pooled = ee.classifier[3](pre_pool) # noqa: F841 + + pre_pool_max_per_c = pre_pool.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0].expand_as(pre_pool) + pre_pool_masked = (pre_pool_max_per_c == pre_pool).float() * pre_pool + pool_feat_locs = (pre_pool_masked > 0).float().sum(1) + if args.debug_every <= 1: + pdb.set_trace() + + if args.output_viz_dir != "NA": + viz_dir = args.output_viz_dir + question_raw + " " + predicted_answer + if not os.path.isdir(args.output_viz_dir): + os.mkdir(args.output_viz_dir) + if not os.path.isdir(viz_dir): + os.mkdir(viz_dir) + args.viz_dir = viz_dir + + if not args.streamlit: + print("Saving visualizations to " + args.viz_dir) + + # Saving Beta and Gamma parameters + path_param = os.path.join(args.viz_dir, "params.pt") + torch.save(predicted_program,path_param) + + # Backprop w.r.t. sum of output scores - What affected prediction most? + ee.feats.register_hook(save_grad("stem")) + for i in range(ee.num_modules): + ee.module_outputs[i].register_hook(save_grad("m" + str(i))) + scores_sum = scores.sum() + scores_sum.backward() + + # Visualizations! + visualize(feats_var, args, "resnet101") + visualize(ee.feats, args, "conv-stem") + visualize(grads["stem"], args, "grad-conv-stem") + for i in range(ee.num_modules): + visualize(ee.module_outputs[i], args, "resblock" + str(i)) + visualize(grads["m" + str(i)], args, "grad-resblock" + str(i)) + visualize(pre_pool, args, "pre-pool") + visualize(pool_feat_locs, args, "pool-feature-locations") + + if (predicted_program is not None) and (args.model_type != "FiLM"): + print() + print("Predicted program:") + program = predicted_program.data.cpu()[0] + num_inputs = 1 + for fn_idx in program: + fn_str = vocab["program_idx_to_token"][fn_idx] + num_inputs += vr.programs.get_num_inputs(fn_str) - 1 + print(fn_str) + if num_inputs == 0: + break + if interactive: + return + + +def run_our_model_batch(args, pg, ee, loader, device): + pg.to(device) pg.eval() - ee.type(dtype) + ee.to(device) ee.eval() - if args.model_type == 'FiLM': - predicted_program = pg(question_var) - else: - predicted_program = pg.reinforce_sample( - question_var, - temperature=args.temperature, - argmax=(args.sample_argmax == 1)) - programs[question_raw] = predicted_program - if args.debug_every <= -1: - pdb.set_trace() - scores = ee(feats_var, predicted_program, save_activations=True) - else: - model.type(dtype) - scores = model(question_var, feats_var) - - # Print results - predicted_probs = scores.data.cpu() - _, predicted_answer_idx = predicted_probs[0].max(dim=0) - predicted_probs = F.softmax(Variable(predicted_probs[0])).data - predicted_answer = vocab['answer_idx_to_token'][predicted_answer_idx[0]] - - answers_to_probs = {} - for i in range(len(vocab['answer_idx_to_token'])): - answers_to_probs[vocab['answer_idx_to_token'][i]] = predicted_probs[i] - answers_to_probs_sorted = sorted(answers_to_probs.items(), key=lambda x: x[1]) - answers_to_probs_sorted.reverse() - for i in range(len(answers_to_probs_sorted)): - if answers_to_probs_sorted[i][1] >= 1e-3 and args.debug_every < float('inf'): - print("%s: %.1f%%" % (answers_to_probs_sorted[i][0].capitalize(), - 100 * answers_to_probs_sorted[i][1])) - - if not interactive: - print(colored('Question: "%s"' % question_raw, 'cyan')) - print(colored(str(predicted_answer).capitalize(), 'magenta')) - - if interactive: - return - - # Visualize Gradients w.r.t. output - cf_conv = ee.classifier[0](ee.cf_input) - cf_bn = ee.classifier[1](cf_conv) - pre_pool = ee.classifier[2](cf_bn) - pooled = ee.classifier[3](pre_pool) - - pre_pool_max_per_c = pre_pool.max(2)[0].max(3)[0].expand_as(pre_pool) - pre_pool_masked = (pre_pool_max_per_c == pre_pool).float() * pre_pool - pool_feat_locs = (pre_pool_masked > 0).float().sum(1) - if args.debug_every <= 1: - pdb.set_trace() - - if args.output_viz_dir != 'NA': - viz_dir = args.output_viz_dir + question_raw + ' ' + predicted_answer - if not os.path.isdir(viz_dir): - os.mkdir(viz_dir) - args.viz_dir = viz_dir - print('Saving visualizations to ' + args.viz_dir) - - # Backprop w.r.t. sum of output scores - What affected prediction most? - ee.feats.register_hook(save_grad('stem')) - for i in range(ee.num_modules): - ee.module_outputs[i].register_hook(save_grad('m' + str(i))) - scores_sum = scores.sum() - scores_sum.backward() - - # Visualizations! - visualize(feats_var, args, 'resnet101') - visualize(ee.feats, args, 'conv-stem') - visualize(grads['stem'], args, 'grad-conv-stem') - for i in range(ee.num_modules): - visualize(ee.module_outputs[i], args, 'resblock' + str(i)) - visualize(grads['m' + str(i)], args, 'grad-resblock' + str(i)) - visualize(pre_pool, args, 'pre-pool') - visualize(pool_feat_locs, args, 'pool-feature-locations') - - if (predicted_program is not None) and (args.model_type != 'FiLM'): - print() - print('Predicted program:') - program = predicted_program.data.cpu()[0] - num_inputs = 1 - for fn_idx in program: - fn_str = vocab['program_idx_to_token'][fn_idx] - num_inputs += vr.programs.get_num_inputs(fn_str) - 1 - print(fn_str) - if num_inputs == 0: - break - - -def run_our_model_batch(args, pg, ee, loader, dtype): - pg.type(dtype) - pg.eval() - ee.type(dtype) - ee.eval() - - all_scores, all_programs = [], [] - all_probs = [] - all_preds = [] - num_correct, num_samples = 0, 0 - - loaded_gammas = None - loaded_betas = None - if args.gammas_from: - print('Loading ') - loaded_gammas = torch.load(args.gammas_from) - if args.betas_from: - print('Betas loaded!') - loaded_betas = torch.load(args.betas_from) - - q_types = [] - film_params = [] - - if args.num_last_words_shuffled == -1: - print('All words of each question shuffled.') - elif args.num_last_words_shuffled > 0: - print('Last %d words of each question shuffled.' % args.num_last_words_shuffled) - start = time.time() - for batch in tqdm(loader): - assert(not pg.training) - assert(not ee.training) - questions, images, feats, answers, programs, program_lists = batch - - if args.num_last_words_shuffled != 0: - for i, question in enumerate(questions): - # Search for token to find question length - q_end = get_index(question.numpy().tolist(), index=2, default=len(question)) - if args.num_last_words_shuffled > 0: - q_end -= args.num_last_words_shuffled # Leave last few words unshuffled - if q_end < 2: - q_end = 2 - question = question[1:q_end] - random.shuffle(question) - questions[i][1:q_end] = question + all_scores, all_programs = [], [] + all_probs = [] + all_preds = [] + num_correct, num_samples = 0, 0 + + loaded_gammas = None + loaded_betas = None + if args.gammas_from: + print("Loading ") + loaded_gammas = torch.load(args.gammas_from, map_location=device) + if args.betas_from: + print("Betas loaded!") + loaded_betas = torch.load(args.betas_from, map_location=device) + + q_types = [] + film_params = [] + + if args.num_last_words_shuffled == -1: + print("All words of each question shuffled.") + elif args.num_last_words_shuffled > 0: + print("Last %d words of each question shuffled." % args.num_last_words_shuffled) + start = time.time() + for batch in tqdm(loader): + assert not pg.training + assert not ee.training + questions, images, feats, answers, programs, program_lists = batch + + if args.num_last_words_shuffled != 0: + for i, question in enumerate(questions): + # Search for token to find question length + q_end = get_index( + question.numpy().tolist(), index=2, default=len(question) + ) + if args.num_last_words_shuffled > 0: + q_end -= args.num_last_words_shuffled + if q_end < 2: + q_end = 2 + question = question[1:q_end] + random.shuffle(question) + questions[i][1:q_end] = question + + if isinstance(questions, list): + questions_var = Variable(questions[0].to(device).long(), volatile=True) + q_types += [questions[1].cpu().numpy()] + else: + questions_var = Variable(questions.to(device).long(), volatile=True) + feats_var = Variable(feats.to(device), volatile=True) + if args.model_type == "FiLM": + programs_pred = pg(questions_var) + # Examine effect of various conditioning modifications at test time! + programs_pred = pg.modify_output( + programs_pred, + gamma_option=args.gamma_option, + gamma_scale=args.gamma_scale, + gamma_shift=args.gamma_shift, + beta_option=args.beta_option, + beta_scale=args.beta_scale, + beta_shift=args.beta_shift, + ) + if args.gammas_from: + programs_pred[:, :, : pg.module_dim] = loaded_gammas.expand_as( + programs_pred[:, :, : pg.module_dim] + ) + if args.betas_from: + programs_pred[:, :, pg.module_dim : 2 * pg.module_dim] = ( + loaded_betas.expand_as( + programs_pred[:, :, pg.module_dim : 2 * pg.module_dim] + ) + ) + else: + programs_pred = pg.reinforce_sample( + questions_var, + temperature=args.temperature, + argmax=(args.sample_argmax == 1), + ) + + film_params += [programs_pred.cpu().data.numpy()] + scores = ee(feats_var, programs_pred, save_activations=True) + probs = F.softmax(scores, dim=1) + + _, preds = scores.data.cpu().max(1) + all_programs.append(programs_pred.data.cpu().clone()) + all_scores.append(scores.data.cpu().clone()) + all_probs.append(probs.data.cpu().clone()) + all_preds.append(preds.cpu().clone()) + if answers[0] is not None: + num_correct += (preds == answers).sum() + num_samples += preds.size(0) + + acc = float(num_correct) / num_samples + print("Got %d / %d = %.2f correct" % (num_correct, num_samples, 100 * acc)) + print("%.2fs to evaluate" % (time.time() - start)) + all_programs = torch.cat(all_programs, 0) + all_scores = torch.cat(all_scores, 0) + all_probs = torch.cat(all_probs, 0) + all_preds = torch.cat(all_preds, 0).squeeze().numpy() + if args.output_h5 is not None: + print('Writing output to "%s"' % args.output_h5) + with h5py.File(args.output_h5, "w") as fout: + fout.create_dataset("scores", data=all_scores.numpy()) + fout.create_dataset("probs", data=all_probs.numpy()) + fout.create_dataset("predicted_programs", data=all_programs.numpy()) + + # Save FiLM params + np.save("film_params", np.vstack(film_params)) if isinstance(questions, list): - questions_var = Variable(questions[0].type(dtype).long(), volatile=True) - q_types += [questions[1].cpu().numpy()] - else: - questions_var = Variable(questions.type(dtype).long(), volatile=True) - feats_var = Variable(feats.type(dtype), volatile=True) - if args.model_type == 'FiLM': - programs_pred = pg(questions_var) - # Examine effect of various conditioning modifications at test time! - programs_pred = pg.modify_output(programs_pred, gamma_option=args.gamma_option, - gamma_scale=args.gamma_scale, gamma_shift=args.gamma_shift, - beta_option=args.beta_option, beta_scale=args.beta_scale, - beta_shift=args.beta_shift) - if args.gammas_from: - programs_pred[:,:,:pg.module_dim] = loaded_gammas.expand_as( - programs_pred[:,:,:pg.module_dim]) - if args.betas_from: - programs_pred[:,:,pg.module_dim:2*pg.module_dim] = loaded_betas.expand_as( - programs_pred[:,:,pg.module_dim:2*pg.module_dim]) - else: - programs_pred = pg.reinforce_sample( - questions_var, - temperature=args.temperature, - argmax=(args.sample_argmax == 1)) - - film_params += [programs_pred.cpu().data.numpy()] - scores = ee(feats_var, programs_pred, save_activations=True) - probs = F.softmax(scores) - - _, preds = scores.data.cpu().max(1) - all_programs.append(programs_pred.data.cpu().clone()) - all_scores.append(scores.data.cpu().clone()) - all_probs.append(probs.data.cpu().clone()) - all_preds.append(preds.cpu().clone()) - if answers[0] is not None: - num_correct += (preds == answers).sum() - num_samples += preds.size(0) - - acc = float(num_correct) / num_samples - print('Got %d / %d = %.2f correct' % (num_correct, num_samples, 100 * acc)) - print('%.2fs to evaluate' % (start - time.time())) - all_programs = torch.cat(all_programs, 0) - all_scores = torch.cat(all_scores, 0) - all_probs = torch.cat(all_probs, 0) - all_preds = torch.cat(all_preds, 0).squeeze().numpy() - if args.output_h5 is not None: - print('Writing output to "%s"' % args.output_h5) - with h5py.File(args.output_h5, 'w') as fout: - fout.create_dataset('scores', data=all_scores.numpy()) - fout.create_dataset('probs', data=all_probs.numpy()) - fout.create_dataset('predicted_programs', data=all_programs.numpy()) - - # Save FiLM params - np.save('film_params', np.vstack(film_params)) - if isinstance(questions, list): - np.save('q_types', np.vstack(q_types)) - - # Save FiLM param stats - if args.output_program_stats_dir: - if not os.path.isdir(args.output_program_stats_dir): - os.mkdir(args.output_program_stats_dir) - gammas = all_programs[:,:,:pg.module_dim] - betas = all_programs[:,:,pg.module_dim:2*pg.module_dim] - gamma_means = gammas.mean(0) - torch.save(gamma_means, os.path.join(args.output_program_stats_dir, 'gamma_means')) - beta_means = betas.mean(0) - torch.save(beta_means, os.path.join(args.output_program_stats_dir, 'beta_means')) - gamma_medians = gammas.median(0)[0] - torch.save(gamma_medians, os.path.join(args.output_program_stats_dir, 'gamma_medians')) - beta_medians = betas.median(0)[0] - torch.save(beta_medians, os.path.join(args.output_program_stats_dir, 'beta_medians')) - - # Note: Takes O(10GB) space - torch.save(gammas, os.path.join(args.output_program_stats_dir, 'gammas')) - torch.save(betas, os.path.join(args.output_program_stats_dir, 'betas')) - - if args.output_preds is not None: - vocab = load_vocab(args) - all_preds_strings = [] - for i in range(len(all_preds)): - all_preds_strings.append(vocab['answer_idx_to_token'][all_preds[i]]) - save_to_file(all_preds_strings, args.output_preds) - - if args.debug_every <= 1: - pdb.set_trace() - return + np.save("q_types", np.vstack(q_types)) + + # Save FiLM param stats + if args.output_program_stats_dir: + if not os.path.isdir(args.output_program_stats_dir): + os.mkdir(args.output_program_stats_dir) + gammas = all_programs[:, :, : pg.module_dim] + betas = all_programs[:, :, pg.module_dim : 2 * pg.module_dim] + gamma_means = gammas.mean(0) + torch.save( + gamma_means, os.path.join(args.output_program_stats_dir, "gamma_means") + ) + beta_means = betas.mean(0) + torch.save( + beta_means, os.path.join(args.output_program_stats_dir, "beta_means") + ) + gamma_medians = gammas.median(0)[0] + torch.save( + gamma_medians, os.path.join(args.output_program_stats_dir, "gamma_medians") + ) + beta_medians = betas.median(0)[0] + torch.save( + beta_medians, os.path.join(args.output_program_stats_dir, "beta_medians") + ) + + # Note: Takes O(10GB) space + torch.save(gammas, os.path.join(args.output_program_stats_dir, "gammas")) + torch.save(betas, os.path.join(args.output_program_stats_dir, "betas")) + + if args.output_preds is not None: + vocab = load_vocab(args) + all_preds_strings = [] + for i in range(len(all_preds)): + all_preds_strings.append(vocab["answer_idx_to_token"][all_preds[i]]) + save_to_file(all_preds_strings, args.output_preds) + + if args.debug_every <= 1: + pdb.set_trace() + return def visualize(features, args, file_name=None): """ Converts a 4d map of features to alpha attention weights, - According to their 2-Norm across dimensions 0 and 1. + according to their 2-norm across dimensions 0 and 1. Then saves the input RGB image as an RGBA image using an upsampling of this attention map. """ - save_file = os.path.join(args.viz_dir, file_name) + save_file = os.path.join(args.viz_dir, file_name) if file_name is not None else None img_path = args.image - - # Scale map to [0, 1] - f_map = (features ** 2).mean(0).mean(1).squeeze().sqrt() + + # Add a batch dimension or a channel dimension if it's lacking (for pool_feat_locs for example) + if features.dim() == 3: + features = features.unsqueeze(0) + # Scale feature map to [0, 1] + f_map = (features**2).mean(0).mean(0).squeeze().sqrt() f_map_shifted = f_map - f_map.min().expand_as(f_map) f_map_scaled = f_map_shifted / f_map_shifted.max().expand_as(f_map_shifted) if save_file is None: - print(f_map_scaled) + print(f_map_scaled) else: - # Read original image - img = imread(img_path, mode='RGB') - orig_img_size = img.shape - - # Convert to image format - alpha = (255 * f_map_scaled).round() - alpha4d = alpha.unsqueeze(0).unsqueeze(0) - alpha_upsampled = torch.nn.functional.upsample_bilinear( - alpha4d, size=torch.Size(orig_img_size)).squeeze(0).transpose(1, 0).transpose(1, 2) - alpha_upsampled_np = alpha_upsampled.cpu().data.numpy() - - # Create and save visualization - imga = np.concatenate([img, alpha_upsampled_np], axis=2) - if save_file[-4:] != '.png': save_file += '.png' - imsave(save_file, imga) + img = imageio.imread(img_path, pilmode="RGB") + orig_img_size = img.shape[:2] + + alpha = (255 * f_map_scaled).round().byte() + alpha4d = alpha.unsqueeze(0).unsqueeze(0) + alpha_upsampled = F.interpolate( + alpha4d, size=orig_img_size, mode="bilinear", align_corners=False + ) + alpha_upsampled = alpha_upsampled.squeeze(0).transpose(1, 0).transpose(1, 2) + alpha_upsampled_np = alpha_upsampled.cpu().data.numpy() + + imga = np.concatenate([img, alpha_upsampled_np], axis=2) + + if not save_file.lower().endswith(".png"): + save_file += ".png" + + imageio.imwrite(save_file, imga) return f_map_scaled -def build_cnn(args, dtype): - if not hasattr(torchvision.models, args.cnn_model): - raise ValueError('Invalid model "%s"' % args.cnn_model) - if not 'resnet' in args.cnn_model: - raise ValueError('Feature extraction only supports ResNets') - whole_cnn = getattr(torchvision.models, args.cnn_model)(pretrained=True) - layers = [ - whole_cnn.conv1, - whole_cnn.bn1, - whole_cnn.relu, - whole_cnn.maxpool, - ] - for i in range(args.cnn_model_stage): - name = 'layer%d' % (i + 1) - layers.append(getattr(whole_cnn, name)) - cnn = torch.nn.Sequential(*layers) - cnn.type(dtype) - cnn.eval() - return cnn - - -def run_batch(args, model, dtype, loader): - if type(model) is tuple: - pg, ee = model - run_our_model_batch(args, pg, ee, loader, dtype) - else: - run_baseline_batch(args, model, loader, dtype) - - -def run_baseline_batch(args, model, loader, dtype): - model.type(dtype) - model.eval() - - all_scores, all_probs = [], [] - num_correct, num_samples = 0, 0 - for batch in loader: - questions, images, feats, answers, programs, program_lists = batch - - questions_var = Variable(questions.type(dtype).long(), volatile=True) - feats_var = Variable(feats.type(dtype), volatile=True) - scores = model(questions_var, feats_var) - probs = F.softmax(scores) - - _, preds = scores.data.cpu().max(1) - all_scores.append(scores.data.cpu().clone()) - all_probs.append(probs.data.cpu().clone()) - - num_correct += (preds == answers).sum() - num_samples += preds.size(0) - print('Ran %d samples' % num_samples) - - acc = float(num_correct) / num_samples - print('Got %d / %d = %.2f correct' % (num_correct, num_samples, 100 * acc)) - - all_scores = torch.cat(all_scores, 0) - all_probs = torch.cat(all_probs, 0) - if args.output_h5 is not None: - print('Writing output to %s' % args.output_h5) - with h5py.File(args.output_h5, 'w') as fout: - fout.create_dataset('scores', data=all_scores.numpy()) - fout.create_dataset('probs', data=all_probs.numpy()) +def build_cnn(args, device): + if not hasattr(torchvision.models, args.cnn_model): + raise ValueError('Invalid model "%s"' % args.cnn_model) + if "resnet" not in args.cnn_model: + raise ValueError("Feature extraction only supports ResNets") + whole_cnn = getattr(torchvision.models, args.cnn_model)(pretrained=True) + layers = [ + whole_cnn.conv1, + whole_cnn.bn1, + whole_cnn.relu, + whole_cnn.maxpool, + ] + for i in range(args.cnn_model_stage): + name = "layer%d" % (i + 1) + layers.append(getattr(whole_cnn, name)) + cnn = torch.nn.Sequential(*layers) + cnn.to(device) + cnn.eval() + return cnn + + +def run_batch(args, model, device, loader): + if type(model) is tuple: + pg, ee = model + run_our_model_batch(args, pg, ee, loader, device) + else: + run_baseline_batch(args, model, loader, device) + + +def run_baseline_batch(args, model, loader, device): + model.to(device) + model.eval() + + all_scores, all_probs = [], [] + num_correct, num_samples = 0, 0 + for batch in loader: + questions, images, feats, answers, programs, program_lists = batch + + questions_var = Variable(questions.to(device).long(), volatile=True) + feats_var = Variable(feats.to(device), volatile=True) + scores = model(questions_var, feats_var) + probs = F.softmax(scores, dim=1) + + _, preds = scores.data.cpu().max(1) + all_scores.append(scores.data.cpu().clone()) + all_probs.append(probs.data.cpu().clone()) + + num_correct += (preds == answers).sum() + num_samples += preds.size(0) + print("Ran %d samples" % num_samples) + + acc = float(num_correct) / num_samples + print("Got %d / %d = %.2f correct" % (num_correct, num_samples, 100 * acc)) + + all_scores = torch.cat(all_scores, 0) + all_probs = torch.cat(all_probs, 0) + if args.output_h5 is not None: + print("Writing output to %s" % args.output_h5) + with h5py.File(args.output_h5, "w") as fout: + fout.create_dataset("scores", data=all_scores.numpy()) + fout.create_dataset("probs", data=all_probs.numpy()) def load_vocab(args): - path = None - if args.baseline_model is not None: - path = args.baseline_model - elif args.program_generator is not None: - path = args.program_generator - elif args.execution_engine is not None: - path = args.execution_engine - return utils.load_cpu(path)['vocab'] + path = None + if args.baseline_model is not None: + path = args.baseline_model + elif args.program_generator is not None: + path = args.program_generator + elif args.execution_engine is not None: + path = args.execution_engine + return utils.load_cpu(path)["vocab"] def save_grad(name): - def hook(grad): - grads[name] = grad - return hook + def hook(grad): + grads[name] = grad + + return hook def save_to_file(text, filename): - with open(filename, mode='wt', encoding='utf-8') as myfile: - myfile.write('\n'.join(text)) - myfile.write('\n') + with open(filename, mode="wt", encoding="utf-8") as myfile: + myfile.write("\n".join(text)) + myfile.write("\n") -def get_index(l, index, default=-1): - try: - return l.index(index) - except ValueError: - return default +def get_index(l, index, default=-1): # noqa: E741 + try: + return l.index(index) + except ValueError: + return default -if __name__ == '__main__': - args = parser.parse_args() - main(args) +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/scripts/run_model.sh b/scripts/run_model.sh new file mode 100644 index 0000000..2512c20 --- /dev/null +++ b/scripts/run_model.sh @@ -0,0 +1,3 @@ +python scripts/run_model.py \ + --program_generator data/best.pt \ + --execution_engine data/best.pt \ No newline at end of file diff --git a/scripts/train/film.sh b/scripts/train/film.sh index 141b8c6..3525183 100644 --- a/scripts/train/film.sh +++ b/scripts/train/film.sh @@ -37,7 +37,7 @@ python scripts/train_model.py \ --module_kernel_size 3 \ --module_batchnorm_affine 0 \ --module_num_layers 1 \ - --num_modules 4 \ + --num_modules 3 \ --condition_pattern 1,1,1,1 \ --gamma_option linear \ --gamma_baseline 1 \ diff --git a/scripts/train_model.py b/scripts/train_model.py index 5ed31f1..a5ee9ce 100644 --- a/scripts/train_model.py +++ b/scripts/train_model.py @@ -8,7 +8,8 @@ import sys import os -sys.path.insert(0, os.path.abspath('.')) + +sys.path.insert(0, os.path.abspath(".")) import argparse import ipdb as pdb @@ -17,8 +18,10 @@ import shutil from termcolor import colored import time +from datetime import datetime import torch + torch.backends.cudnn.enabled = True from torch.autograd import Variable import torch.nn.functional as F @@ -35,610 +38,740 @@ parser = argparse.ArgumentParser() # Input data -parser.add_argument('--train_question_h5', default='data/train_questions.h5') -parser.add_argument('--train_features_h5', default='data/train_features.h5') -parser.add_argument('--val_question_h5', default='data/val_questions.h5') -parser.add_argument('--val_features_h5', default='data/val_features.h5') -parser.add_argument('--feature_dim', default='1024,14,14') -parser.add_argument('--vocab_json', default='data/vocab.json') - -parser.add_argument('--loader_num_workers', type=int, default=1) -parser.add_argument('--use_local_copies', default=0, type=int) -parser.add_argument('--cleanup_local_copies', default=1, type=int) - -parser.add_argument('--family_split_file', default=None) -parser.add_argument('--num_train_samples', default=None, type=int) -parser.add_argument('--num_val_samples', default=None, type=int) -parser.add_argument('--shuffle_train_data', default=1, type=int) +parser.add_argument("--train_question_h5", default="data/train_questions.h5") +parser.add_argument("--train_features_h5", default="data/train_features.h5") +parser.add_argument("--val_question_h5", default="data/val_questions.h5") +parser.add_argument("--val_features_h5", default="data/val_features.h5") +parser.add_argument("--feature_dim", default="1024,14,14") +parser.add_argument("--vocab_json", default="data/vocab.json") + +parser.add_argument("--loader_num_workers", type=int, default=1) +parser.add_argument("--use_local_copies", default=0, type=int) +parser.add_argument("--cleanup_local_copies", default=1, type=int) + +parser.add_argument("--family_split_file", default=None) +parser.add_argument("--num_train_samples", default=None, type=int) +parser.add_argument("--num_val_samples", default=None, type=int) +parser.add_argument("--shuffle_train_data", default=1, type=int) # What type of model to use and which parts to train -parser.add_argument('--model_type', default='PG', - choices=['FiLM', 'PG', 'EE', 'PG+EE', 'LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']) -parser.add_argument('--train_program_generator', default=1, type=int) -parser.add_argument('--train_execution_engine', default=1, type=int) -parser.add_argument('--baseline_train_only_rnn', default=0, type=int) +parser.add_argument( + "--model_type", + default="PG", + choices=["FiLM", "PG", "EE", "PG+EE", "LSTM", "CNN+LSTM", "CNN+LSTM+SA"], +) +parser.add_argument("--train_program_generator", default=1, type=int) +parser.add_argument("--train_execution_engine", default=1, type=int) +parser.add_argument("--baseline_train_only_rnn", default=0, type=int) # Start from an existing checkpoint -parser.add_argument('--program_generator_start_from', default=None) -parser.add_argument('--execution_engine_start_from', default=None) -parser.add_argument('--baseline_start_from', default=None) +parser.add_argument("--program_generator_start_from", default=None) +parser.add_argument("--execution_engine_start_from", default=None) +parser.add_argument("--baseline_start_from", default=None) # RNN options -parser.add_argument('--rnn_wordvec_dim', default=300, type=int) -parser.add_argument('--rnn_hidden_dim', default=256, type=int) -parser.add_argument('--rnn_num_layers', default=2, type=int) -parser.add_argument('--rnn_dropout', default=0, type=float) +parser.add_argument("--rnn_wordvec_dim", default=300, type=int) +parser.add_argument("--rnn_hidden_dim", default=256, type=int) +parser.add_argument("--rnn_num_layers", default=2, type=int) +parser.add_argument("--rnn_dropout", default=0, type=float) # Module net / FiLMedNet options -parser.add_argument('--module_stem_num_layers', default=2, type=int) -parser.add_argument('--module_stem_batchnorm', default=0, type=int) -parser.add_argument('--module_dim', default=128, type=int) -parser.add_argument('--module_residual', default=1, type=int) -parser.add_argument('--module_batchnorm', default=0, type=int) +parser.add_argument("--module_stem_num_layers", default=2, type=int) +parser.add_argument("--module_stem_batchnorm", default=0, type=int) +parser.add_argument("--module_dim", default=128, type=int) +parser.add_argument("--module_residual", default=1, type=int) +parser.add_argument("--module_batchnorm", default=0, type=int) # FiLM only options -parser.add_argument('--set_execution_engine_eval', default=0, type=int) -parser.add_argument('--program_generator_parameter_efficient', default=1, type=int) -parser.add_argument('--rnn_output_batchnorm', default=0, type=int) -parser.add_argument('--bidirectional', default=0, type=int) -parser.add_argument('--encoder_type', default='gru', type=str, - choices=['linear', 'gru', 'lstm']) -parser.add_argument('--decoder_type', default='linear', type=str, - choices=['linear', 'gru', 'lstm']) -parser.add_argument('--gamma_option', default='linear', - choices=['linear', 'sigmoid', 'tanh', 'exp']) -parser.add_argument('--gamma_baseline', default=1, type=float) -parser.add_argument('--num_modules', default=4, type=int) -parser.add_argument('--module_stem_kernel_size', default=3, type=int) -parser.add_argument('--module_stem_stride', default=1, type=int) -parser.add_argument('--module_stem_padding', default=None, type=int) -parser.add_argument('--module_num_layers', default=1, type=int) # Only mnl=1 currently implemented -parser.add_argument('--module_batchnorm_affine', default=0, type=int) # 1 overrides other factors -parser.add_argument('--module_dropout', default=5e-2, type=float) -parser.add_argument('--module_input_proj', default=1, type=int) # Inp conv kernel size (0 for None) -parser.add_argument('--module_kernel_size', default=3, type=int) -parser.add_argument('--condition_method', default='bn-film', type=str, - choices=['block-input-film', 'block-output-film', 'bn-film', 'concat', 'conv-film', 'relu-film']) -parser.add_argument('--condition_pattern', default='', type=str) # List of 0/1's (len = # FiLMs) -parser.add_argument('--use_gamma', default=1, type=int) -parser.add_argument('--use_beta', default=1, type=int) -parser.add_argument('--use_coords', default=1, type=int) # 0: none, 1: low usage, 2: high usage -parser.add_argument('--grad_clip', default=0, type=float) # <= 0 for no grad clipping -parser.add_argument('--debug_every', default=float('inf'), type=float) # inf for no pdb -parser.add_argument('--print_verbose_every', default=float('inf'), type=float) # inf for min print +parser.add_argument("--set_execution_engine_eval", default=0, type=int) +parser.add_argument("--program_generator_parameter_efficient", default=1, type=int) +parser.add_argument("--rnn_output_batchnorm", default=0, type=int) +parser.add_argument("--bidirectional", default=0, type=int) +parser.add_argument( + "--encoder_type", default="gru", type=str, choices=["linear", "gru", "lstm"] +) +parser.add_argument( + "--decoder_type", default="linear", type=str, choices=["linear", "gru", "lstm"] +) +parser.add_argument( + "--gamma_option", default="linear", choices=["linear", "sigmoid", "tanh", "exp"] +) +parser.add_argument("--gamma_baseline", default=1, type=float) +parser.add_argument("--num_modules", default=4, type=int) +parser.add_argument("--module_stem_kernel_size", default=3, type=int) +parser.add_argument("--module_stem_stride", default=1, type=int) +parser.add_argument("--module_stem_padding", default=None, type=int) +parser.add_argument( + "--module_num_layers", default=1, type=int +) # Only mnl=1 currently implemented +parser.add_argument( + "--module_batchnorm_affine", default=0, type=int +) # 1 overrides other factors +parser.add_argument("--module_dropout", default=5e-2, type=float) +parser.add_argument( + "--module_input_proj", default=1, type=int +) # Inp conv kernel size (0 for None) +parser.add_argument("--module_kernel_size", default=3, type=int) +parser.add_argument( + "--condition_method", + default="bn-film", + type=str, + choices=[ + "block-input-film", + "block-output-film", + "bn-film", + "concat", + "conv-film", + "relu-film", + ], +) +parser.add_argument( + "--condition_pattern", default="", type=str +) # List of 0/1's (len = # FiLMs) +parser.add_argument("--use_gamma", default=1, type=int) +parser.add_argument("--use_beta", default=1, type=int) +parser.add_argument( + "--use_coords", default=1, type=int +) # 0: none, 1: low usage, 2: high usage +parser.add_argument("--grad_clip", default=0, type=float) # <= 0 for no grad clipping +parser.add_argument("--debug_every", default=float("inf"), type=float) # inf for no pdb +parser.add_argument( + "--print_verbose_every", default=float("inf"), type=float +) # inf for min print # CNN options (for baselines) -parser.add_argument('--cnn_res_block_dim', default=128, type=int) -parser.add_argument('--cnn_num_res_blocks', default=0, type=int) -parser.add_argument('--cnn_proj_dim', default=512, type=int) -parser.add_argument('--cnn_pooling', default='maxpool2', - choices=['none', 'maxpool2']) +parser.add_argument("--cnn_res_block_dim", default=128, type=int) +parser.add_argument("--cnn_num_res_blocks", default=0, type=int) +parser.add_argument("--cnn_proj_dim", default=512, type=int) +parser.add_argument("--cnn_pooling", default="maxpool2", choices=["none", "maxpool2"]) # Stacked-Attention options -parser.add_argument('--stacked_attn_dim', default=512, type=int) -parser.add_argument('--num_stacked_attn', default=2, type=int) +parser.add_argument("--stacked_attn_dim", default=512, type=int) +parser.add_argument("--num_stacked_attn", default=2, type=int) # Classifier options -parser.add_argument('--classifier_proj_dim', default=512, type=int) -parser.add_argument('--classifier_downsample', default='maxpool2', - choices=['maxpool2', 'maxpool3', 'maxpool4', 'maxpool5', 'maxpool7', 'maxpoolfull', 'none', - 'avgpool2', 'avgpool3', 'avgpool4', 'avgpool5', 'avgpool7', 'avgpoolfull', 'aggressive']) -parser.add_argument('--classifier_fc_dims', default='1024') -parser.add_argument('--classifier_batchnorm', default=0, type=int) -parser.add_argument('--classifier_dropout', default=0, type=float) +parser.add_argument("--classifier_proj_dim", default=512, type=int) +parser.add_argument( + "--classifier_downsample", + default="maxpool2", + choices=[ + "maxpool2", + "maxpool3", + "maxpool4", + "maxpool5", + "maxpool7", + "maxpoolfull", + "none", + "avgpool2", + "avgpool3", + "avgpool4", + "avgpool5", + "avgpool7", + "avgpoolfull", + "aggressive", + ], +) +parser.add_argument("--classifier_fc_dims", default="1024") +parser.add_argument("--classifier_batchnorm", default=0, type=int) +parser.add_argument("--classifier_dropout", default=0, type=float) # Optimization options -parser.add_argument('--batch_size', default=64, type=int) -parser.add_argument('--num_iterations', default=100000, type=int) -parser.add_argument('--optimizer', default='Adam', - choices=['Adadelta', 'Adagrad', 'Adam', 'Adamax', 'ASGD', 'RMSprop', 'SGD']) -parser.add_argument('--learning_rate', default=5e-4, type=float) -parser.add_argument('--reward_decay', default=0.9, type=float) -parser.add_argument('--weight_decay', default=0, type=float) +parser.add_argument("--batch_size", default=64, type=int) +parser.add_argument("--num_iterations", default=100000, type=int) +parser.add_argument( + "--optimizer", + default="Adam", + choices=["Adadelta", "Adagrad", "Adam", "Adamax", "ASGD", "RMSprop", "SGD"], +) +parser.add_argument("--learning_rate", default=5e-4, type=float) +parser.add_argument("--reward_decay", default=0.9, type=float) +parser.add_argument("--weight_decay", default=0, type=float) # Output options -parser.add_argument('--checkpoint_path', default='data/checkpoint.pt') -parser.add_argument('--randomize_checkpoint_path', type=int, default=0) -parser.add_argument('--avoid_checkpoint_override', default=0, type=int) -parser.add_argument('--record_loss_every', default=1, type=int) -parser.add_argument('--checkpoint_every', default=10000, type=int) -parser.add_argument('--time', default=0, type=int) +parser.add_argument("--checkpoint_path", default="data/checkpoint.pt") +parser.add_argument("--randomize_checkpoint_path", type=int, default=0) +parser.add_argument("--avoid_checkpoint_override", default=0, type=int) +parser.add_argument("--record_loss_every", default=1, type=int) +parser.add_argument("--checkpoint_every", default=10000, type=int) +parser.add_argument("--time", default=0, type=int) def main(args): - if args.randomize_checkpoint_path == 1: - name, ext = os.path.splitext(args.checkpoint_path) - num = random.randint(1, 1000000) - args.checkpoint_path = '%s_%06d%s' % (name, num, ext) - print('Will save checkpoints to %s' % args.checkpoint_path) - - vocab = utils.load_vocab(args.vocab_json) - - if args.use_local_copies == 1: - shutil.copy(args.train_question_h5, '/tmp/train_questions.h5') - shutil.copy(args.train_features_h5, '/tmp/train_features.h5') - shutil.copy(args.val_question_h5, '/tmp/val_questions.h5') - shutil.copy(args.val_features_h5, '/tmp/val_features.h5') - args.train_question_h5 = '/tmp/train_questions.h5' - args.train_features_h5 = '/tmp/train_features.h5' - args.val_question_h5 = '/tmp/val_questions.h5' - args.val_features_h5 = '/tmp/val_features.h5' - - question_families = None - if args.family_split_file is not None: - with open(args.family_split_file, 'r') as f: - question_families = json.load(f) - - train_loader_kwargs = { - 'question_h5': args.train_question_h5, - 'feature_h5': args.train_features_h5, - 'vocab': vocab, - 'batch_size': args.batch_size, - 'shuffle': args.shuffle_train_data == 1, - 'question_families': question_families, - 'max_samples': args.num_train_samples, - 'num_workers': args.loader_num_workers, - } - val_loader_kwargs = { - 'question_h5': args.val_question_h5, - 'feature_h5': args.val_features_h5, - 'vocab': vocab, - 'batch_size': args.batch_size, - 'question_families': question_families, - 'max_samples': args.num_val_samples, - 'num_workers': args.loader_num_workers, - } - - with ClevrDataLoader(**train_loader_kwargs) as train_loader, \ - ClevrDataLoader(**val_loader_kwargs) as val_loader: - train_loop(args, train_loader, val_loader) - - if args.use_local_copies == 1 and args.cleanup_local_copies == 1: - os.remove('/tmp/train_questions.h5') - os.remove('/tmp/train_features.h5') - os.remove('/tmp/val_questions.h5') - os.remove('/tmp/val_features.h5') - - -def train_loop(args, train_loader, val_loader): - vocab = utils.load_vocab(args.vocab_json) - program_generator, pg_kwargs, pg_optimizer = None, None, None - execution_engine, ee_kwargs, ee_optimizer = None, None, None - baseline_model, baseline_kwargs, baseline_optimizer = None, None, None - baseline_type = None - - pg_best_state, ee_best_state, baseline_best_state = None, None, None - - # Set up model - optim_method = getattr(torch.optim, args.optimizer) - if args.model_type in ['FiLM', 'PG', 'PG+EE']: - program_generator, pg_kwargs = get_program_generator(args) - pg_optimizer = optim_method(program_generator.parameters(), - lr=args.learning_rate, - weight_decay=args.weight_decay) - print('Here is the conditioning network:') - print(program_generator) - if args.model_type in ['FiLM', 'EE', 'PG+EE']: - execution_engine, ee_kwargs = get_execution_engine(args) - ee_optimizer = optim_method(execution_engine.parameters(), - lr=args.learning_rate, - weight_decay=args.weight_decay) - print('Here is the conditioned network:') - print(execution_engine) - if args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']: - baseline_model, baseline_kwargs = get_baseline_model(args) - params = baseline_model.parameters() - if args.baseline_train_only_rnn == 1: - params = baseline_model.rnn.parameters() - baseline_optimizer = optim_method(params, - lr=args.learning_rate, - weight_decay=args.weight_decay) - print('Here is the baseline model') - print(baseline_model) - baseline_type = args.model_type - loss_fn = torch.nn.CrossEntropyLoss().cuda() - - stats = { - 'train_losses': [], 'train_rewards': [], 'train_losses_ts': [], - 'train_accs': [], 'val_accs': [], 'val_accs_ts': [], - 'best_val_acc': -1, 'model_t': 0, - } - t, epoch, reward_moving_average = 0, 0, 0 - - set_mode('train', [program_generator, execution_engine, baseline_model]) - - print('train_loader has %d samples' % len(train_loader.dataset)) - print('val_loader has %d samples' % len(val_loader.dataset)) - - num_checkpoints = 0 - epoch_start_time = 0.0 - epoch_total_time = 0.0 - train_pass_total_time = 0.0 - val_pass_total_time = 0.0 - running_loss = 0.0 - while t < args.num_iterations: - if (epoch > 0) and (args.time == 1): - epoch_time = time.time() - epoch_start_time - epoch_total_time += epoch_time - print(colored('EPOCH PASS AVG TIME: ' + str(epoch_total_time / epoch), 'white')) - print(colored('Epoch Pass Time : ' + str(epoch_time), 'white')) - epoch_start_time = time.time() - - epoch += 1 - print('Starting epoch %d' % epoch) - for batch in train_loader: - t += 1 - questions, _, feats, answers, programs, _ = batch - if isinstance(questions, list): - questions = questions[0] - questions_var = Variable(questions.cuda()) - feats_var = Variable(feats.cuda()) - answers_var = Variable(answers.cuda()) - if programs[0] is not None: - programs_var = Variable(programs.cuda()) - - reward = None - if args.model_type == 'PG': - # Train program generator with ground-truth programs - pg_optimizer.zero_grad() - loss = program_generator(questions_var, programs_var) - loss.backward() - pg_optimizer.step() - elif args.model_type == 'EE': - # Train execution engine with ground-truth programs - ee_optimizer.zero_grad() - scores = execution_engine(feats_var, programs_var) - loss = loss_fn(scores, answers_var) - loss.backward() - ee_optimizer.step() - elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']: - baseline_optimizer.zero_grad() - baseline_model.zero_grad() - scores = baseline_model(questions_var, feats_var) - loss = loss_fn(scores, answers_var) - loss.backward() - baseline_optimizer.step() - elif args.model_type == 'PG+EE': - programs_pred = program_generator.reinforce_sample(questions_var) - scores = execution_engine(feats_var, programs_pred) - - loss = loss_fn(scores, answers_var) - _, preds = scores.data.cpu().max(1) - raw_reward = (preds == answers).float() - reward_moving_average *= args.reward_decay - reward_moving_average += (1.0 - args.reward_decay) * raw_reward.mean() - centered_reward = raw_reward - reward_moving_average - - if args.train_execution_engine == 1: - ee_optimizer.zero_grad() - loss.backward() - ee_optimizer.step() - - if args.train_program_generator == 1: - pg_optimizer.zero_grad() - program_generator.reinforce_backward(centered_reward.cuda()) - pg_optimizer.step() - elif args.model_type == 'FiLM': - if args.set_execution_engine_eval == 1: - set_mode('eval', [execution_engine]) - programs_pred = program_generator(questions_var) - scores = execution_engine(feats_var, programs_pred) - loss = loss_fn(scores, answers_var) - - pg_optimizer.zero_grad() - ee_optimizer.zero_grad() - if args.debug_every <= -2: - pdb.set_trace() - loss.backward() - if args.debug_every < float('inf'): - check_grad_num_nans(execution_engine, 'FiLMedNet') - check_grad_num_nans(program_generator, 'FiLMGen') - - if args.train_program_generator == 1: - if args.grad_clip > 0: - torch.nn.utils.clip_grad_norm(program_generator.parameters(), args.grad_clip) - pg_optimizer.step() - if args.train_execution_engine == 1: - if args.grad_clip > 0: - torch.nn.utils.clip_grad_norm(execution_engine.parameters(), args.grad_clip) - ee_optimizer.step() - - if t % args.record_loss_every == 0: - running_loss += loss.data[0] - avg_loss = running_loss / args.record_loss_every - print(t, avg_loss) - stats['train_losses'].append(avg_loss) - stats['train_losses_ts'].append(t) - if reward is not None: - stats['train_rewards'].append(reward) - running_loss = 0.0 - else: - running_loss += loss.data[0] - - if t % args.checkpoint_every == 0: - num_checkpoints += 1 - print('Checking training accuracy ... ') - start = time.time() - train_acc = check_accuracy(args, program_generator, execution_engine, - baseline_model, train_loader) - if args.time == 1: - train_pass_time = (time.time() - start) - train_pass_total_time += train_pass_time - print(colored('TRAIN PASS AVG TIME: ' + str(train_pass_total_time / num_checkpoints), 'red')) - print(colored('Train Pass Time : ' + str(train_pass_time), 'red')) - print('train accuracy is', train_acc) - print('Checking validation accuracy ...') - start = time.time() - val_acc = check_accuracy(args, program_generator, execution_engine, - baseline_model, val_loader) - if args.time == 1: - val_pass_time = (time.time() - start) - val_pass_total_time += val_pass_time - print(colored('VAL PASS AVG TIME: ' + str(val_pass_total_time / num_checkpoints), 'cyan')) - print(colored('Val Pass Time : ' + str(val_pass_time), 'cyan')) - print('val accuracy is ', val_acc) - stats['train_accs'].append(train_acc) - stats['val_accs'].append(val_acc) - stats['val_accs_ts'].append(t) - - if val_acc > stats['best_val_acc']: - stats['best_val_acc'] = val_acc - stats['model_t'] = t - best_pg_state = get_state(program_generator) - best_ee_state = get_state(execution_engine) - best_baseline_state = get_state(baseline_model) - - checkpoint = { - 'args': args.__dict__, - 'program_generator_kwargs': pg_kwargs, - 'program_generator_state': best_pg_state, - 'execution_engine_kwargs': ee_kwargs, - 'execution_engine_state': best_ee_state, - 'baseline_kwargs': baseline_kwargs, - 'baseline_state': best_baseline_state, - 'baseline_type': baseline_type, - 'vocab': vocab - } - for k, v in stats.items(): - checkpoint[k] = v - print('Saving checkpoint to %s' % args.checkpoint_path) - torch.save(checkpoint, args.checkpoint_path) - del checkpoint['program_generator_state'] - del checkpoint['execution_engine_state'] - del checkpoint['baseline_state'] - with open(args.checkpoint_path + '.json', 'w') as f: - json.dump(checkpoint, f) + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + if args.randomize_checkpoint_path == 1: + name, ext = os.path.splitext(args.checkpoint_path) + num = random.randint(1, 1000000) + args.checkpoint_path = "%s_%06d%s" % (name, num, ext) + print("Will save checkpoints to %s" % args.checkpoint_path) + + vocab = utils.load_vocab(args.vocab_json) + + if args.use_local_copies == 1: + shutil.copy(args.train_question_h5, "/tmp/train_questions.h5") + shutil.copy(args.train_features_h5, "/tmp/train_features.h5") + shutil.copy(args.val_question_h5, "/tmp/val_questions.h5") + shutil.copy(args.val_features_h5, "/tmp/val_features.h5") + args.train_question_h5 = "/tmp/train_questions.h5" + args.train_features_h5 = "/tmp/train_features.h5" + args.val_question_h5 = "/tmp/val_questions.h5" + args.val_features_h5 = "/tmp/val_features.h5" + + question_families = None + if args.family_split_file is not None: + with open(args.family_split_file, "r") as f: + question_families = json.load(f) + + train_loader_kwargs = { + "question_h5": args.train_question_h5, + "feature_h5": args.train_features_h5, + "vocab": vocab, + "batch_size": args.batch_size, + "shuffle": args.shuffle_train_data == 1, + "question_families": question_families, + "max_samples": args.num_train_samples, + "num_workers": args.loader_num_workers, + } + val_loader_kwargs = { + "question_h5": args.val_question_h5, + "feature_h5": args.val_features_h5, + "vocab": vocab, + "batch_size": args.batch_size, + "question_families": question_families, + "max_samples": args.num_val_samples, + "num_workers": args.loader_num_workers, + } - if t == args.num_iterations: - break + with ( + ClevrDataLoader(**train_loader_kwargs) as train_loader, + ClevrDataLoader(**val_loader_kwargs) as val_loader, + ): + train_loop(args, train_loader, val_loader, device) + + if args.use_local_copies == 1 and args.cleanup_local_copies == 1: + os.remove("/tmp/train_questions.h5") + os.remove("/tmp/train_features.h5") + os.remove("/tmp/val_questions.h5") + os.remove("/tmp/val_features.h5") + now = datetime.now() + print("Current date and time:", now.strftime("%Y-%m-%d %H:%M")) + + +def train_loop(args, train_loader, val_loader, device): + vocab = utils.load_vocab(args.vocab_json) + program_generator, pg_kwargs, pg_optimizer = None, None, None + execution_engine, ee_kwargs, ee_optimizer = None, None, None + baseline_model, baseline_kwargs, baseline_optimizer = None, None, None + baseline_type = None + + pg_best_state, ee_best_state, baseline_best_state = None, None, None # noqa: F841 + + # Set up model + optim_method = getattr(torch.optim, args.optimizer) + if args.model_type in ["FiLM", "PG", "PG+EE"]: + program_generator, pg_kwargs = get_program_generator(args, device) + pg_optimizer = optim_method( + program_generator.parameters(), + lr=args.learning_rate, + weight_decay=args.weight_decay, + ) + print("Here is the conditioning network:") + print(program_generator) + if args.model_type in ["FiLM", "EE", "PG+EE"]: + execution_engine, ee_kwargs = get_execution_engine(args, device) + ee_optimizer = optim_method( + execution_engine.parameters(), + lr=args.learning_rate, + weight_decay=args.weight_decay, + ) + print("Here is the conditioned network:") + print(execution_engine) + if args.model_type in ["LSTM", "CNN+LSTM", "CNN+LSTM+SA"]: + baseline_model, baseline_kwargs = get_baseline_model(args, device) + params = baseline_model.parameters() + if args.baseline_train_only_rnn == 1: + params = baseline_model.rnn.parameters() + baseline_optimizer = optim_method( + params, lr=args.learning_rate, weight_decay=args.weight_decay + ) + print("Here is the baseline model") + print(baseline_model) + baseline_type = args.model_type + loss_fn = torch.nn.CrossEntropyLoss().to(device) + + stats = { + "train_losses": [], + "train_rewards": [], + "train_losses_ts": [], + "train_accs": [], + "val_accs": [], + "val_accs_ts": [], + "best_val_acc": -1, + "model_t": 0, + } + t, epoch, reward_moving_average = 0, 0, 0 + + set_mode("train", [program_generator, execution_engine, baseline_model]) + + print("train_loader has %d samples" % len(train_loader.dataset)) + print("val_loader has %d samples" % len(val_loader.dataset)) + + num_checkpoints = 0 + epoch_start_time = 0.0 + epoch_total_time = 0.0 + train_pass_total_time = 0.0 + val_pass_total_time = 0.0 + running_loss = 0.0 + while t < args.num_iterations: + if (epoch > 0) and (args.time == 1): + epoch_time = time.time() - epoch_start_time + epoch_total_time += epoch_time + print( + colored( + "EPOCH PASS AVG TIME: " + str(epoch_total_time / epoch), "white" + ) + ) + print(colored("Epoch Pass Time : " + str(epoch_time), "white")) + epoch_start_time = time.time() + + epoch += 1 + print("Starting epoch %d" % epoch) + now = datetime.now() + print("Current date and time (epoch):", now.strftime("%Y-%m-%d %H:%M")) + for batch in train_loader: + t += 1 + questions, _, feats, answers, programs, _ = batch + if isinstance(questions, list): + questions = questions[0] + questions_var = Variable(questions.to(device)) + feats_var = Variable(feats.to(device)) + answers_var = Variable(answers.to(device)) + if programs[0] is not None: + programs_var = Variable(programs.to(device)) + + reward = None + if args.model_type == "PG": + # Train program generator with ground-truth programs + pg_optimizer.zero_grad() + loss = program_generator(questions_var, programs_var) + loss.backward() + pg_optimizer.step() + elif args.model_type == "EE": + # Train execution engine with ground-truth programs + ee_optimizer.zero_grad() + scores = execution_engine(feats_var, programs_var) + loss = loss_fn(scores, answers_var) + loss.backward() + ee_optimizer.step() + elif args.model_type in ["LSTM", "CNN+LSTM", "CNN+LSTM+SA"]: + baseline_optimizer.zero_grad() + baseline_model.zero_grad() + scores = baseline_model(questions_var, feats_var) + loss = loss_fn(scores, answers_var) + loss.backward() + baseline_optimizer.step() + elif args.model_type == "PG+EE": + programs_pred = program_generator.reinforce_sample(questions_var) + scores = execution_engine(feats_var, programs_pred) + + loss = loss_fn(scores, answers_var) + _, preds = scores.data.cpu().max(1) + raw_reward = (preds == answers).float() + reward_moving_average *= args.reward_decay + reward_moving_average += (1.0 - args.reward_decay) * raw_reward.mean() + centered_reward = raw_reward - reward_moving_average + + if args.train_execution_engine == 1: + ee_optimizer.zero_grad() + loss.backward() + ee_optimizer.step() + + if args.train_program_generator == 1: + pg_optimizer.zero_grad() + program_generator.reinforce_backward(centered_reward.to(device)) + pg_optimizer.step() + elif args.model_type == "FiLM": + if args.set_execution_engine_eval == 1: + set_mode("eval", [execution_engine]) + programs_pred = program_generator(questions_var) + scores = execution_engine(feats_var, programs_pred) + loss = loss_fn(scores, answers_var) + + pg_optimizer.zero_grad() + ee_optimizer.zero_grad() + if args.debug_every <= -2: + pdb.set_trace() + loss.backward() + if args.debug_every < float("inf"): + check_grad_num_nans(execution_engine, "FiLMedNet") + check_grad_num_nans(program_generator, "FiLMGen") + + if args.train_program_generator == 1: + if args.grad_clip > 0: + torch.nn.utils.clip_grad_norm( + program_generator.parameters(), args.grad_clip + ) + pg_optimizer.step() + if args.train_execution_engine == 1: + if args.grad_clip > 0: + torch.nn.utils.clip_grad_norm( + execution_engine.parameters(), args.grad_clip + ) + ee_optimizer.step() + + if t % args.record_loss_every == 0: + running_loss += loss.item() + avg_loss = running_loss / args.record_loss_every + print("t:", t, "avg_loss", avg_loss) + stats["train_losses"].append(avg_loss) + stats["train_losses_ts"].append(t) + if reward is not None: + stats["train_rewards"].append(reward) + running_loss = 0.0 + else: + running_loss += loss.item() + + if t % args.checkpoint_every == 0: + num_checkpoints += 1 + print("Checking training accuracy ... ") + now = datetime.now() + print("Current date and time:", now.strftime("%Y-%m-%d %H:%M")) + start = time.time() + train_acc = check_accuracy( + args, + program_generator, + execution_engine, + baseline_model, + train_loader, + device, + ) + if args.time == 1: + train_pass_time = time.time() - start + train_pass_total_time += train_pass_time + print( + colored( + "TRAIN PASS AVG TIME: " + + str(train_pass_total_time / num_checkpoints), + "red", + ) + ) + print( + colored("Train Pass Time : " + str(train_pass_time), "red") + ) + print("train accuracy is", train_acc) + print("Checking validation accuracy ...") + now = datetime.now() + print("Current date and time:", now.strftime("%Y-%m-%d %H:%M")) + start = time.time() + val_acc = check_accuracy( + args, + program_generator, + execution_engine, + baseline_model, + val_loader, + device, + ) + if args.time == 1: + val_pass_time = time.time() - start + val_pass_total_time += val_pass_time + print( + colored( + "VAL PASS AVG TIME: " + + str(val_pass_total_time / num_checkpoints), + "cyan", + ) + ) + print( + colored("Val Pass Time : " + str(val_pass_time), "cyan") + ) + print("val accuracy is ", val_acc) + stats["train_accs"].append(train_acc) + stats["val_accs"].append(val_acc) + stats["val_accs_ts"].append(t) + + if val_acc > stats["best_val_acc"]: + stats["best_val_acc"] = val_acc + stats["model_t"] = t + best_pg_state = get_state(program_generator) + best_ee_state = get_state(execution_engine) + best_baseline_state = get_state(baseline_model) + + checkpoint = { + "args": args.__dict__, + "program_generator_kwargs": pg_kwargs, + "program_generator_state": best_pg_state, + "execution_engine_kwargs": ee_kwargs, + "execution_engine_state": best_ee_state, + "baseline_kwargs": baseline_kwargs, + "baseline_state": best_baseline_state, + "baseline_type": baseline_type, + "vocab": vocab, + } + for k, v in stats.items(): + checkpoint[k] = v + print("Saving checkpoint to %s" % args.checkpoint_path) + torch.save(checkpoint, args.checkpoint_path) + del checkpoint["program_generator_state"] + del checkpoint["execution_engine_state"] + del checkpoint["baseline_state"] + with open(args.checkpoint_path + ".json", "w") as f: + json.dump(checkpoint, f) + + if t == args.num_iterations: + break def parse_int_list(s): - if s == '': return () - return tuple(int(n) for n in s.split(',')) + if s == "": + return () + return tuple(int(n) for n in s.split(",")) def get_state(m): - if m is None: - return None - state = {} - for k, v in m.state_dict().items(): - state[k] = v.clone() - return state - - -def get_program_generator(args): - vocab = utils.load_vocab(args.vocab_json) - if args.program_generator_start_from is not None: - pg, kwargs = utils.load_program_generator( - args.program_generator_start_from, model_type=args.model_type) - cur_vocab_size = pg.encoder_embed.weight.size(0) - if cur_vocab_size != len(vocab['question_token_to_idx']): - print('Expanding vocabulary of program generator') - pg.expand_encoder_vocab(vocab['question_token_to_idx']) - kwargs['encoder_vocab_size'] = len(vocab['question_token_to_idx']) - else: - kwargs = { - 'encoder_vocab_size': len(vocab['question_token_to_idx']), - 'decoder_vocab_size': len(vocab['program_token_to_idx']), - 'wordvec_dim': args.rnn_wordvec_dim, - 'hidden_dim': args.rnn_hidden_dim, - 'rnn_num_layers': args.rnn_num_layers, - 'rnn_dropout': args.rnn_dropout, - } - if args.model_type == 'FiLM': - kwargs['parameter_efficient'] = args.program_generator_parameter_efficient == 1 - kwargs['output_batchnorm'] = args.rnn_output_batchnorm == 1 - kwargs['bidirectional'] = args.bidirectional == 1 - kwargs['encoder_type'] = args.encoder_type - kwargs['decoder_type'] = args.decoder_type - kwargs['gamma_option'] = args.gamma_option - kwargs['gamma_baseline'] = args.gamma_baseline - kwargs['num_modules'] = args.num_modules - kwargs['module_num_layers'] = args.module_num_layers - kwargs['module_dim'] = args.module_dim - kwargs['debug_every'] = args.debug_every - pg = FiLMGen(**kwargs) + if m is None: + return None + state = {} + for k, v in m.state_dict().items(): + state[k] = v.clone() + return state + + +def get_program_generator(args, device): + vocab = utils.load_vocab(args.vocab_json) + if args.program_generator_start_from is not None: + pg, kwargs = utils.load_program_generator( + args.program_generator_start_from, model_type=args.model_type + ) + cur_vocab_size = pg.encoder_embed.weight.size(0) + if cur_vocab_size != len(vocab["question_token_to_idx"]): + print("Expanding vocabulary of program generator") + pg.expand_encoder_vocab(vocab["question_token_to_idx"]) + kwargs["encoder_vocab_size"] = len(vocab["question_token_to_idx"]) else: - pg = Seq2Seq(**kwargs) - pg.cuda() - pg.train() - return pg, kwargs - - -def get_execution_engine(args): - vocab = utils.load_vocab(args.vocab_json) - if args.execution_engine_start_from is not None: - ee, kwargs = utils.load_execution_engine( - args.execution_engine_start_from, model_type=args.model_type) - else: - kwargs = { - 'vocab': vocab, - 'feature_dim': parse_int_list(args.feature_dim), - 'stem_batchnorm': args.module_stem_batchnorm == 1, - 'stem_num_layers': args.module_stem_num_layers, - 'module_dim': args.module_dim, - 'module_residual': args.module_residual == 1, - 'module_batchnorm': args.module_batchnorm == 1, - 'classifier_proj_dim': args.classifier_proj_dim, - 'classifier_downsample': args.classifier_downsample, - 'classifier_fc_layers': parse_int_list(args.classifier_fc_dims), - 'classifier_batchnorm': args.classifier_batchnorm == 1, - 'classifier_dropout': args.classifier_dropout, - } - if args.model_type == 'FiLM': - kwargs['num_modules'] = args.num_modules - kwargs['stem_kernel_size'] = args.module_stem_kernel_size - kwargs['stem_stride'] = args.module_stem_stride - kwargs['stem_padding'] = args.module_stem_padding - kwargs['module_num_layers'] = args.module_num_layers - kwargs['module_batchnorm_affine'] = args.module_batchnorm_affine == 1 - kwargs['module_dropout'] = args.module_dropout - kwargs['module_input_proj'] = args.module_input_proj - kwargs['module_kernel_size'] = args.module_kernel_size - kwargs['use_gamma'] = args.use_gamma == 1 - kwargs['use_beta'] = args.use_beta == 1 - kwargs['use_coords'] = args.use_coords - kwargs['debug_every'] = args.debug_every - kwargs['print_verbose_every'] = args.print_verbose_every - kwargs['condition_method'] = args.condition_method - kwargs['condition_pattern'] = parse_int_list(args.condition_pattern) - ee = FiLMedNet(**kwargs) + kwargs = { + "encoder_vocab_size": len(vocab["question_token_to_idx"]), + "decoder_vocab_size": len(vocab["program_token_to_idx"]), + "wordvec_dim": args.rnn_wordvec_dim, + "hidden_dim": args.rnn_hidden_dim, + "rnn_num_layers": args.rnn_num_layers, + "rnn_dropout": args.rnn_dropout, + } + if args.model_type == "FiLM": + kwargs["parameter_efficient"] = ( + args.program_generator_parameter_efficient == 1 + ) + kwargs["output_batchnorm"] = args.rnn_output_batchnorm == 1 + kwargs["bidirectional"] = args.bidirectional == 1 + kwargs["encoder_type"] = args.encoder_type + kwargs["decoder_type"] = args.decoder_type + kwargs["gamma_option"] = args.gamma_option + kwargs["gamma_baseline"] = args.gamma_baseline + kwargs["num_modules"] = args.num_modules + kwargs["module_num_layers"] = args.module_num_layers + kwargs["module_dim"] = args.module_dim + kwargs["debug_every"] = args.debug_every + pg = FiLMGen(**kwargs) + else: + pg = Seq2Seq(**kwargs) + pg.to(device) + pg.train() + return pg, kwargs + + +def get_execution_engine(args, device): + vocab = utils.load_vocab(args.vocab_json) + if args.execution_engine_start_from is not None: + ee, kwargs = utils.load_execution_engine( + args.execution_engine_start_from, model_type=args.model_type + ) else: - ee = ModuleNet(**kwargs) - ee.cuda() - ee.train() - return ee, kwargs - - -def get_baseline_model(args): - vocab = utils.load_vocab(args.vocab_json) - if args.baseline_start_from is not None: - model, kwargs = utils.load_baseline(args.baseline_start_from) - elif args.model_type == 'LSTM': - kwargs = { - 'vocab': vocab, - 'rnn_wordvec_dim': args.rnn_wordvec_dim, - 'rnn_dim': args.rnn_hidden_dim, - 'rnn_num_layers': args.rnn_num_layers, - 'rnn_dropout': args.rnn_dropout, - 'fc_dims': parse_int_list(args.classifier_fc_dims), - 'fc_use_batchnorm': args.classifier_batchnorm == 1, - 'fc_dropout': args.classifier_dropout, - } - model = LstmModel(**kwargs) - elif args.model_type == 'CNN+LSTM': - kwargs = { - 'vocab': vocab, - 'rnn_wordvec_dim': args.rnn_wordvec_dim, - 'rnn_dim': args.rnn_hidden_dim, - 'rnn_num_layers': args.rnn_num_layers, - 'rnn_dropout': args.rnn_dropout, - 'cnn_feat_dim': parse_int_list(args.feature_dim), - 'cnn_num_res_blocks': args.cnn_num_res_blocks, - 'cnn_res_block_dim': args.cnn_res_block_dim, - 'cnn_proj_dim': args.cnn_proj_dim, - 'cnn_pooling': args.cnn_pooling, - 'fc_dims': parse_int_list(args.classifier_fc_dims), - 'fc_use_batchnorm': args.classifier_batchnorm == 1, - 'fc_dropout': args.classifier_dropout, - } - model = CnnLstmModel(**kwargs) - elif args.model_type == 'CNN+LSTM+SA': - kwargs = { - 'vocab': vocab, - 'rnn_wordvec_dim': args.rnn_wordvec_dim, - 'rnn_dim': args.rnn_hidden_dim, - 'rnn_num_layers': args.rnn_num_layers, - 'rnn_dropout': args.rnn_dropout, - 'cnn_feat_dim': parse_int_list(args.feature_dim), - 'stacked_attn_dim': args.stacked_attn_dim, - 'num_stacked_attn': args.num_stacked_attn, - 'fc_dims': parse_int_list(args.classifier_fc_dims), - 'fc_use_batchnorm': args.classifier_batchnorm == 1, - 'fc_dropout': args.classifier_dropout, - } - model = CnnLstmSaModel(**kwargs) - if model.rnn.token_to_idx != vocab['question_token_to_idx']: - # Make sure new vocab is superset of old - for k, v in model.rnn.token_to_idx.items(): - assert k in vocab['question_token_to_idx'] - assert vocab['question_token_to_idx'][k] == v - for token, idx in vocab['question_token_to_idx'].items(): - model.rnn.token_to_idx[token] = idx - kwargs['vocab'] = vocab - model.rnn.expand_vocab(vocab['question_token_to_idx']) - model.cuda() - model.train() - return model, kwargs + kwargs = { + "vocab": vocab, + "feature_dim": parse_int_list(args.feature_dim), + "stem_batchnorm": args.module_stem_batchnorm == 1, + "stem_num_layers": args.module_stem_num_layers, + "module_dim": args.module_dim, + "module_residual": args.module_residual == 1, + "module_batchnorm": args.module_batchnorm == 1, + "classifier_proj_dim": args.classifier_proj_dim, + "classifier_downsample": args.classifier_downsample, + "classifier_fc_layers": parse_int_list(args.classifier_fc_dims), + "classifier_batchnorm": args.classifier_batchnorm == 1, + "classifier_dropout": args.classifier_dropout, + } + if args.model_type == "FiLM": + kwargs["num_modules"] = args.num_modules + kwargs["stem_kernel_size"] = args.module_stem_kernel_size + kwargs["stem_stride"] = args.module_stem_stride + kwargs["stem_padding"] = args.module_stem_padding + kwargs["module_num_layers"] = args.module_num_layers + kwargs["module_batchnorm_affine"] = args.module_batchnorm_affine == 1 + kwargs["module_dropout"] = args.module_dropout + kwargs["module_input_proj"] = args.module_input_proj + kwargs["module_kernel_size"] = args.module_kernel_size + kwargs["use_gamma"] = args.use_gamma == 1 + kwargs["use_beta"] = args.use_beta == 1 + kwargs["use_coords"] = args.use_coords + kwargs["debug_every"] = args.debug_every + kwargs["print_verbose_every"] = args.print_verbose_every + kwargs["condition_method"] = args.condition_method + kwargs["condition_pattern"] = parse_int_list(args.condition_pattern) + ee = FiLMedNet(**kwargs) + else: + ee = ModuleNet(**kwargs) + ee.to(device) + ee.train() + return ee, kwargs + + +def get_baseline_model(args, device): + vocab = utils.load_vocab(args.vocab_json) + if args.baseline_start_from is not None: + model, kwargs = utils.load_baseline(args.baseline_start_from) + elif args.model_type == "LSTM": + kwargs = { + "vocab": vocab, + "rnn_wordvec_dim": args.rnn_wordvec_dim, + "rnn_dim": args.rnn_hidden_dim, + "rnn_num_layers": args.rnn_num_layers, + "rnn_dropout": args.rnn_dropout, + "fc_dims": parse_int_list(args.classifier_fc_dims), + "fc_use_batchnorm": args.classifier_batchnorm == 1, + "fc_dropout": args.classifier_dropout, + } + model = LstmModel(**kwargs) + elif args.model_type == "CNN+LSTM": + kwargs = { + "vocab": vocab, + "rnn_wordvec_dim": args.rnn_wordvec_dim, + "rnn_dim": args.rnn_hidden_dim, + "rnn_num_layers": args.rnn_num_layers, + "rnn_dropout": args.rnn_dropout, + "cnn_feat_dim": parse_int_list(args.feature_dim), + "cnn_num_res_blocks": args.cnn_num_res_blocks, + "cnn_res_block_dim": args.cnn_res_block_dim, + "cnn_proj_dim": args.cnn_proj_dim, + "cnn_pooling": args.cnn_pooling, + "fc_dims": parse_int_list(args.classifier_fc_dims), + "fc_use_batchnorm": args.classifier_batchnorm == 1, + "fc_dropout": args.classifier_dropout, + } + model = CnnLstmModel(**kwargs) + elif args.model_type == "CNN+LSTM+SA": + kwargs = { + "vocab": vocab, + "rnn_wordvec_dim": args.rnn_wordvec_dim, + "rnn_dim": args.rnn_hidden_dim, + "rnn_num_layers": args.rnn_num_layers, + "rnn_dropout": args.rnn_dropout, + "cnn_feat_dim": parse_int_list(args.feature_dim), + "stacked_attn_dim": args.stacked_attn_dim, + "num_stacked_attn": args.num_stacked_attn, + "fc_dims": parse_int_list(args.classifier_fc_dims), + "fc_use_batchnorm": args.classifier_batchnorm == 1, + "fc_dropout": args.classifier_dropout, + } + model = CnnLstmSaModel(**kwargs) + if model.rnn.token_to_idx != vocab["question_token_to_idx"]: + # Make sure new vocab is superset of old + for k, v in model.rnn.token_to_idx.items(): + assert k in vocab["question_token_to_idx"] + assert vocab["question_token_to_idx"][k] == v + for token, idx in vocab["question_token_to_idx"].items(): + model.rnn.token_to_idx[token] = idx + kwargs["vocab"] = vocab + model.rnn.expand_vocab(vocab["question_token_to_idx"]) + model.to(device) + model.train() + return model, kwargs def set_mode(mode, models): - assert mode in ['train', 'eval'] - for m in models: - if m is None: continue - if mode == 'train': m.train() - if mode == 'eval': m.eval() - - -def check_accuracy(args, program_generator, execution_engine, baseline_model, loader): - set_mode('eval', [program_generator, execution_engine, baseline_model]) - num_correct, num_samples = 0, 0 - for batch in loader: - questions, _, feats, answers, programs, _ = batch - if isinstance(questions, list): - questions = questions[0] - - questions_var = Variable(questions.cuda(), volatile=True) - feats_var = Variable(feats.cuda(), volatile=True) - answers_var = Variable(feats.cuda(), volatile=True) - if programs[0] is not None: - programs_var = Variable(programs.cuda(), volatile=True) - - scores = None # Use this for everything but PG - if args.model_type == 'PG': - vocab = utils.load_vocab(args.vocab_json) - for i in range(questions.size(0)): - program_pred = program_generator.sample(Variable(questions[i:i+1].cuda(), volatile=True)) - program_pred_str = vr.preprocess.decode(program_pred, vocab['program_idx_to_token']) - program_str = vr.preprocess.decode(programs[i], vocab['program_idx_to_token']) - if program_pred_str == program_str: - num_correct += 1 - num_samples += 1 - elif args.model_type == 'EE': - scores = execution_engine(feats_var, programs_var) - elif args.model_type == 'PG+EE': - programs_pred = program_generator.reinforce_sample( - questions_var, argmax=True) - scores = execution_engine(feats_var, programs_pred) - elif args.model_type == 'FiLM': - programs_pred = program_generator(questions_var) - scores = execution_engine(feats_var, programs_pred) - elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']: - scores = baseline_model(questions_var, feats_var) - - if scores is not None: - _, preds = scores.data.cpu().max(1) - num_correct += (preds == answers).sum() - num_samples += preds.size(0) - - if args.num_val_samples is not None and num_samples >= args.num_val_samples: - break - - set_mode('train', [program_generator, execution_engine, baseline_model]) - acc = float(num_correct) / num_samples - return acc - -def check_grad_num_nans(model, model_name='model'): + assert mode in ["train", "eval"] + for m in models: + if m is None: + continue + if mode == "train": + m.train() + if mode == "eval": + m.eval() + + +def check_accuracy( + args, program_generator, execution_engine, baseline_model, loader, device +): + set_mode("eval", [program_generator, execution_engine, baseline_model]) + num_correct, num_samples = 0, 0 + for batch in loader: + questions, _, feats, answers, programs, _ = batch + if isinstance(questions, list): + questions = questions[0] + + with torch.no_grad(): + questions_var = Variable(questions.to(device)) + feats_var = Variable(feats.to(device)) + answers_var = Variable(feats.to(device)) # noqa: F841 + if programs[0] is not None: + programs_var = Variable(programs.to(device)) + + scores = None # Use this for everything but PG + if args.model_type == "PG": + vocab = utils.load_vocab(args.vocab_json) + for i in range(questions.size(0)): + program_pred = program_generator.sample( + Variable(questions[i : i + 1].to(device)) + ) + program_pred_str = vr.preprocess.decode( + program_pred, vocab["program_idx_to_token"] + ) + program_str = vr.preprocess.decode( + programs[i], vocab["program_idx_to_token"] + ) + if program_pred_str == program_str: + num_correct += 1 + num_samples += 1 + elif args.model_type == "EE": + scores = execution_engine(feats_var, programs_var) + elif args.model_type == "PG+EE": + programs_pred = program_generator.reinforce_sample( + questions_var, argmax=True + ) + scores = execution_engine(feats_var, programs_pred) + elif args.model_type == "FiLM": + programs_pred = program_generator(questions_var) + scores = execution_engine(feats_var, programs_pred) + elif args.model_type in ["LSTM", "CNN+LSTM", "CNN+LSTM+SA"]: + scores = baseline_model(questions_var, feats_var) + + if scores is not None: + _, preds = scores.data.cpu().max(1) + num_correct += (preds == answers).sum() + num_samples += preds.size(0) + + if args.num_val_samples is not None and num_samples >= args.num_val_samples: + break + + set_mode("train", [program_generator, execution_engine, baseline_model]) + acc = float(num_correct) / num_samples + return acc + + +def check_grad_num_nans(model, model_name="model"): grads = [p.grad for p in model.parameters() if p.grad is not None] num_nans = [np.sum(np.isnan(grad.data.cpu().numpy())) for grad in grads] nan_checks = [num_nan == 0 for num_nan in num_nans] if False in nan_checks: - print('Nans in ' + model_name + ' gradient!') - print(num_nans) - pdb.set_trace() - raise(Exception) - -if __name__ == '__main__': - args = parser.parse_args() - main(args) + print("Nans in " + model_name + " gradient!") + print(num_nans) + pdb.set_trace() + raise (Exception) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/vr/data.py b/vr/data.py index 894cae5..306dcf9 100644 --- a/vr/data.py +++ b/vr/data.py @@ -16,163 +16,211 @@ def _dataset_to_tensor(dset, mask=None): - arr = np.asarray(dset, dtype=np.int64) - if mask is not None: - arr = arr[mask] - tensor = torch.LongTensor(arr) - return tensor + arr = np.asarray(dset, dtype=np.int64) + if mask is not None: + arr = arr[mask] + tensor = torch.LongTensor(arr) + return tensor class ClevrDataset(Dataset): - def __init__(self, question_h5, feature_h5, vocab, mode='prefix', - image_h5=None, max_samples=None, question_families=None, - image_idx_start_from=None): - mode_choices = ['prefix', 'postfix'] - if mode not in mode_choices: - raise ValueError('Invalid mode "%s"' % mode) - self.image_h5 = image_h5 - self.vocab = vocab - self.feature_h5 = feature_h5 - self.mode = mode - self.max_samples = max_samples - - mask = None - if question_families is not None: - # Use only the specified families - all_families = np.asarray(question_h5['question_families']) - N = all_families.shape[0] - print(question_families) - target_families = np.asarray(question_families)[:, None] - mask = (all_families == target_families).any(axis=0) - if image_idx_start_from is not None: - all_image_idxs = np.asarray(question_h5['image_idxs']) - mask = all_image_idxs >= image_idx_start_from - - # Data from the question file is small, so read it all into memory - print('Reading question data into memory') - self.all_types = None - if 'types' in question_h5: - self.all_types = _dataset_to_tensor(question_h5['types'], mask) - self.all_question_families = None - if 'question_families' in question_h5: - self.all_question_families = _dataset_to_tensor(question_h5['question_families'], mask) - self.all_questions = _dataset_to_tensor(question_h5['questions'], mask) - self.all_image_idxs = _dataset_to_tensor(question_h5['image_idxs'], mask) - self.all_programs = None - if 'programs' in question_h5: - self.all_programs = _dataset_to_tensor(question_h5['programs'], mask) - self.all_answers = None - if 'answers' in question_h5: - self.all_answers = _dataset_to_tensor(question_h5['answers'], mask) - - def __getitem__(self, index): - if self.all_question_families is not None: - question_family = self.all_question_families[index] - q_type = None if self.all_types is None else self.all_types[index] - question = self.all_questions[index] - image_idx = self.all_image_idxs[index] - answer = None - if self.all_answers is not None: - answer = self.all_answers[index] - program_seq = None - if self.all_programs is not None: - program_seq = self.all_programs[index] - - image = None - if self.image_h5 is not None: - image = self.image_h5['images'][image_idx] - image = torch.FloatTensor(np.asarray(image, dtype=np.float32)) - - feats = self.feature_h5['features'][image_idx] - feats = torch.FloatTensor(np.asarray(feats, dtype=np.float32)) - - program_json = None - if program_seq is not None: - program_json_seq = [] - for fn_idx in program_seq: - fn_str = self.vocab['program_idx_to_token'][fn_idx] - if fn_str == '' or fn_str == '': continue - fn = vr.programs.str_to_function(fn_str) - program_json_seq.append(fn) - if self.mode == 'prefix': - program_json = vr.programs.prefix_to_list(program_json_seq) - elif self.mode == 'postfix': - program_json = vr.programs.postfix_to_list(program_json_seq) - - if q_type is None: - return (question, image, feats, answer, program_seq, program_json) - return ([question, q_type], image, feats, answer, program_seq, program_json) - - def __len__(self): - if self.max_samples is None: - return self.all_questions.size(0) - else: - return min(self.max_samples, self.all_questions.size(0)) + def __init__( + self, + question_h5_path, + feature_h5_path, + vocab, + mode="prefix", + image_h5_path=None, + max_samples=None, + question_families=None, + image_idx_start_from=None, + ): + mode_choices = ["prefix", "postfix"] + if mode not in mode_choices: + raise ValueError('Invalid mode "%s"' % mode) + self.vocab = vocab + self.mode = mode + self.max_samples = max_samples + + # Store file paths for lazy loading + self.feature_h5_path = feature_h5_path + self.image_h5_path = image_h5_path + self.feature_h5 = None + self.image_h5 = None + + # Read question data into memory (file is small, so it's okay to load here) + print("Reading question data into memory from", question_h5_path) + with h5py.File(question_h5_path, "r") as question_h5: + mask = None + if question_families is not None: + # Use only the specified families + all_families = np.asarray(question_h5["question_families"]) + target_families = np.asarray(question_families)[:, None] + mask = (all_families == target_families).any(axis=0) + if image_idx_start_from is not None: + all_image_idxs = np.asarray(question_h5["image_idxs"]) + mask = all_image_idxs >= image_idx_start_from + + self.all_types = None + if "types" in question_h5: + self.all_types = _dataset_to_tensor(question_h5["types"], mask) + self.all_question_families = None + if "question_families" in question_h5: + self.all_question_families = _dataset_to_tensor( + question_h5["question_families"], mask + ) + self.all_questions = _dataset_to_tensor(question_h5["questions"], mask) + self.all_image_idxs = _dataset_to_tensor(question_h5["image_idxs"], mask) + self.all_programs = None + if "programs" in question_h5: + self.all_programs = _dataset_to_tensor(question_h5["programs"], mask) + self.all_answers = None + if "answers" in question_h5: + self.all_answers = _dataset_to_tensor(question_h5["answers"], mask) + + def _lazy_open_files(self): + # Open hdf5 files only if they haven't been opened yet. + if self.feature_h5 is None: + self.feature_h5 = h5py.File(self.feature_h5_path, "r") + if self.image_h5 is None and self.image_h5_path is not None: + self.image_h5 = h5py.File(self.image_h5_path, "r") + + def __getitem__(self, index): + self._lazy_open_files() + + if self.all_question_families is not None: + question_family = self.all_question_families[index] + q_type = None if self.all_types is None else self.all_types[index] + question = self.all_questions[index] + image_idx = self.all_image_idxs[index] + answer = None + if self.all_answers is not None: + answer = self.all_answers[index] + program_seq = None + if self.all_programs is not None: + program_seq = self.all_programs[index] + + image = None + if self.image_h5 is not None: + image = self.image_h5["images"][image_idx] + image = torch.FloatTensor(np.asarray(image, dtype=np.float32)) + + feats = self.feature_h5["features"][image_idx] + feats = torch.FloatTensor(np.asarray(feats, dtype=np.float32)) + + program_json = None + if program_seq is not None: + program_json_seq = [] + for fn_idx in program_seq: + # Convert fn_idx from a tensor to an int + key = ( + int(fn_idx.item()) + if isinstance(fn_idx, torch.Tensor) + else int(fn_idx) + ) + fn_str = self.vocab["program_idx_to_token"][key] + if fn_str == "" or fn_str == "": + continue + fn = vr.programs.str_to_function(fn_str) + program_json_seq.append(fn) + if self.mode == "prefix": + program_json = vr.programs.prefix_to_list(program_json_seq) + elif self.mode == "postfix": + program_json = vr.programs.postfix_to_list(program_json_seq) + + if q_type is None: + return (question, image, feats, answer, program_seq, program_json) + return ([question, q_type], image, feats, answer, program_seq, program_json) + + def __len__(self): + if self.max_samples is None: + return self.all_questions.size(0) + else: + return min(self.max_samples, self.all_questions.size(0)) + + def close(self): + # Close any opened hdf5 files + if self.image_h5 is not None: + self.image_h5.close() + self.image_h5 = None + if self.feature_h5 is not None: + self.feature_h5.close() + self.feature_h5 = None + + def __del__(self): + self.close() class ClevrDataLoader(DataLoader): - def __init__(self, **kwargs): - if 'question_h5' not in kwargs: - raise ValueError('Must give question_h5') - if 'feature_h5' not in kwargs: - raise ValueError('Must give feature_h5') - if 'vocab' not in kwargs: - raise ValueError('Must give vocab') - - feature_h5_path = kwargs.pop('feature_h5') - print('Reading features from', feature_h5_path) - self.feature_h5 = h5py.File(feature_h5_path, 'r') - - self.image_h5 = None - if 'image_h5' in kwargs: - image_h5_path = kwargs.pop('image_h5') - print('Reading images from ', image_h5_path) - self.image_h5 = h5py.File(image_h5_path, 'r') - - vocab = kwargs.pop('vocab') - mode = kwargs.pop('mode', 'prefix') - - question_families = kwargs.pop('question_families', None) - max_samples = kwargs.pop('max_samples', None) - question_h5_path = kwargs.pop('question_h5') - image_idx_start_from = kwargs.pop('image_idx_start_from', None) - print('Reading questions from ', question_h5_path) - with h5py.File(question_h5_path, 'r') as question_h5: - self.dataset = ClevrDataset(question_h5, self.feature_h5, vocab, mode, - image_h5=self.image_h5, - max_samples=max_samples, - question_families=question_families, - image_idx_start_from=image_idx_start_from) - kwargs['collate_fn'] = clevr_collate - super(ClevrDataLoader, self).__init__(self.dataset, **kwargs) - - def close(self): - if self.image_h5 is not None: - self.image_h5.close() - if self.feature_h5 is not None: - self.feature_h5.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() + def __init__(self, **kwargs): + if "question_h5" not in kwargs: + raise ValueError("Must give question_h5") + if "feature_h5" not in kwargs: + raise ValueError("Must give feature_h5") + if "vocab" not in kwargs: + raise ValueError("Must give vocab") + + # Instead of opening hdf5 files here, we just store the file paths. + feature_h5_path = kwargs.pop("feature_h5") + print("Using features from", feature_h5_path) + + image_h5_path = None + if "image_h5" in kwargs: + image_h5_path = kwargs.pop("image_h5") + print("Using images from", image_h5_path) + + vocab = kwargs.pop("vocab") + mode = kwargs.pop("mode", "prefix") + question_families = kwargs.pop("question_families", None) + max_samples = kwargs.pop("max_samples", None) + question_h5_path = kwargs.pop("question_h5") + print("Reading questions from", question_h5_path) + image_idx_start_from = kwargs.pop("image_idx_start_from", None) + self.dataset = ClevrDataset( + question_h5_path, + feature_h5_path, + vocab, + mode, + image_h5_path=image_h5_path, + max_samples=max_samples, + question_families=question_families, + image_idx_start_from=image_idx_start_from, + ) + kwargs["collate_fn"] = clevr_collate + super(ClevrDataLoader, self).__init__(self.dataset, **kwargs) + + def close(self): + if hasattr(self, "dataset"): + self.dataset.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() def clevr_collate(batch): - transposed = list(zip(*batch)) - question_batch = default_collate(transposed[0]) - image_batch = transposed[1] - if any(img is not None for img in image_batch): - image_batch = default_collate(image_batch) - feat_batch = transposed[2] - if any(f is not None for f in feat_batch): - feat_batch = default_collate(feat_batch) - answer_batch = transposed[3] - if transposed[3][0] is not None: - answer_batch = default_collate(transposed[3]) - program_seq_batch = transposed[4] - if transposed[4][0] is not None: - program_seq_batch = default_collate(transposed[4]) - program_struct_batch = transposed[5] - return [question_batch, image_batch, feat_batch, answer_batch, program_seq_batch, program_struct_batch] + transposed = list(zip(*batch)) + question_batch = default_collate(transposed[0]) + image_batch = transposed[1] + if any(img is not None for img in image_batch): + image_batch = default_collate(image_batch) + feat_batch = transposed[2] + if any(f is not None for f in feat_batch): + feat_batch = default_collate(feat_batch) + answer_batch = transposed[3] + if transposed[3][0] is not None: + answer_batch = default_collate(transposed[3]) + program_seq_batch = transposed[4] + if transposed[4][0] is not None: + program_seq_batch = default_collate(transposed[4]) + program_struct_batch = transposed[5] + return [ + question_batch, + image_batch, + feat_batch, + answer_batch, + program_seq_batch, + program_struct_batch, + ] diff --git a/vr/models/filmed_net.py b/vr/models/filmed_net.py index aff9483..285dec5 100644 --- a/vr/models/filmed_net.py +++ b/vr/models/filmed_net.py @@ -16,296 +16,395 @@ class FiLM(nn.Module): - """ - A Feature-wise Linear Modulation Layer from - 'FiLM: Visual Reasoning with a General Conditioning Layer' - """ - def forward(self, x, gammas, betas): - gammas = gammas.unsqueeze(2).unsqueeze(3).expand_as(x) - betas = betas.unsqueeze(2).unsqueeze(3).expand_as(x) - return (gammas * x) + betas + """ + A Feature-wise Linear Modulation Layer from + 'FiLM: Visual Reasoning with a General Conditioning Layer' + """ + + def forward(self, x, gammas, betas): + gammas = gammas.unsqueeze(2).unsqueeze(3).expand_as(x) + betas = betas.unsqueeze(2).unsqueeze(3).expand_as(x) + return (gammas * x) + betas class FiLMedNet(nn.Module): - def __init__(self, vocab, feature_dim=(1024, 14, 14), - stem_num_layers=2, - stem_batchnorm=False, - stem_kernel_size=3, - stem_stride=1, - stem_padding=None, - num_modules=4, - module_num_layers=1, - module_dim=128, - module_residual=True, - module_batchnorm=False, - module_batchnorm_affine=False, - module_dropout=0, - module_input_proj=1, - module_kernel_size=3, - classifier_proj_dim=512, - classifier_downsample='maxpool2', - classifier_fc_layers=(1024,), - classifier_batchnorm=False, - classifier_dropout=0, - condition_method='bn-film', - condition_pattern=[], - use_gamma=True, - use_beta=True, - use_coords=1, - debug_every=float('inf'), - print_verbose_every=float('inf'), - verbose=True, - ): - super(FiLMedNet, self).__init__() - - num_answers = len(vocab['answer_idx_to_token']) - - self.stem_times = [] - self.module_times = [] - self.classifier_times = [] - self.timing = False - - self.num_modules = num_modules - self.module_num_layers = module_num_layers - self.module_batchnorm = module_batchnorm - self.module_dim = module_dim - self.condition_method = condition_method - self.use_gamma = use_gamma - self.use_beta = use_beta - self.use_coords_freq = use_coords - self.debug_every = debug_every - self.print_verbose_every = print_verbose_every - - # Initialize helper variables - self.stem_use_coords = (stem_stride == 1) and (self.use_coords_freq > 0) - self.condition_pattern = condition_pattern - if len(condition_pattern) == 0: - self.condition_pattern = [] - for i in range(self.module_num_layers * self.num_modules): - self.condition_pattern.append(self.condition_method != 'concat') - else: - self.condition_pattern = [i > 0 for i in self.condition_pattern] - self.extra_channel_freq = self.use_coords_freq - self.block = FiLMedResBlock - self.num_cond_maps = 2 * self.module_dim if self.condition_method == 'concat' else 0 - self.fwd_count = 0 - self.num_extra_channels = 2 if self.use_coords_freq > 0 else 0 - if self.debug_every <= -1: - self.print_verbose_every = 1 - module_H = feature_dim[1] // (stem_stride ** stem_num_layers) # Rough calc: work for main cases - module_W = feature_dim[2] // (stem_stride ** stem_num_layers) # Rough calc: work for main cases - self.coords = coord_map((module_H, module_W)) - self.default_weight = Variable(torch.ones(1, 1, self.module_dim)).type(torch.cuda.FloatTensor) - self.default_bias = Variable(torch.zeros(1, 1, self.module_dim)).type(torch.cuda.FloatTensor) - - # Initialize stem - stem_feature_dim = feature_dim[0] + self.stem_use_coords * self.num_extra_channels - self.stem = build_stem(stem_feature_dim, module_dim, - num_layers=stem_num_layers, with_batchnorm=stem_batchnorm, - kernel_size=stem_kernel_size, stride=stem_stride, padding=stem_padding) - - # Initialize FiLMed network body - self.function_modules = {} - self.vocab = vocab - for fn_num in range(self.num_modules): - with_cond = self.condition_pattern[self.module_num_layers * fn_num: - self.module_num_layers * (fn_num + 1)] - mod = self.block(module_dim, with_residual=module_residual, with_batchnorm=module_batchnorm, - with_cond=with_cond, - dropout=module_dropout, - num_extra_channels=self.num_extra_channels, - extra_channel_freq=self.extra_channel_freq, - with_input_proj=module_input_proj, - num_cond_maps=self.num_cond_maps, - kernel_size=module_kernel_size, - batchnorm_affine=module_batchnorm_affine, - num_layers=self.module_num_layers, - condition_method=condition_method, - debug_every=self.debug_every) - self.add_module(str(fn_num), mod) - self.function_modules[fn_num] = mod - - # Initialize output classifier - self.classifier = build_classifier(module_dim + self.num_extra_channels, module_H, module_W, - num_answers, classifier_fc_layers, classifier_proj_dim, - classifier_downsample, with_batchnorm=classifier_batchnorm, - dropout=classifier_dropout) - - init_modules(self.modules()) - - def forward(self, x, film, save_activations=False): - # Initialize forward pass and externally viewable activations - self.fwd_count += 1 - if save_activations: - self.feats = None - self.module_outputs = [] - self.cf_input = None - - if self.debug_every <= -2: - pdb.set_trace() - - # Prepare FiLM layers - gammas = None - betas = None - if self.condition_method == 'concat': - # Use parameters usually used to condition via FiLM instead to condition via concatenation - cond_params = film[:,:,:2*self.module_dim] - cond_maps = cond_params.unsqueeze(3).unsqueeze(4).expand(cond_params.size() + x.size()[-2:]) - else: - gammas, betas = torch.split(film[:,:,:2*self.module_dim], self.module_dim, dim=-1) - if not self.use_gamma: - gammas = self.default_weight.expand_as(gammas) - if not self.use_beta: - betas = self.default_bias.expand_as(betas) - - # Propagate up image features CNN - batch_coords = None - if self.use_coords_freq > 0: - batch_coords = self.coords.unsqueeze(0).expand(torch.Size((x.size(0), *self.coords.size()))) - if self.stem_use_coords: - x = torch.cat([x, batch_coords], 1) - feats = self.stem(x) - if save_activations: - self.feats = feats - N, _, H, W = feats.size() - - # Propagate up the network from low-to-high numbered blocks - module_inputs = Variable(torch.zeros(feats.size()).unsqueeze(1).expand( - N, self.num_modules, self.module_dim, H, W)).type(torch.cuda.FloatTensor) - module_inputs[:,0] = feats - for fn_num in range(self.num_modules): - if self.condition_method == 'concat': - layer_output = self.function_modules[fn_num](module_inputs[:,fn_num], - extra_channels=batch_coords, cond_maps=cond_maps[:,fn_num]) - else: - layer_output = self.function_modules[fn_num](module_inputs[:,fn_num], - gammas[:,fn_num,:], betas[:,fn_num,:], batch_coords) - - # Store for future computation - if save_activations: - self.module_outputs.append(layer_output) - if fn_num == (self.num_modules - 1): - final_module_output = layer_output - else: - module_inputs_updated = module_inputs.clone() - module_inputs_updated[:,fn_num+1] = module_inputs_updated[:,fn_num+1] + layer_output - module_inputs = module_inputs_updated - - if self.debug_every <= -2: - pdb.set_trace() - - # Run the final classifier over the resultant, post-modulated features. - if self.use_coords_freq > 0: - final_module_output = torch.cat([final_module_output, batch_coords], 1) - if save_activations: - self.cf_input = final_module_output - out = self.classifier(final_module_output) - - if ((self.fwd_count % self.debug_every) == 0) or (self.debug_every <= -1): - pdb.set_trace() - return out + def __init__( + self, + vocab, + feature_dim=(1024, 14, 14), + stem_num_layers=2, + stem_batchnorm=False, + stem_kernel_size=3, + stem_stride=1, + stem_padding=None, + num_modules=4, + module_num_layers=1, + module_dim=128, + module_residual=True, + module_batchnorm=False, + module_batchnorm_affine=False, + module_dropout=0, + module_input_proj=1, + module_kernel_size=3, + classifier_proj_dim=512, + classifier_downsample="maxpool2", + classifier_fc_layers=(1024,), + classifier_batchnorm=False, + classifier_dropout=0, + condition_method="bn-film", + condition_pattern=[], + use_gamma=True, + use_beta=True, + use_coords=1, + debug_every=float("inf"), + print_verbose_every=float("inf"), + verbose=True, + ): + super(FiLMedNet, self).__init__() + if torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + + num_answers = len(vocab["answer_idx_to_token"]) + + self.stem_times = [] + self.module_times = [] + self.classifier_times = [] + self.timing = False + + self.num_modules = num_modules + self.module_num_layers = module_num_layers + self.module_batchnorm = module_batchnorm + self.module_dim = module_dim + self.condition_method = condition_method + self.use_gamma = use_gamma + self.use_beta = use_beta + self.use_coords_freq = use_coords + self.debug_every = debug_every + self.print_verbose_every = print_verbose_every + + # Initialize helper variables + self.stem_use_coords = (stem_stride == 1) and (self.use_coords_freq > 0) + self.condition_pattern = condition_pattern + if len(condition_pattern) == 0: + self.condition_pattern = [] + for i in range(self.module_num_layers * self.num_modules): + self.condition_pattern.append(self.condition_method != "concat") + else: + self.condition_pattern = [i > 0 for i in self.condition_pattern] + self.extra_channel_freq = self.use_coords_freq + self.block = FiLMedResBlock + self.num_cond_maps = ( + 2 * self.module_dim if self.condition_method == "concat" else 0 + ) + self.fwd_count = 0 + self.num_extra_channels = 2 if self.use_coords_freq > 0 else 0 + if self.debug_every <= -1: + self.print_verbose_every = 1 + module_H = feature_dim[1] // ( + stem_stride**stem_num_layers + ) # Rough calc: work for main cases + module_W = feature_dim[2] // ( + stem_stride**stem_num_layers + ) # Rough calc: work for main cases + self.coords = coord_map((module_H, module_W), self.device) + # self.default_weight = Variable(torch.ones(1, 1, self.module_dim)).type(torch.cuda.FloatTensor) + self.default_weight = torch.ones( + 1, 1, self.module_dim, device=self.device, dtype=torch.float + ) + # self.default_bias = Variable(torch.zeros(1, 1, self.module_dim)).type( + # torch.cuda.FloatTensor + # ) + self.default_bias = torch.zeros( + 1, 1, self.module_dim, device=self.device, dtype=torch.float + ) + + # Initialize stem + stem_feature_dim = ( + feature_dim[0] + self.stem_use_coords * self.num_extra_channels + ) + self.stem = build_stem( + stem_feature_dim, + module_dim, + num_layers=stem_num_layers, + with_batchnorm=stem_batchnorm, + kernel_size=stem_kernel_size, + stride=stem_stride, + padding=stem_padding, + ) + + # Initialize FiLMed network body + self.function_modules = {} + self.vocab = vocab + for fn_num in range(self.num_modules): + with_cond = self.condition_pattern[ + self.module_num_layers * fn_num : self.module_num_layers * (fn_num + 1) + ] + mod = self.block( + module_dim, + with_residual=module_residual, + with_batchnorm=module_batchnorm, + with_cond=with_cond, + dropout=module_dropout, + num_extra_channels=self.num_extra_channels, + extra_channel_freq=self.extra_channel_freq, + with_input_proj=module_input_proj, + num_cond_maps=self.num_cond_maps, + kernel_size=module_kernel_size, + batchnorm_affine=module_batchnorm_affine, + num_layers=self.module_num_layers, + condition_method=condition_method, + debug_every=self.debug_every, + ) + self.add_module(str(fn_num), mod) + self.function_modules[fn_num] = mod + + # Initialize output classifier + self.classifier = build_classifier( + module_dim + self.num_extra_channels, + module_H, + module_W, + num_answers, + classifier_fc_layers, + classifier_proj_dim, + classifier_downsample, + with_batchnorm=classifier_batchnorm, + dropout=classifier_dropout, + ) + + init_modules(self.modules()) + + def forward(self, x, film, save_activations=False): + # Initialize forward pass and externally viewable activations + self.fwd_count += 1 + if save_activations: + self.feats = None + self.module_outputs = [] + self.cf_input = None + + if self.debug_every <= -2: + pdb.set_trace() + + # Prepare FiLM layers + gammas = None + betas = None + if self.condition_method == "concat": + # Use parameters usually used to condition via FiLM instead to condition via concatenation + cond_params = film[:, :, : 2 * self.module_dim] + cond_maps = ( + cond_params.unsqueeze(3) + .unsqueeze(4) + .expand(cond_params.size() + x.size()[-2:]) + ) + else: + gammas, betas = torch.split( + film[:, :, : 2 * self.module_dim], self.module_dim, dim=-1 + ) + if not self.use_gamma: + gammas = self.default_weight.expand_as(gammas) + if not self.use_beta: + betas = self.default_bias.expand_as(betas) + + # Propagate up image features CNN + batch_coords = None + if self.use_coords_freq > 0: + batch_coords = self.coords.unsqueeze(0).expand( + torch.Size((x.size(0), *self.coords.size())) + ) + if self.stem_use_coords: + x = torch.cat([x, batch_coords], 1) + feats = self.stem(x) + if save_activations: + self.feats = feats + N, _, H, W = feats.size() + + # Propagate up the network from low-to-high numbered blocks + # module_inputs = Variable( + # torch.zeros(feats.size()) + # .unsqueeze(1) + # .expand(N, self.num_modules, self.module_dim, H, W) + # ).type(torch.cuda.FloatTensor) + module_inputs = ( + torch.zeros(feats.size(), device=self.device, dtype=torch.float) + .unsqueeze(1) + .expand(N, self.num_modules, self.module_dim, H, W) + ) + + module_inputs[:, 0] = feats + for fn_num in range(self.num_modules): + if self.condition_method == "concat": + layer_output = self.function_modules[fn_num]( + module_inputs[:, fn_num], + extra_channels=batch_coords, + cond_maps=cond_maps[:, fn_num], + ) + else: + layer_output = self.function_modules[fn_num]( + module_inputs[:, fn_num], + gammas[:, fn_num, :], + betas[:, fn_num, :], + batch_coords, + ) + + # Store for future computation + if save_activations: + self.module_outputs.append(layer_output) + if fn_num == (self.num_modules - 1): + final_module_output = layer_output + else: + module_inputs_updated = module_inputs.clone() + module_inputs_updated[:, fn_num + 1] = ( + module_inputs_updated[:, fn_num + 1] + layer_output + ) + module_inputs = module_inputs_updated + + if self.debug_every <= -2: + pdb.set_trace() + + # Run the final classifier over the resultant, post-modulated features. + if self.use_coords_freq > 0: + final_module_output = torch.cat([final_module_output, batch_coords], 1) + if save_activations: + self.cf_input = final_module_output + out = self.classifier(final_module_output) + + if ((self.fwd_count % self.debug_every) == 0) or (self.debug_every <= -1): + pdb.set_trace() + return out class FiLMedResBlock(nn.Module): - def __init__(self, in_dim, out_dim=None, with_residual=True, with_batchnorm=True, - with_cond=[False], dropout=0, num_extra_channels=0, extra_channel_freq=1, - with_input_proj=0, num_cond_maps=0, kernel_size=3, batchnorm_affine=False, - num_layers=1, condition_method='bn-film', debug_every=float('inf')): - if out_dim is None: - out_dim = in_dim - super(FiLMedResBlock, self).__init__() - self.with_residual = with_residual - self.with_batchnorm = with_batchnorm - self.with_cond = with_cond - self.dropout = dropout - self.extra_channel_freq = 0 if num_extra_channels == 0 else extra_channel_freq - self.with_input_proj = with_input_proj # Kernel size of input projection - self.num_cond_maps = num_cond_maps - self.kernel_size = kernel_size - self.batchnorm_affine = batchnorm_affine - self.num_layers = num_layers - self.condition_method = condition_method - self.debug_every = debug_every - - if self.with_input_proj % 2 == 0: - raise(NotImplementedError) - if self.kernel_size % 2 == 0: - raise(NotImplementedError) - if self.num_layers >= 2: - raise(NotImplementedError) - - if self.condition_method == 'block-input-film' and self.with_cond[0]: - self.film = FiLM() - if self.with_input_proj: - self.input_proj = nn.Conv2d(in_dim + (num_extra_channels if self.extra_channel_freq >= 1 else 0), - in_dim, kernel_size=self.with_input_proj, padding=self.with_input_proj // 2) - - self.conv1 = nn.Conv2d(in_dim + self.num_cond_maps + - (num_extra_channels if self.extra_channel_freq >= 2 else 0), - out_dim, kernel_size=self.kernel_size, - padding=self.kernel_size // 2) - if self.condition_method == 'conv-film' and self.with_cond[0]: - self.film = FiLM() - if self.with_batchnorm: - self.bn1 = nn.BatchNorm2d(out_dim, affine=((not self.with_cond[0]) or self.batchnorm_affine)) - if self.condition_method == 'bn-film' and self.with_cond[0]: - self.film = FiLM() - if dropout > 0: - self.drop = nn.Dropout2d(p=self.dropout) - if ((self.condition_method == 'relu-film' or self.condition_method == 'block-output-film') - and self.with_cond[0]): - self.film = FiLM() - - init_modules(self.modules()) - - def forward(self, x, gammas=None, betas=None, extra_channels=None, cond_maps=None): - if self.debug_every <= -2: - pdb.set_trace() - - if self.condition_method == 'block-input-film' and self.with_cond[0]: - x = self.film(x, gammas, betas) - - # ResBlock input projection - if self.with_input_proj: - if extra_channels is not None and self.extra_channel_freq >= 1: - x = torch.cat([x, extra_channels], 1) - x = F.relu(self.input_proj(x)) - out = x - - # ResBlock body - if cond_maps is not None: - out = torch.cat([out, cond_maps], 1) - if extra_channels is not None and self.extra_channel_freq >= 2: - out = torch.cat([out, extra_channels], 1) - out = self.conv1(out) - if self.condition_method == 'conv-film' and self.with_cond[0]: - out = self.film(out, gammas, betas) - if self.with_batchnorm: - out = self.bn1(out) - if self.condition_method == 'bn-film' and self.with_cond[0]: - out = self.film(out, gammas, betas) - if self.dropout > 0: - out = self.drop(out) - out = F.relu(out) - if self.condition_method == 'relu-film' and self.with_cond[0]: - out = self.film(out, gammas, betas) - - # ResBlock remainder - if self.with_residual: - out = x + out - if self.condition_method == 'block-output-film' and self.with_cond[0]: - out = self.film(out, gammas, betas) - return out - - -def coord_map(shape, start=-1, end=1): - """ - Gives, a 2d shape tuple, returns two mxn coordinate maps, - Ranging min-max in the x and y directions, respectively. - """ - m, n = shape - x_coord_row = torch.linspace(start, end, steps=n).type(torch.cuda.FloatTensor) - y_coord_row = torch.linspace(start, end, steps=m).type(torch.cuda.FloatTensor) - x_coords = x_coord_row.unsqueeze(0).expand(torch.Size((m, n))).unsqueeze(0) - y_coords = y_coord_row.unsqueeze(1).expand(torch.Size((m, n))).unsqueeze(0) - return Variable(torch.cat([x_coords, y_coords], 0)) + def __init__( + self, + in_dim, + out_dim=None, + with_residual=True, + with_batchnorm=True, + with_cond=[False], + dropout=0, + num_extra_channels=0, + extra_channel_freq=1, + with_input_proj=0, + num_cond_maps=0, + kernel_size=3, + batchnorm_affine=False, + num_layers=1, + condition_method="bn-film", + debug_every=float("inf"), + ): + if out_dim is None: + out_dim = in_dim + super(FiLMedResBlock, self).__init__() + self.with_residual = with_residual + self.with_batchnorm = with_batchnorm + self.with_cond = with_cond + self.dropout = dropout + self.extra_channel_freq = 0 if num_extra_channels == 0 else extra_channel_freq + self.with_input_proj = with_input_proj # Kernel size of input projection + self.num_cond_maps = num_cond_maps + self.kernel_size = kernel_size + self.batchnorm_affine = batchnorm_affine + self.num_layers = num_layers + self.condition_method = condition_method + self.debug_every = debug_every + + if self.with_input_proj % 2 == 0: + raise (NotImplementedError) + if self.kernel_size % 2 == 0: + raise (NotImplementedError) + if self.num_layers >= 2: + raise (NotImplementedError) + + if self.condition_method == "block-input-film" and self.with_cond[0]: + self.film = FiLM() + if self.with_input_proj: + self.input_proj = nn.Conv2d( + in_dim + (num_extra_channels if self.extra_channel_freq >= 1 else 0), + in_dim, + kernel_size=self.with_input_proj, + padding=self.with_input_proj // 2, + ) + + self.conv1 = nn.Conv2d( + in_dim + + self.num_cond_maps + + (num_extra_channels if self.extra_channel_freq >= 2 else 0), + out_dim, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + ) + if self.condition_method == "conv-film" and self.with_cond[0]: + self.film = FiLM() + if self.with_batchnorm: + self.bn1 = nn.BatchNorm2d( + out_dim, affine=((not self.with_cond[0]) or self.batchnorm_affine) + ) + if self.condition_method == "bn-film" and self.with_cond[0]: + self.film = FiLM() + if dropout > 0: + self.drop = nn.Dropout2d(p=self.dropout) + if ( + self.condition_method == "relu-film" + or self.condition_method == "block-output-film" + ) and self.with_cond[0]: + self.film = FiLM() + + init_modules(self.modules()) + + def forward(self, x, gammas=None, betas=None, extra_channels=None, cond_maps=None): + if self.debug_every <= -2: + pdb.set_trace() + + if self.condition_method == "block-input-film" and self.with_cond[0]: + x = self.film(x, gammas, betas) + + # ResBlock input projection + if self.with_input_proj: + if extra_channels is not None and self.extra_channel_freq >= 1: + x = torch.cat([x, extra_channels], 1) + x = F.relu(self.input_proj(x)) + out = x + + # ResBlock body + if cond_maps is not None: + out = torch.cat([out, cond_maps], 1) + if extra_channels is not None and self.extra_channel_freq >= 2: + out = torch.cat([out, extra_channels], 1) + out = self.conv1(out) + if self.condition_method == "conv-film" and self.with_cond[0]: + out = self.film(out, gammas, betas) + if self.with_batchnorm: + out = self.bn1(out) + if self.condition_method == "bn-film" and self.with_cond[0]: + out = self.film(out, gammas, betas) + if self.dropout > 0: + out = self.drop(out) + out = F.relu(out) + if self.condition_method == "relu-film" and self.with_cond[0]: + out = self.film(out, gammas, betas) + + # ResBlock remainder + if self.with_residual: + out = x + out + if self.condition_method == "block-output-film" and self.with_cond[0]: + out = self.film(out, gammas, betas) + return out + + +def coord_map(shape, device, start=-1, end=1): + """ + Gives, a 2d shape tuple, returns two mxn coordinate maps, + Ranging min-max in the x and y directions, respectively. + """ + m, n = shape + # x_coord_row = torch.linspace(start, end, steps=n).type(torch.cuda.FloatTensor) + # y_coord_row = torch.linspace(start, end, steps=m).type(torch.cuda.FloatTensor) + x_coord_row = torch.linspace(start, end, steps=n, device=device, dtype=torch.float) + y_coord_row = torch.linspace(start, end, steps=m, device=device, dtype=torch.float) + + x_coords = x_coord_row.unsqueeze(0).expand(torch.Size((m, n))).unsqueeze(0) + y_coords = y_coord_row.unsqueeze(1).expand(torch.Size((m, n))).unsqueeze(0) + return Variable(torch.cat([x_coords, y_coords], 0)) diff --git a/vr/utils.py b/vr/utils.py index 18ab122..fd94d9a 100644 --- a/vr/utils.py +++ b/vr/utils.py @@ -15,107 +15,111 @@ from vr.models import FiLMedNet from vr.models import FiLMGen + def invert_dict(d): - return {v: k for k, v in d.items()} + return {v: k for k, v in d.items()} def load_vocab(path): - with open(path, 'r') as f: - vocab = json.load(f) - vocab['question_idx_to_token'] = invert_dict(vocab['question_token_to_idx']) - vocab['program_idx_to_token'] = invert_dict(vocab['program_token_to_idx']) - vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) - # Sanity check: make sure , , and are consistent - assert vocab['question_token_to_idx'][''] == 0 - assert vocab['question_token_to_idx'][''] == 1 - assert vocab['question_token_to_idx'][''] == 2 - assert vocab['program_token_to_idx'][''] == 0 - assert vocab['program_token_to_idx'][''] == 1 - assert vocab['program_token_to_idx'][''] == 2 - return vocab + with open(path, "r") as f: + vocab = json.load(f) + vocab["question_idx_to_token"] = invert_dict(vocab["question_token_to_idx"]) + vocab["program_idx_to_token"] = invert_dict(vocab["program_token_to_idx"]) + vocab["answer_idx_to_token"] = invert_dict(vocab["answer_token_to_idx"]) + # Sanity check: make sure , , and are consistent + assert vocab["question_token_to_idx"][""] == 0 + assert vocab["question_token_to_idx"][""] == 1 + assert vocab["question_token_to_idx"][""] == 2 + assert vocab["program_token_to_idx"][""] == 0 + assert vocab["program_token_to_idx"][""] == 1 + assert vocab["program_token_to_idx"][""] == 2 + return vocab def load_cpu(path): - """ - Loads a torch checkpoint, remapping all Tensors to CPU - """ - return torch.load(path, map_location=lambda storage, loc: storage) - - -def load_program_generator(path, model_type='PG+EE'): - checkpoint = load_cpu(path) - kwargs = checkpoint['program_generator_kwargs'] - state = checkpoint['program_generator_state'] - if model_type == 'FiLM': - print('Loading FiLMGen from ' + path) - kwargs = get_updated_args(kwargs, FiLMGen) - model = FiLMGen(**kwargs) - else: - print('Loading PG from ' + path) - model = Seq2Seq(**kwargs) - model.load_state_dict(state) - return model, kwargs - - -def load_execution_engine(path, verbose=True, model_type='PG+EE'): - checkpoint = load_cpu(path) - kwargs = checkpoint['execution_engine_kwargs'] - state = checkpoint['execution_engine_state'] - kwargs['verbose'] = verbose - if model_type == 'FiLM': - print('Loading FiLMedNet from ' + path) - kwargs = get_updated_args(kwargs, FiLMedNet) - model = FiLMedNet(**kwargs) - else: - print('Loading EE from ' + path) - model = ModuleNet(**kwargs) - cur_state = model.state_dict() - model.load_state_dict(state) - return model, kwargs + """ + Loads a torch checkpoint, remapping all Tensors to CPU + """ + return torch.load(path, map_location=lambda storage, loc: storage) + + +def load_program_generator(path, model_type="PG+EE"): + checkpoint = load_cpu(path) + kwargs = checkpoint["program_generator_kwargs"] + state = checkpoint["program_generator_state"] + if model_type == "FiLM": + # print("Loading FiLMGen from " + path) + kwargs = get_updated_args(kwargs, FiLMGen) + model = FiLMGen(**kwargs) + else: + print("Loading PG from " + path) + model = Seq2Seq(**kwargs) + model.load_state_dict(state) + return model, kwargs + + +def load_execution_engine(path, verbose=True, model_type="PG+EE"): + checkpoint = load_cpu(path) + kwargs = checkpoint["execution_engine_kwargs"] + state = checkpoint["execution_engine_state"] + kwargs["verbose"] = verbose + if model_type == "FiLM": + # print("Loading FiLMedNet from " + path) + kwargs = get_updated_args(kwargs, FiLMedNet) + model = FiLMedNet(**kwargs) + else: + print("Loading EE from " + path) + model = ModuleNet(**kwargs) + cur_state = model.state_dict() # noqa: F841 + model.load_state_dict(state) + return model, kwargs def load_baseline(path): - model_cls_dict = { - 'LSTM': LstmModel, - 'CNN+LSTM': CnnLstmModel, - 'CNN+LSTM+SA': CnnLstmSaModel, - } - checkpoint = load_cpu(path) - baseline_type = checkpoint['baseline_type'] - kwargs = checkpoint['baseline_kwargs'] - state = checkpoint['baseline_state'] + model_cls_dict = { + "LSTM": LstmModel, + "CNN+LSTM": CnnLstmModel, + "CNN+LSTM+SA": CnnLstmSaModel, + } + checkpoint = load_cpu(path) + baseline_type = checkpoint["baseline_type"] + kwargs = checkpoint["baseline_kwargs"] + state = checkpoint["baseline_state"] - model = model_cls_dict[baseline_type](**kwargs) - model.load_state_dict(state) - return model, kwargs + model = model_cls_dict[baseline_type](**kwargs) + model.load_state_dict(state) + return model, kwargs def get_updated_args(kwargs, object_class): - """ - Returns kwargs with renamed args or arg valuesand deleted, deprecated, unused args. - Useful for loading older, trained models. - If using this function is neccessary, use immediately before initializing object. - """ - # Update arg values - for arg in arg_value_updates: - if arg in kwargs and kwargs[arg] in arg_value_updates[arg]: - kwargs[arg] = arg_value_updates[arg][kwargs[arg]] - - # Delete deprecated, unused args - valid_args = inspect.getargspec(object_class.__init__)[0] - new_kwargs = {valid_arg: kwargs[valid_arg] for valid_arg in valid_args if valid_arg in kwargs} - return new_kwargs + """ + Returns kwargs with renamed args or arg valuesand deleted, deprecated, unused args. + Useful for loading older, trained models. + If using this function is neccessary, use immediately before initializing object. + """ + # Update arg values + for arg in arg_value_updates: + if arg in kwargs and kwargs[arg] in arg_value_updates[arg]: + kwargs[arg] = arg_value_updates[arg][kwargs[arg]] + + # Delete deprecated, unused args + # valid_args = inspect.getargspec(object_class.__init__)[0] + valid_args = inspect.getfullargspec(object_class.__init__).args + new_kwargs = { + valid_arg: kwargs[valid_arg] for valid_arg in valid_args if valid_arg in kwargs + } + return new_kwargs arg_value_updates = { - 'condition_method': { - 'block-input-fac': 'block-input-film', - 'block-output-fac': 'block-output-film', - 'cbn': 'bn-film', - 'conv-fac': 'conv-film', - 'relu-fac': 'relu-film', - }, - 'module_input_proj': { - True: 1, - }, + "condition_method": { + "block-input-fac": "block-input-film", + "block-output-fac": "block-output-film", + "cbn": "bn-film", + "conv-fac": "conv-film", + "relu-fac": "relu-film", + }, + "module_input_proj": { + True: 1, + }, }