1
1
from __future__ import print_function
2
2
import numpy as np
3
+ import scipy .stats as st
3
4
import multiprocessing as mp
5
+ from collections .abc import Iterable
4
6
5
7
np .random .seed (0 )
6
8
@@ -12,21 +14,33 @@ def worker_process(arg):
12
14
13
15
class EvolutionStrategy (object ):
14
16
def __init__ (self , weights , get_reward_func , population_size = 50 , sigma = 0.1 , learning_rate = 0.03 , decay = 0.999 ,
15
- num_threads = 1 ):
16
-
17
+ num_threads = 1 , limits = None , printer = None , distributions = None ):
18
+ if limits is None :
19
+ limits = (np .inf , - np .inf )
17
20
self .weights = weights
21
+ self .limits = limits
18
22
self .get_reward = get_reward_func
19
23
self .POPULATION_SIZE = population_size
20
- self .SIGMA = sigma
24
+ if distributions is None :
25
+ distributions = st .norm (loc = 0. , scale = sigma )
26
+ if isinstance (distributions , Iterable ):
27
+ distributions = list (distributions )
28
+ self .SIGMA = np .array ([d .std () for d in distributions ])
29
+ else :
30
+ self .SIGMA = distributions .std ()
31
+
32
+ self .distributions = distributions
21
33
self .learning_rate = learning_rate
22
34
self .decay = decay
23
35
self .num_threads = mp .cpu_count () if num_threads == - 1 else num_threads
36
+ if printer is None :
37
+ printer = print
38
+ self .printer = printer
24
39
25
40
def _get_weights_try (self , w , p ):
26
41
weights_try = []
27
42
for index , i in enumerate (p ):
28
- jittered = self .SIGMA * i
29
- weights_try .append (w [index ] + jittered )
43
+ weights_try .append (w [index ] + i )
30
44
return weights_try
31
45
32
46
def get_weights (self ):
@@ -36,8 +50,13 @@ def _get_population(self):
36
50
population = []
37
51
for i in range (self .POPULATION_SIZE ):
38
52
x = []
39
- for w in self .weights :
40
- x .append (np .random .randn (* w .shape ))
53
+ if isinstance (self .distributions , Iterable ):
54
+ for j , w in enumerate (self .weights ):
55
+ x .append (self .distributions [j ].rvs (* w .shape ))
56
+ else :
57
+ for w in self .weights :
58
+ x .append (self .distributions .rvs (* w .shape ))
59
+
41
60
population .append (x )
42
61
return population
43
62
@@ -59,10 +78,14 @@ def _update_weights(self, rewards, population):
59
78
if std == 0 :
60
79
return
61
80
rewards = (rewards - rewards .mean ()) / std
81
+ update_factor = self .learning_rate / (self .POPULATION_SIZE * self .SIGMA )
82
+
62
83
for index , w in enumerate (self .weights ):
63
84
layer_population = np .array ([p [index ] for p in population ])
64
- update_factor = self .learning_rate / (self .POPULATION_SIZE * self .SIGMA )
65
- self .weights [index ] = w + update_factor * np .dot (layer_population .T , rewards ).T
85
+ if not isinstance (update_factor , np .ndarray ):
86
+ self .weights [index ] = w + update_factor * np .dot (layer_population .T , rewards ).T
87
+ else :
88
+ self .weights [index ] = w + update_factor [index ] * np .dot (layer_population .T , rewards ).T
66
89
self .learning_rate *= self .decay
67
90
68
91
def run (self , iterations , print_step = 10 ):
@@ -75,7 +98,7 @@ def run(self, iterations, print_step=10):
75
98
self ._update_weights (rewards , population )
76
99
77
100
if (iteration + 1 ) % print_step == 0 :
78
- print ('iter %d. reward: %f' % (iteration + 1 , self .get_reward (self .weights )))
101
+ self . printer ('iter %d. reward: %f' % (iteration + 1 , self .get_reward (self .weights )))
79
102
if pool is not None :
80
103
pool .close ()
81
104
pool .join ()
0 commit comments