-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtrain_alphas.py
More file actions
175 lines (139 loc) · 6.74 KB
/
train_alphas.py
File metadata and controls
175 lines (139 loc) · 6.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import argparse
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import *
import random
import time
import os
import sys
from common.datasets import get_data_loaders
from common import utils
from common.trainer import Trainer
from common.callback import *
from common.losses import *
from thrifty.models import get_model, get_model_exact_params
"""
Callback for exponentially increasing thhe temperature of the alpha loss
"""
class AlphaCallback(Callback):
def __init__(self, alph):
self.alph = alph
def callOnEndForward(self, trainer):
trainer.temperature *= (1 + self.alph)
class AlphaLoss(LossFun):
name = "AlphaLoss"
def call(self, output, target, trainer):
temp = trainer.temperature
x = trainer.model.Lblock.alpha.data
loss = x*x*(1-x)*(1-x)
loss = torch.sum(temp*loss)
return loss
if __name__ == '__main__':
os.makedirs("logs", exist_ok=True)
parser = utils.args()
parser.add_argument("-alpha", "--alpha", type=float, default = 1.5e-4)
parser.add_argument("-st", "--starting-temp", type=float, default = 3e-4)
args = parser.parse_args()
print(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(args.seed)
dataset = get_data_loaders(args)
metadata = dataset[2]
if args.topk is not None:
topk = tuple(args.topk)
else:
if args.dataset=="imagenet":
topk=(1,5)
else:
topk=(1,)
model = get_model(args, metadata)
if args.n_params is not None and args.model not in ["block_thrifty", "blockthrifty"]:
model, args = get_model_exact_params(model, args, metadata)
CONV_WEIGHT_BACKUP1 = model.Lblock.Lconv.conv1.weight.data
CONV_WEIGHT_BACKUP2 = model.Lblock.Lconv.conv2.weight.data
# Log for parameters, filters and pooling strategy
info_dict = utils.get_info(model, metadata)
for key,val in info_dict.items():
print(key, " : ", val)
print("")
model = model.to(device)
if args.resume is None:
# First phase of training
# Init optimizer and scheduler
scheduler = None
if args.optimizer=="sgd":
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
schedule_fun = lambda epoch, gamma=args.gamma, steps=args.steps : utils.reduceLR(epoch, gamma, steps)
scheduler = LambdaLR(optimizer, lr_lambda= schedule_fun)
elif args.optimizer=="adam":
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
if args.name is not None:
with open("logs/{}.log".format(args.name), "a") as f:
f.write(str(args))
f.write("\nParameters : {}".format(n_parameters))
if hasattr(model, "n_filters"):
f.write("\nFilters : {}".format(model.n_filters))
else:
f.write("\nFilters : _ ")
f.write("\n*******\n")
print("-"*80 + "\n")
trainer1 = Trainer(device, model, dataset, optimizer, [CrossEntropy(), AlphaLoss()], name=args.name, topk=topk, checkpointFreq=args.checkpoint_freq)
trainer1.temperature = args.starting_temp
trainer1.callbacks.append(AlphaCallback(args.alpha))
if scheduler is not None:
trainer1.callbacks.append(SchedulerCB(scheduler))
trainer1.train(args.epochs)
torch.save(model.state_dict(), args.name+".model")
else: # arg.resume is not None
model.load_state_dict(torch.load(args.resume))
print("-"*80)
print("Binarize and fine tune\n")
print("")
FROZEN_ALPHA = (model.Lblock.alpha.data > 0.2).float().to(device)
with open("logs/{}.log".format(args.name), "a") as f:
f.write("*******\nFine tuning after binarization\n*******\n")
model.Lblock.alpha.data = FROZEN_ALPHA
model.Lblock.alpha.requires_grad = False
# Beginning of second training phase
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
schedule_fun = lambda epoch, gamma=args.gamma, steps=args.steps : utils.reduceLR(epoch, gamma, steps)
scheduler = LambdaLR(optimizer, lr_lambda= schedule_fun)
trainer2 = Trainer(device, model, dataset, optimizer, CrossEntropy(), name=args.name, topk=topk, checkpointFreq=args.checkpoint_freq)
trainer2.callbacks.append(SchedulerCB(scheduler))
trainer2.train(args.epochs, args.epochs)
print("\n"+"-"*80)
print("Train again from scratch with same initialization\n")
print("")
with open("logs/{}.log".format(args.name), "a") as f:
f.write("*******\nTrain from scratch with same init\n*******\n")
# Reinitialize model
model = get_model(args, metadata).to(device)
model.Lblock.alpha.data = FROZEN_ALPHA
model.Lblock.Lconv.conv1.weight.data = CONV_WEIGHT_BACKUP1.to(device)
model.Lblock.Lconv.conv2.weight.data = CONV_WEIGHT_BACKUP2.to(device)
model.Lblock.alpha.requires_grad = False
# Beginning of third training phase
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
schedule_fun = lambda epoch, gamma=args.gamma, steps=args.steps : utils.reduceLR(epoch, gamma, steps)
scheduler = LambdaLR(optimizer, lr_lambda= schedule_fun)
trainer3 = Trainer(device, model, dataset, optimizer, CrossEntropy(), name=args.name, topk=topk, checkpointFreq=args.checkpoint_freq)
trainer3.callbacks.append(SchedulerCB(scheduler))
trainer3.train(args.epochs, args.epochs)
print("\n"+"-"*80)
print("Train again from scratch, another init\n")
print("")
with open("logs/{}.log".format(args.name), "a") as f:
f.write("*******\nTrain from scratch another init\n*******\n")
# Reinitialize model
model = get_model(args, metadata).to(device)
model.Lblock.alpha.data = FROZEN_ALPHA
model.Lblock.alpha.requires_grad = False
# Beginning of third training phase
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
schedule_fun = lambda epoch, gamma=args.gamma, steps=args.steps : utils.reduceLR(epoch, gamma, steps)
scheduler = LambdaLR(optimizer, lr_lambda= schedule_fun)
trainer4 = Trainer(device, model, dataset, optimizer, CrossEntropy(), name=args.name, topk=topk, checkpointFreq=args.checkpoint_freq)
trainer4.callbacks.append(SchedulerCB(scheduler))
trainer4.train(args.epochs, args.epochs)