-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathtrain.py
90 lines (72 loc) · 2.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import sys
import os
import logging
import time
import itertools
from backbone import EmbedNetwork
from loss import TripletLoss
from triplet_selector import BatchHardTripletSelector
from batch_sampler import BatchSampler
from datasets.Market1501 import Market1501
from optimizer import AdamOptimWrapper
from logger import logger
def train():
## setup
torch.multiprocessing.set_sharing_strategy('file_system')
if not os.path.exists('./res'): os.makedirs('./res')
## model and loss
logger.info('setting up backbone model and loss')
net = EmbedNetwork().cuda()
net = nn.DataParallel(net)
triplet_loss = TripletLoss(margin = None).cuda() # no margin means soft-margin
## optimizer
logger.info('creating optimizer')
optim = AdamOptimWrapper(net.parameters(), lr = 3e-4, wd = 0, t0 = 15000, t1 = 25000)
## dataloader
selector = BatchHardTripletSelector()
ds = Market1501('datasets/Market-1501-v15.09.15/bounding_box_train', is_train = True)
sampler = BatchSampler(ds, 18, 4)
dl = DataLoader(ds, batch_sampler = sampler, num_workers = 4)
diter = iter(dl)
## train
logger.info('start training ...')
loss_avg = []
count = 0
t_start = time.time()
while True:
try:
imgs, lbs, _ = next(diter)
except StopIteration:
diter = iter(dl)
imgs, lbs, _ = next(diter)
net.train()
imgs = imgs.cuda()
lbs = lbs.cuda()
embds = net(imgs)
anchor, positives, negatives = selector(embds, lbs)
loss = triplet_loss(anchor, positives, negatives)
optim.zero_grad()
loss.backward()
optim.step()
loss_avg.append(loss.detach().cpu().numpy())
if count % 20 == 0 and count != 0:
loss_avg = sum(loss_avg) / len(loss_avg)
t_end = time.time()
time_interval = t_end - t_start
logger.info('iter: {}, loss: {:4f}, lr: {:4f}, time: {:3f}'.format(count, loss_avg, optim.lr, time_interval))
loss_avg = []
t_start = t_end
count += 1
if count == 25000: break
## dump model
logger.info('saving trained model')
torch.save(net.module.state_dict(), './res/model.pkl')
logger.info('everything finished')
if __name__ == '__main__':
train()