-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizer.py
69 lines (51 loc) · 2.18 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Optimizer.
# Uses operator overloading to have optimizers act as drop-in replacement of
# learning rate scalars.
import abc
import dataclasses
import numpy as np
class Optimizer(metaclass=abc.ABCMeta):
def update(self, obj: object, attribute: str,
gradient: np.ndarray) -> None:
identifier = f'{id(obj)}.{attribute}'
variable = getattr(obj, attribute)
variable = self.update_variable(identifier, variable, gradient)
setattr(obj, attribute, variable)
@abc.abstractmethod
def update_variable(self, identifier: str, variable: np.ndarray,
gradient: np.ndarray) -> np.ndarray:
pass
class SGDOptimizer(Optimizer):
def __init__(self, learning_rate: float) -> float:
self._learning_rate = learning_rate
def update_variable(self, identifier: str, variable: np.ndarray,
gradient: np.ndarray) -> np.ndarray:
variable -= self._learning_rate * gradient
return variable
@dataclasses.dataclass
class AdamOptimizerConfig:
learning_rate: float
beta1: float = 0.9
beta2: float = 0.999
epsilon: float = 1e-7
def __post_init__(self, *args, **kwargs):
self._steps = {}
self._momentums = {}
self._velocities = {}
class AdamOptimizer(AdamOptimizerConfig, Optimizer):
def update_variable(self, identifier: str, variable: np.ndarray,
gradient: np.ndarray) -> None:
t = self._steps.get(identifier, 1)
# print(f'Updating {identifier} at step {t}')
m = self._momentums.get(identifier, np.zeros(gradient.shape))
v = self._velocities.get(identifier, np.zeros(gradient.shape))
new_m = self.beta1 * m + (1 - self.beta1) * gradient
new_n = self.beta2 * v + (1 - self.beta2) * gradient**2
corrected_m = new_m / (1 - self.beta1**t)
corrected_n = new_n / (1 - self.beta2**t)
variable -= self.learning_rate * (corrected_m /
np.sqrt(corrected_n + self.epsilon))
self._steps[identifier] = t + 1
self._momentums[identifier] = new_m
self._velocities[identifier] = new_n
return variable