-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrainmodel.py
More file actions
64 lines (64 loc) · 1.79 KB
/
trainmodel.py
File metadata and controls
64 lines (64 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# import sys, time
# import numpy as np
# import torch
#
#
# torch.manual_seed(0)
# torch.cuda.manual_seed(0)
# np.random.seed(0)
#
#
#
# ########################################################################################################################
#
# class Appr(object):
# def __init__(self, model,args=None):
# self.model = model
# self.nepochs = args.num_epochs
# self.sbatch = args.batch_size
# self.lr = args.learning_rate
# self.ce = torch.nn.CrossEntropyLoss()
# self.optimizer = self._get_optimizer()
# self.gpu = args.gpu_id
#
#
# return
#
# def update_lr(self,optimizer):
# for param_group in optimizer.param_groups:
# param_group['lr'] = param_group['lr']/10.
#
#
# def _get_optimizer(self, lr=None):
# lr = self.lr
# optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=lr,momentum = 0.5)
# return optimizer
#
# def train(self, train_loader):
#
# lr = self.lr
# self.optimizer = self._get_optimizer(lr)
# nepochs = self.nepochs
# # Loop epochs
# try:
# for e in range(nepochs):
# self.train_epoch(train_loader, cur_epoch=e, nepoch=nepochs)
# except KeyboardInterrupt:
# print()
#
#
# def train_epoch(self,train_loader, cur_epoch=0, nepoch=0):
# self.model.train()
# for i, (images, labels) in enumerate(train_loader):
#
# images = images.cuda(self.gpu)
# targets = labels.cuda(self.gpu)
# output,_= self.model.forward(images)
# loss = self.ce(output, targets)
# self.optimizer.zero_grad()
# loss.backward()
# self.optimizer.step()
# return
#
#
#