Skip to content

Commit de30951

Browse files
author
Karan Desai
authored
Bug fixes and tweaks for a stronger baseline
2 parents feba6de + 711689c commit de30951

File tree

5 files changed

+66
-25
lines changed

5 files changed

+66
-25
lines changed

configs/lf_disc_faster_rcnn_x101_bs32.yml configs/lf_disc_faster_rcnn_x101.yml

+9-5
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@ model:
2525

2626
# Optimization related arguments
2727
solver:
28-
batch_size: 32
28+
batch_size: 128 # 32 x num_gpus is a good rule of thumb
2929
num_epochs: 20
30-
initial_lr: 0.001
31-
lr_gamma: 0.9997592083
32-
minimum_lr: 0.00005
30+
initial_lr: 0.01
3331
training_splits: "train" # "trainval"
34-
32+
lr_gamma: 0.1
33+
lr_milestones: # epochs when lr —> lr * lr_gamma
34+
- 4
35+
- 7
36+
- 10
37+
warmup_factor: 0.2
38+
warmup_epochs: 1

evaluate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
parser = argparse.ArgumentParser("Evaluate and/or generate EvalAI submission file.")
2020
parser.add_argument(
21-
"--config-yml", default="configs/lf_disc_faster_rcnn_x101_bs32.yml",
21+
"--config-yml", default="configs/lf_disc_faster_rcnn_x101.yml",
2222
help="Path to a config file listing reader, model and optimization parameters."
2323
)
2424
parser.add_argument(

train.py

+42-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.utils.data import DataLoader
99
from tqdm import tqdm
1010
import yaml
11+
from bisect import bisect
1112

1213
from visdialch.data.dataset import VisDialDataset
1314
from visdialch.encoders import Encoder
@@ -19,7 +20,7 @@
1920

2021
parser = argparse.ArgumentParser()
2122
parser.add_argument(
22-
"--config-yml", default="configs/lf_disc_faster_rcnn_x101_bs32.yml",
23+
"--config-yml", default="configs/lf_disc_faster_rcnn_x101.yml",
2324
help="Path to a config file listing reader, model and solver parameters."
2425
)
2526
parser.add_argument(
@@ -76,6 +77,7 @@
7677
torch.backends.cudnn.benchmark = False
7778
torch.backends.cudnn.deterministic = True
7879

80+
7981
# ================================================================================================
8082
# INPUT ARGUMENTS AND CONFIG
8183
# ================================================================================================
@@ -95,14 +97,14 @@
9597

9698

9799
# ================================================================================================
98-
# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER
100+
# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER
99101
# ================================================================================================
100102

101103
train_dataset = VisDialDataset(
102104
config["dataset"], args.train_json, overfit=args.overfit, in_memory=args.in_memory
103105
)
104106
train_dataloader = DataLoader(
105-
train_dataset, batch_size=config["solver"]["batch_size"], num_workers=args.cpu_workers
107+
train_dataset, batch_size=config["solver"]["batch_size"], num_workers=args.cpu_workers, shuffle=True
106108
)
107109

108110
val_dataset = VisDialDataset(
@@ -126,9 +128,31 @@
126128
if -1 not in args.gpu_ids:
127129
model = nn.DataParallel(model, args.gpu_ids)
128130

131+
# Loss function.
129132
criterion = nn.CrossEntropyLoss()
130-
optimizer = optim.Adam(model.parameters(), lr=config["solver"]["initial_lr"])
131-
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=config["solver"]["lr_gamma"])
133+
134+
if config["solver"]["training_splits"] == "trainval":
135+
iterations = (len(train_dataset) + len(val_dataset)) // config["solver"]["batch_size"] + 1
136+
else:
137+
iterations = len(train_dataset) // config["solver"]["batch_size"] + 1
138+
139+
140+
def lr_lambda_fun(current_iteration: int) -> float:
141+
"""Returns a learning rate multiplier.
142+
143+
Till `warmup_epochs`, learning rate linearly increases to `initial_lr`,
144+
and then gets multiplied by `lr_gamma` every time a milestone is crossed.
145+
"""
146+
current_epoch = float(current_iteration) / iterations
147+
if current_epoch <= config["solver"]["warmup_epochs"]:
148+
alpha = current_epoch / float(config["solver"]["warmup_epochs"])
149+
return config["solver"]["warmup_factor"] * (1. - alpha) + alpha
150+
else:
151+
idx = bisect(config["solver"]["lr_milestones"], current_epoch)
152+
return pow(config["solver"]["lr_gamma"], idx)
153+
154+
optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"])
155+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun)
132156

133157

134158
# ================================================================================================
@@ -159,14 +183,10 @@
159183
# TRAINING LOOP
160184
# ================================================================================================
161185

162-
# Forever increasing counter keeping track of iterations completed.
163-
if config["solver"]["training_splits"] == "trainval":
164-
iterations = (len(train_dataset) + len(val_dataset)) // config["solver"]["batch_size"] + 1
165-
else:
166-
iterations = len(train_dataset) // config["solver"]["batch_size"] + 1
167-
186+
# Forever increasing counter keeping track of iterations completed (for tensorboard logging).
168187
global_iteration_step = start_epoch * iterations
169-
for epoch in range(start_epoch, config["solver"]["num_epochs"] + 1):
188+
189+
for epoch in range(start_epoch, config["solver"]["num_epochs"]):
170190

171191
# --------------------------------------------------------------------------------------------
172192
# ON EPOCH START (combine dataloaders if training on train + val)
@@ -189,9 +209,10 @@
189209

190210
summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step)
191211
summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step)
192-
if optimizer.param_groups[0]["lr"] > config["solver"]["minimum_lr"]:
193-
scheduler.step()
212+
213+
scheduler.step(global_iteration_step)
194214
global_iteration_step += 1
215+
torch.cuda.empty_cache()
195216

