|
8 | 8 | from torch.utils.data import DataLoader
|
9 | 9 | from tqdm import tqdm
|
10 | 10 | import yaml
|
| 11 | +from bisect import bisect |
11 | 12 |
|
12 | 13 | from visdialch.data.dataset import VisDialDataset
|
13 | 14 | from visdialch.encoders import Encoder
|
|
19 | 20 |
|
20 | 21 | parser = argparse.ArgumentParser()
|
21 | 22 | parser.add_argument(
|
22 |
| - "--config-yml", default="configs/lf_disc_faster_rcnn_x101_bs32.yml", |
| 23 | + "--config-yml", default="configs/lf_disc_faster_rcnn_x101.yml", |
23 | 24 | help="Path to a config file listing reader, model and solver parameters."
|
24 | 25 | )
|
25 | 26 | parser.add_argument(
|
|
76 | 77 | torch.backends.cudnn.benchmark = False
|
77 | 78 | torch.backends.cudnn.deterministic = True
|
78 | 79 |
|
| 80 | + |
79 | 81 | # ================================================================================================
|
80 | 82 | # INPUT ARGUMENTS AND CONFIG
|
81 | 83 | # ================================================================================================
|
|
95 | 97 |
|
96 | 98 |
|
97 | 99 | # ================================================================================================
|
98 |
| -# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER |
| 100 | +# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER |
99 | 101 | # ================================================================================================
|
100 | 102 |
|
101 | 103 | train_dataset = VisDialDataset(
|
102 | 104 | config["dataset"], args.train_json, overfit=args.overfit, in_memory=args.in_memory
|
103 | 105 | )
|
104 | 106 | train_dataloader = DataLoader(
|
105 |
| - train_dataset, batch_size=config["solver"]["batch_size"], num_workers=args.cpu_workers |
| 107 | + train_dataset, batch_size=config["solver"]["batch_size"], num_workers=args.cpu_workers, shuffle=True |
106 | 108 | )
|
107 | 109 |
|
108 | 110 | val_dataset = VisDialDataset(
|
|
126 | 128 | if -1 not in args.gpu_ids:
|
127 | 129 | model = nn.DataParallel(model, args.gpu_ids)
|
128 | 130 |
|
| 131 | +# Loss function. |
129 | 132 | criterion = nn.CrossEntropyLoss()
|
130 |
| -optimizer = optim.Adam(model.parameters(), lr=config["solver"]["initial_lr"]) |
131 |
| -scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=config["solver"]["lr_gamma"]) |
| 133 | + |
| 134 | +if config["solver"]["training_splits"] == "trainval": |
| 135 | + iterations = (len(train_dataset) + len(val_dataset)) // config["solver"]["batch_size"] + 1 |
| 136 | +else: |
| 137 | + iterations = len(train_dataset) // config["solver"]["batch_size"] + 1 |
| 138 | + |
| 139 | + |
| 140 | +def lr_lambda_fun(current_iteration: int) -> float: |
| 141 | + """Returns a learning rate multiplier. |
| 142 | +
|
| 143 | + Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, |
| 144 | + and then gets multiplied by `lr_gamma` every time a milestone is crossed. |
| 145 | + """ |
| 146 | + current_epoch = float(current_iteration) / iterations |
| 147 | + if current_epoch <= config["solver"]["warmup_epochs"]: |
| 148 | + alpha = current_epoch / float(config["solver"]["warmup_epochs"]) |
| 149 | + return config["solver"]["warmup_factor"] * (1. - alpha) + alpha |
| 150 | + else: |
| 151 | + idx = bisect(config["solver"]["lr_milestones"], current_epoch) |
| 152 | + return pow(config["solver"]["lr_gamma"], idx) |
| 153 | + |
| 154 | +optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"]) |
| 155 | +scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) |
132 | 156 |
|
133 | 157 |
|
134 | 158 | # ================================================================================================
|
|
159 | 183 | # TRAINING LOOP
|
160 | 184 | # ================================================================================================
|
161 | 185 |
|
162 |
| -# Forever increasing counter keeping track of iterations completed. |
163 |
| -if config["solver"]["training_splits"] == "trainval": |
164 |
| - iterations = (len(train_dataset) + len(val_dataset)) // config["solver"]["batch_size"] + 1 |
165 |
| -else: |
166 |
| - iterations = len(train_dataset) // config["solver"]["batch_size"] + 1 |
167 |
| - |
| 186 | +# Forever increasing counter keeping track of iterations completed (for tensorboard logging). |
168 | 187 | global_iteration_step = start_epoch * iterations
|
169 |
| -for epoch in range(start_epoch, config["solver"]["num_epochs"] + 1): |
| 188 | + |
| 189 | +for epoch in range(start_epoch, config["solver"]["num_epochs"]): |
170 | 190 |
|
171 | 191 | # --------------------------------------------------------------------------------------------
|
172 | 192 | # ON EPOCH START (combine dataloaders if training on train + val)
|
|
189 | 209 |
|
190 | 210 | summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step)
|
191 | 211 | summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step)
|
192 |
| - if optimizer.param_groups[0]["lr"] > config["solver"]["minimum_lr"]: |
193 |
| - scheduler.step() |
| 212 | + |
| 213 | + scheduler.step(global_iteration_step) |
194 | 214 | global_iteration_step += 1
|
| 215 | + torch.cuda.empty_cache() |
195 | 216 |
|
196 | 217 | # --------------------------------------------------------------------------------------------
|
197 | 218 | # ON EPOCH END (checkpointing and validation)
|
|
200 | 221 |
|
201 | 222 | # Validate and report automatic metrics.
|
202 | 223 | if args.validate:
|
| 224 | + |
| 225 | + # Switch dropout, batchnorm etc to the correct mode. |
| 226 | + model.eval() |
| 227 | + |
203 | 228 | print(f"\nValidation after epoch {epoch}:")
|
204 | 229 | for i, batch in enumerate(tqdm(val_dataloader)):
|
205 | 230 | for key in batch:
|
|
217 | 242 | for metric_name, metric_value in all_metrics.items():
|
218 | 243 | print(f"{metric_name}: {metric_value}")
|
219 | 244 | summary_writer.add_scalars("metrics", all_metrics, global_iteration_step)
|
| 245 | + |
| 246 | + model.train() |
| 247 | + torch.cuda.empty_cache() |
0 commit comments