-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizers_schedulers.py
61 lines (50 loc) · 1.83 KB
/
optimizers_schedulers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
# Optimizers
def optimizers(
parameters,
args
):
"""
Getting a optimizer.
"""
if args.optimizer == "Adam":
# print ("Optimizer: Adam")
return torch.optim.Adam(parameters, args.lr, weight_decay=args.weight_decay)
elif args.optimizer == "AdamW":
# print ("Optimizer: AdamW")
return torch.optim.AdamW(parameters, args.lr, weight_decay=args.weight_decay)
elif args.optimizer == "SGD":
# print ("Optimizer: SGD")
return torch.optim.SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif args.optimizer == "Adagrad":
# print ("Optimizer: Adagrad")
return torch.optim.Adagrad(parameters, args.lr, weight_decay=args.weight_decay)
else:
raise 'Consider a optimizer among ("Adam", "AdamW", "SGD", "Adagrad").'
return None
# Schedulers
def schedulers(
optimizer,
args
):
"""
Getting a scheduler.
"""
if args.scheduler == "CosineAnnealingLR":
# print ("Scheduler: CosineAnnealingLR")
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs, eta_min=1e-10)
elif args.scheduler == "CosineAnnealingWarmRestarts":
# print ("Scheduler: CosineAnnealingWarmRestarts")
if args.warmup_epochs == 0:
raise "No.of warmup epochs should be greater than 0."
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=args.warmup_epochs, T_mult=args.multiplier, eta_min=1e-10)
elif args.scheduler == "ReduceLROnPlateau":
# print ("Scheduler: ReduceLROnPlateau")
return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001)
elif args.scheduler is None or args.scheduler == "None":
# print ("Scheduler: None")
return None
else:
raise 'Consider a scheduler among ("CosineAnnealingLR", "CosineAnnealingWarmRestarts", "ReduceLROnPlateau").'
return None