1
1
from __future__ import print_function
2
2
import numpy as np
3
+ import scipy .stats as st
3
4
import multiprocessing as mp
5
+ from collections .abc import Iterable
4
6
5
7
np .random .seed (0 )
6
8
@@ -9,24 +11,66 @@ def worker_process(arg):
9
11
get_reward_func , weights = arg
10
12
return get_reward_func (weights )
11
13
14
+ class WeightUpdateStrategy :
15
+ __slots__ = ("learning_rate" ,)
16
+ def __init__ (self , dim , learning_rate ):
17
+ self .learning_rate = learning_rate
18
+
19
+
20
+ class strategies :
21
+ class GD (WeightUpdateStrategy ):
22
+ def update (self , i , g ):
23
+ return self .learning_rate * g
24
+
25
+
26
+ class Adam (WeightUpdateStrategy ):
27
+ __slots__ = ("eps" , "beta1" , "beta2" , "m" , "v" )
28
+ def __init__ (self , dim , learning_rate , eps = 1e-8 , beta1 = 0.9 , beta2 = 0.999 ):
29
+ super ().__init__ (dim , learning_rate )
30
+ self .eps = eps
31
+ self .beta1 = beta1
32
+ self .beta2 = beta2
33
+ self .m = np .zeros (dim )
34
+ self .v = np .zeros (dim )
35
+
36
+ def update (self , i , g ):
37
+ self .m [i ] = self .beta1 * self .m [i ] + (1 - self .beta1 ) * g
38
+ self .v [i ] = self .beta2 * self .v [i ] + (1 - self .beta2 ) * (g ** 2 )
39
+ return self .learning_rate * np .sqrt (1 - self .beta2 ) / (1 - self .beta1 ) * self .m [i ] / np .sqrt (np .sqrt (self .v [i ])+ self .eps )
40
+
12
41
13
42
class EvolutionStrategy (object ):
14
43
def __init__ (self , weights , get_reward_func , population_size = 50 , sigma = 0.1 , learning_rate = 0.03 , decay = 0.999 ,
15
- num_threads = 1 ):
16
-
44
+ num_threads = 1 , limits = None , printer = None , distributions = None , strategy = None ):
45
+ if limits is None :
46
+ limits = (np .inf , - np .inf )
17
47
self .weights = weights
48
+ self .limits = limits
18
49
self .get_reward = get_reward_func
19
50
self .POPULATION_SIZE = population_size
20
- self .SIGMA = sigma
51
+ if distributions is None :
52
+ distributions = st .norm (loc = 0. , scale = sigma )
53
+ if isinstance (distributions , Iterable ):
54
+ distributions = list (distributions )
55
+ self .SIGMA = np .array ([d .std () for d in distributions ])
56
+ else :
57
+ self .SIGMA = distributions .std ()
58
+
59
+ self .distributions = distributions
21
60
self .learning_rate = learning_rate
22
61
self .decay = decay
23
62
self .num_threads = mp .cpu_count () if num_threads == - 1 else num_threads
63
+ if printer is None :
64
+ printer = print
65
+ self .printer = printer
66
+ if strategy is None :
67
+ strategy = strategies .GD
68
+ self .strategy = strategy (len (weights ), self .learning_rate )
24
69
25
70
def _get_weights_try (self , w , p ):
26
71
weights_try = []
27
72
for index , i in enumerate (p ):
28
- jittered = self .SIGMA * i
29
- weights_try .append (w [index ] + jittered )
73
+ weights_try .append (w [index ] + i )
30
74
return weights_try
31
75
32
76
def get_weights (self ):
@@ -36,8 +80,13 @@ def _get_population(self):
36
80
population = []
37
81
for i in range (self .POPULATION_SIZE ):
38
82
x = []
39
- for w in self .weights :
40
- x .append (np .random .randn (* w .shape ))
83
+ if isinstance (self .distributions , Iterable ):
84
+ for j , w in enumerate (self .weights ):
85
+ x .append (self .distributions [j ].rvs (* w .shape ))
86
+ else :
87
+ for w in self .weights :
88
+ x .append (self .distributions .rvs (* w .shape ))
89
+
41
90
population .append (x )
42
91
return population
43
92
@@ -59,10 +108,17 @@ def _update_weights(self, rewards, population):
59
108
if std == 0 :
60
109
return
61
110
rewards = (rewards - rewards .mean ()) / std
111
+ grad_factor = 1. / (self .POPULATION_SIZE * (self .SIGMA ** 2 ))
112
+
62
113
for index , w in enumerate (self .weights ):
63
114
layer_population = np .array ([p [index ] for p in population ])
64
- update_factor = self .learning_rate / (self .POPULATION_SIZE * self .SIGMA )
65
- self .weights [index ] = w + update_factor * np .dot (layer_population .T , rewards ).T
115
+ corr = np .dot (layer_population .T , rewards ).T
116
+
117
+ if not isinstance (grad_factor , np .ndarray ):
118
+ g = grad_factor * corr
119
+ else :
120
+ g = grad_factor [index ] * corr
121
+ self .weights [index ] = w + self .strategy .update (index , g )
66
122
self .learning_rate *= self .decay
67
123
68
124
def run (self , iterations , print_step = 10 ):
@@ -75,7 +131,8 @@ def run(self, iterations, print_step=10):
75
131
self ._update_weights (rewards , population )
76
132
77
133
if (iteration + 1 ) % print_step == 0 :
78
- print ('iter %d. reward: %f' % (iteration + 1 , self .get_reward (self .weights )))
134
+ #self.printer('iter %d. reward: %f' % (iteration + 1, self.get_reward(self.weights)))
135
+ self .printer ('iter %d. reward: %f' % (iteration + 1 , self .get_reward (self .weights )), self .weights )
79
136
if pool is not None :
80
137
pool .close ()
81
138
pool .join ()
0 commit comments