196217
# --------------------------------------------------------------------------------------------
197218
# ON EPOCH END (checkpointing and validation)
@@ -200,6 +221,10 @@
200221

201222
# Validate and report automatic metrics.
202223
if args.validate:
224+
225+
# Switch dropout, batchnorm etc to the correct mode.
226+
model.eval()
227+
203228
print(f"\nValidation after epoch {epoch}:")
204229
for i, batch in enumerate(tqdm(val_dataloader)):
205230
for key in batch:
@@ -217,3 +242,6 @@
217242
for metric_name, metric_value in all_metrics.items():
218243
print(f"{metric_name}: {metric_value}")
219244
summary_writer.add_scalars("metrics", all_metrics, global_iteration_step)
245+
246+
model.train()
247+
torch.cuda.empty_cache()

visdialch/decoders/disc.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def __init__(self, config, vocabulary):
1414
padding_idx=vocabulary.PAD_INDEX)
1515
self.option_rnn = nn.LSTM(config["word_embedding_size"],
1616
config["lstm_hidden_size"],
17-
batch_first=True)
17+
config["lstm_num_layers"],
18+
batch_first=True,
19+
dropout=config["dropout"])
1820

1921
# Options are variable length padded sequences, use DynamicRNN.
2022
self.option_rnn = DynamicRNN(self.option_rnn)

visdialch/encoders/lf.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def __init__(self, config, vocabulary):
3939
config["img_feature_size"], config["lstm_hidden_size"]
4040
)
4141

42+
# fc layer for image * question to attention weights
43+
self.attention_proj = nn.Linear(config["lstm_hidden_size"], 1)
44+
4245
# fusion layer (attended_image_features + question + history)
4346
fusion_size = config["img_feature_size"] + config["lstm_hidden_size"] * 2
4447
self.fusion = nn.Linear(fusion_size, config["lstm_hidden_size"])
@@ -78,10 +81,14 @@ def forward(self, batch):
7881
batch_size * num_rounds, -1, self.config["lstm_hidden_size"]
7982
)
8083

81-
# attend the features using question
84+
# computing attention weights
8285
# shape: (batch_size * num_rounds, num_proposals)
83-
image_attention_weights = projected_image_features.bmm(
84-
ques_embed.unsqueeze(-1)).squeeze()
86+
projected_ques_features = ques_embed.unsqueeze(1).repeat(
87+
1, img.shape[1], 1)
88+
projected_ques_image = projected_ques_features * projected_image_features
89+
projected_ques_image = self.dropout(projected_ques_image)
90+
image_attention_weights = self.attention_proj(
91+
projected_ques_image).squeeze()
8592
image_attention_weights = F.softmax(image_attention_weights, dim=-1)
8693

8794
# shape: (batch_size * num_rounds, num_proposals, img_features_size)
@@ -105,7 +112,7 @@ def forward(self, batch):
105112
hist_embed = self.word_embed(hist)
106113

107114
# shape: (batch_size * num_rounds, lstm_hidden_size)
108-
_ , (hist_embed, _) = self.hist_rnn(hist_embed, batch["hist_len"])
115+
_, (hist_embed, _) = self.hist_rnn(hist_embed, batch["hist_len"])
109116

110117
fused_vector = torch.cat((img, ques_embed, hist_embed), 1)
111118
fused_vector = self.dropout(fused_vector)

0 commit comments

Comments
 (0)