-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_teacher.py
More file actions
93 lines (78 loc) · 2.95 KB
/
train_teacher.py
File metadata and controls
93 lines (78 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import wandb
import datetime
import pprint
import torch
import torch.nn as nn
from models import *
from utils import *
from dataloaders import *
from optimizers import *
################################
#### 0. SETUP CONFIGURATION
################################
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
cfg = exec_configurator()
if cfg['teacher_model']['model_name'] in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']:
cfg['optimizer']['lr'] = 0.01
initialize(cfg['trainer']['seed'])
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc, start_epoch, logging_dict = 0, 0, {}
# Total number of training epochs
EPOCHS = cfg['trainer']['epochs']
print('==> Initialize Logging Framework..')
logging_name = 'T_' + get_logging_name(cfg)
logging_name += ('_' + current_time)
framework_name = cfg['logging']['framework_name']
if framework_name == 'wandb':
wandb.init(project=cfg['logging']['project_name'], name=logging_name, config=cfg)
pprint.pprint(cfg)
################################
#### 1. BUILD THE DATASET
################################
train_dataloader, test_dataloader, num_classes = get_dataloader(**cfg['dataloader'])
################################
#### 2. BUILD THE NEURAL NETWORK
################################
teacher_model_name = cfg['teacher_model'].pop('model_name', None)
teacher_model = model_dict[teacher_model_name](num_classes=num_classes, **cfg['teacher_model'])
teacher_model = teacher_model.to(device)
total_params = sum(p.numel() for p in teacher_model.parameters())
print(f'==> Number of parameters in Teacher: {teacher_model_name}: {total_params}')
################################
#### 3.a OPTIMIZING MODEL PARAMETERS
################################
criterion = nn.CrossEntropyLoss()
opt_name = cfg['optimizer'].pop('opt_name', None)
optimizer = get_optimizer(teacher_model, opt_name, cfg['optimizer'])
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 180, 210], gamma=0.1)
################################
#### 3.b Training
################################
if __name__ == "__main__":
for epoch in range(1, EPOCHS + 1):
print('\nEpoch: %d' % epoch)
loop_one_epoch(
dataloader=train_dataloader,
model=teacher_model,
criterion=criterion,
optimizer=optimizer,
device=device,
logging_dict=logging_dict,
epoch=epoch,
loop_type='train',
logging_name=logging_name)
best_acc, acc = loop_one_epoch(
dataloader=test_dataloader,
model=teacher_model,
criterion=criterion,
optimizer=optimizer,
device=device,
logging_dict=logging_dict,
epoch=epoch,
loop_type='test',
logging_name=logging_name,
best_acc=best_acc)
if scheduler is not None:
scheduler.step()
if framework_name == 'wandb':
wandb.log(logging_dict)