@@ -45,3 +45,103 @@ To pretrain a ViLT model from scratch on the COCO dataset,
4545```
4646mmf_run config=projects/vilt/configs/masked_coco/pretrain.yaml run_type=train_val dataset=masked_coco model=vilt
4747```
48+
49+ ## Using the ViLT model from code
50+ Here is an example of running the ViLT model from code, to do visual question answering (vqa) on a raw image and text.
51+ The forward pass takes ~ 15ms which is very fast compared to UNITER's ~ 600ms.
52+
53+ ``` python
54+ from argparse import Namespace
55+
56+ import torch
57+ from mmf.common.sample import SampleList
58+ from mmf.datasets.processors.bert_processors import VILTTextTokenizer
59+ from mmf.datasets.processors.image_processors import VILTImageProcessor
60+ from mmf.utils.build import build_model
61+ from mmf.utils.configuration import Configuration, load_yaml
62+ from mmf.utils.general import get_current_device
63+ from mmf.utils.text import VocabDict
64+ from omegaconf import OmegaConf
65+ from PIL import Image
66+ ```
67+
68+ A way to make model configs and instantiate the ViLT model.
69+ ``` python
70+ # make model config for vilt vqa2
71+ model_name = " vilt"
72+ config_args = Namespace(
73+ config_override = None ,
74+ opts = [" model=vilt" , " dataset=vqa2" , " config=configs/defaults.yaml" ],
75+ )
76+ default_config = Configuration(config_args).get_config()
77+ model_vqa_config = load_yaml(
78+ " /private/home/your/path/to/mmf/projects/vilt/configs/vqa2/defaults.yaml"
79+ )
80+ config = OmegaConf.merge(default_config, model_vqa_config)
81+ OmegaConf.resolve(config)
82+ model_config = config.model_config[model_name]
83+ model_config.model = model_name
84+ vilt_model = build_model(model_config)
85+ ```
86+
87+ Load model weights, ` model_checkpoint_path ` is the model checkpoint downloaded at model zoo path ` vilt.vqa ` ,
88+ with current url ` s3://dl.fbaipublicfiles.com/mmf/data/models/vilt/vilt.finetuned.vqa2.tar.gz `
89+ ``` python
90+ # build model and load weights
91+ model_checkpoint_path = ' ./vilt_vqa2.pth'
92+ state_dict = torch.load(model_checkpoint_path)
93+ vilt_model.load_state_dict(state_dict, strict = False )
94+ vilt_model.eval()
95+ vilt_model = vilt_model.to(get_current_device())
96+ ```
97+
98+ Prepare input image and text.
99+ This example is using an image of a man with a hat kissing his daughter.
100+ The text is the question posed to the ViLT model for visual question answering.
101+ ``` python
102+ # get image input
103+ image_processor = VILTImageProcessor({" size" : [384 , 384 ]})
104+ image_path = " ./kissing_image.jpg"
105+ raw_img = Image.open(image_path).convert(" RGB" )
106+ image = image_processor(raw_img)
107+
108+ # get text input
109+ text_tokenizer = VILTTextTokenizer({})
110+ question = " What is on his head?"
111+ processed_text_dict = text_tokenizer({" text" : question})
112+ ```
113+
114+ Wrap everything up in a sample list as expected by the ViLT BaseModel.
115+ ``` python
116+ # make batch inputs
117+ sample_dict = {** processed_text_dict, " image" : image}
118+ sample_dict = {
119+ k: v.unsqueeze(0 ) for k, v in sample_dict.items() if isinstance (v, torch.Tensor)
120+ }
121+ sample_dict[" targets" ] = torch.zeros((1 , 3129 ))
122+ sample_dict[" targets" ][0 ,1358 ] = 1
123+ sample_dict[" dataset_name" ] = " vqa2"
124+ sample_dict[" dataset_type" ] = " test"
125+ sample_list = SampleList(sample_dict).to(get_current_device())
126+ ```
127+
128+ Load the vqa answer -> word string map to understand what it says!
129+ Currently file url at ` s3://dl.fbaipublicfiles.com/mmf/data/datasets/vqa2/defaults/extras/vocabs/answers_vqa.txt `
130+ ``` python
131+ # load vqa2 id -> answers
132+ vocab_file_path = " /private/home/path/to/answers_vqa.txt"
133+ answer_vocab = VocabDict(vocab_file_path)
134+ ```
135+
136+ And heres the part you've been waiting for!
137+ ``` python
138+ # do prediction
139+ with torch.no_grad():
140+ vqa_logits = vilt_model(sample_list)[" scores" ]
141+ answer_id = vqa_logits.argmax().item()
142+ answer = answer_vocab.idx2word(answer_id)
143+ print (chr (27 ) + " [2J" ) # clear the terminal
144+ print (f " { question} : { answer} " )
145+ ```
146+
147+ Expected output ` What is on his head?: hat `
0 commit comments