1
- import os . path as osp
1
+ import os
2
2
import time
3
3
import joblib
4
+
4
5
import numpy as np
5
6
import tensorflow as tf
6
- from baselines import logger
7
7
8
- from baselines .common import set_global_seeds , explained_variance
8
+ from baselines import logger
9
+ from baselines .common import set_global_seeds , explained_variance , tf_util
9
10
from baselines .common .runners import AbstractEnvRunner
10
- from baselines .common import tf_util
11
+ from baselines .a2c . utils import discount_with_dones , Scheduler , make_path , find_trainable_variables , calc_entropy , mse
11
12
12
- from baselines .a2c .utils import discount_with_dones
13
- from baselines .a2c .utils import Scheduler , make_path , find_trainable_variables
14
- from baselines .a2c .utils import cat_entropy , mse
15
13
16
14
class Model (object ):
17
-
18
- def __init__ (self , policy , ob_space , ac_space , nenvs , nsteps ,
19
- ent_coef = 0.01 , vf_coef = 0.5 , max_grad_norm = 0.5 , lr = 7e-4 ,
20
- alpha = 0.99 , epsilon = 1e-5 , total_timesteps = int (80e6 ), lrschedule = 'linear' ):
15
+ def __init__ (self , policy , ob_space , ac_space , n_envs , n_steps ,
16
+ ent_coef = 0.01 , vf_coef = 0.25 , max_grad_norm = 0.5 , learning_rate = 7e-4 ,
17
+ alpha = 0.99 , epsilon = 1e-5 , total_timesteps = int (80e6 ), lr_schedule = 'linear' ):
18
+ """
19
+ The A2C (Advantage Actor Critic) model class, https://arxiv.org/abs/1602.01783
20
+
21
+ :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
22
+ :param ob_space: (Gym Space) Observation space
23
+ :param ac_space: (Gym Space) Action space
24
+ :param n_envs: (int) The number of environments
25
+ :param n_steps: (int) The number of steps to run for each environment
26
+ :param ent_coef: (float) Entropy coefficient for the loss caculation
27
+ :param vf_coef: (float) Value function coefficient for the loss calculation
28
+ :param max_grad_norm: (float) The maximum value for the gradient clipping
29
+ :param learning_rate: (float) The learning rate
30
+ :param alpha: (float) RMS prop optimizer decay
31
+ :param epsilon: (float) RMS prop optimizer epsilon
32
+ :param total_timesteps: (int) The total number of samples
33
+ :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
34
+ 'double_linear_con', 'middle_drop' or 'double_middle_drop')
35
+ """
21
36
22
37
sess = tf_util .make_session ()
23
- nbatch = nenvs * nsteps
38
+ n_batch = n_envs * n_steps
24
39
25
- A = tf .placeholder (tf .int32 , [nbatch ])
26
- ADV = tf .placeholder (tf .float32 , [nbatch ])
27
- R = tf .placeholder (tf .float32 , [nbatch ])
28
- LR = tf .placeholder (tf .float32 , [])
40
+ actions_ph = tf .placeholder (tf .int32 , [n_batch ])
41
+ advs_ph = tf .placeholder (tf .float32 , [n_batch ])
42
+ rewards_ph = tf .placeholder (tf .float32 , [n_batch ])
43
+ learning_rate_ph = tf .placeholder (tf .float32 , [])
29
44
30
- step_model = policy (sess , ob_space , ac_space , nenvs , 1 , reuse = False )
31
- train_model = policy (sess , ob_space , ac_space , nenvs * nsteps , nsteps , reuse = True )
45
+ step_model = policy (sess , ob_space , ac_space , n_envs , 1 , reuse = False )
46
+ train_model = policy (sess , ob_space , ac_space , n_envs * n_steps , n_steps , reuse = True )
32
47
33
- neglogpac = tf .nn .sparse_softmax_cross_entropy_with_logits (logits = train_model .pi , labels = A )
34
- pg_loss = tf .reduce_mean (ADV * neglogpac )
35
- vf_loss = tf . reduce_mean ( mse (tf .squeeze (train_model .vf ), R ) )
36
- entropy = tf .reduce_mean (cat_entropy (train_model .pi ))
37
- loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
48
+ neglogpac = tf .nn .sparse_softmax_cross_entropy_with_logits (logits = train_model .policy , labels = actions_ph )
49
+ pg_loss = tf .reduce_mean (advs_ph * neglogpac )
50
+ vf_loss = mse (tf .squeeze (train_model .value_fn ), rewards_ph )
51
+ entropy = tf .reduce_mean (calc_entropy (train_model .policy ))
52
+ loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
38
53
39
54
params = find_trainable_variables ("model" )
40
55
grads = tf .gradients (loss , params )
41
56
if max_grad_norm is not None :
42
- grads , grad_norm = tf .clip_by_global_norm (grads , max_grad_norm )
57
+ grads , _ = tf .clip_by_global_norm (grads , max_grad_norm )
43
58
grads = list (zip (grads , params ))
44
- trainer = tf .train .RMSPropOptimizer (learning_rate = LR , decay = alpha , epsilon = epsilon )
59
+ trainer = tf .train .RMSPropOptimizer (learning_rate = learning_rate_ph , decay = alpha , epsilon = epsilon )
45
60
_train = trainer .apply_gradients (grads )
46
61
47
- lr = Scheduler (v = lr , nvalues = total_timesteps , schedule = lrschedule )
62
+ learning_rate = Scheduler (initial_value = learning_rate , n_values = total_timesteps , schedule = lr_schedule )
48
63
49
64
def train (obs , states , rewards , masks , actions , values ):
50
65
advs = rewards - values
51
- for step in range (len (obs )):
52
- cur_lr = lr .value ()
53
- td_map = {train_model .X :obs , A :actions , ADV :advs , R :rewards , LR :cur_lr }
66
+ for _ in range (len (obs )):
67
+ cur_lr = learning_rate .value ()
68
+ td_map = {train_model .obs_ph : obs , actions_ph : actions , advs_ph : advs ,
69
+ rewards_ph : rewards , learning_rate_ph : cur_lr }
54
70
if states is not None :
55
- td_map [train_model .S ] = states
56
- td_map [train_model .M ] = masks
71
+ td_map [train_model .states_ph ] = states
72
+ td_map [train_model .masks_ph ] = masks
57
73
policy_loss , value_loss , policy_entropy , _ = sess .run (
58
74
[pg_loss , vf_loss , entropy , _train ],
59
75
td_map
60
76
)
61
77
return policy_loss , value_loss , policy_entropy
62
78
63
79
def save (save_path ):
64
- ps = sess .run (params )
65
- make_path (osp .dirname (save_path ))
66
- joblib .dump (ps , save_path )
80
+ parameters = sess .run (params )
81
+ make_path (os . path .dirname (save_path ))
82
+ joblib .dump (parameters , save_path )
67
83
68
84
def load (load_path ):
69
85
loaded_params = joblib .load (load_path )
70
86
restores = []
71
- for p , loaded_p in zip (params , loaded_params ):
72
- restores .append (p .assign (loaded_p ))
87
+ for param , loaded_p in zip (params , loaded_params ):
88
+ restores .append (param .assign (loaded_p ))
73
89
sess .run (restores )
74
90
75
91
self .train = train
@@ -82,16 +98,30 @@ def load(load_path):
82
98
self .load = load
83
99
tf .global_variables_initializer ().run (session = sess )
84
100
85
- class Runner (AbstractEnvRunner ):
86
101
87
- def __init__ (self , env , model , nsteps = 5 , gamma = 0.99 ):
88
- super ().__init__ (env = env , model = model , nsteps = nsteps )
102
+ class Runner (AbstractEnvRunner ):
103
+ def __init__ (self , env , model , n_steps = 5 , gamma = 0.99 ):
104
+ """
105
+ A runner to learn the policy of an environment for a model
106
+
107
+ :param env: (Gym environment) The environment to learn from
108
+ :param model: (Model) The model to learn
109
+ :param n_steps: (int) The number of steps to run for each environment
110
+ :param gamma: (float) Discount factor
111
+ """
112
+ super (Runner , self ).__init__ (env = env , model = model , n_steps = n_steps )
89
113
self .gamma = gamma
90
114
91
115
def run (self ):
92
- mb_obs , mb_rewards , mb_actions , mb_values , mb_dones = [],[],[],[],[]
116
+ """
117
+ Run a learning step of the model
118
+
119
+ :return: ([float], [float], [float], [bool], [float], [float])
120
+ observations, states, rewards, masks, actions, values
121
+ """
122
+ mb_obs , mb_rewards , mb_actions , mb_values , mb_dones = [], [], [], [], []
93
123
mb_states = self .states
94
- for n in range (self .nsteps ):
124
+ for _ in range (self .n_steps ):
95
125
actions , values , states , _ = self .model .step (self .obs , self .states , self .dones )
96
126
mb_obs .append (np .copy (self .obs ))
97
127
mb_actions .append (actions )
@@ -102,11 +132,11 @@ def run(self):
102
132
self .dones = dones
103
133
for n , done in enumerate (dones ):
104
134
if done :
105
- self .obs [n ] = self .obs [n ]* 0
135
+ self .obs [n ] = self .obs [n ] * 0
106
136
self .obs = obs
107
137
mb_rewards .append (rewards )
108
138
mb_dones .append (self .dones )
109
- #batch of steps to batch of rollouts
139
+ # batch of steps to batch of rollouts
110
140
mb_obs = np .asarray (mb_obs , dtype = np .uint8 ).swapaxes (1 , 0 ).reshape (self .batch_ob_shape )
111
141
mb_rewards = np .asarray (mb_rewards , dtype = np .float32 ).swapaxes (1 , 0 )
112
142
mb_actions = np .asarray (mb_actions , dtype = np .int32 ).swapaxes (1 , 0 )
@@ -115,12 +145,12 @@ def run(self):
115
145
mb_masks = mb_dones [:, :- 1 ]
116
146
mb_dones = mb_dones [:, 1 :]
117
147
last_values = self .model .value (self .obs , self .states , self .dones ).tolist ()
118
- #discount/bootstrap off value fn
148
+ # discount/bootstrap off value fn
119
149
for n , (rewards , dones , value ) in enumerate (zip (mb_rewards , mb_dones , last_values )):
120
150
rewards = rewards .tolist ()
121
151
dones = dones .tolist ()
122
152
if dones [- 1 ] == 0 :
123
- rewards = discount_with_dones (rewards + [value ], dones + [0 ], self .gamma )[:- 1 ]
153
+ rewards = discount_with_dones (rewards + [value ], dones + [0 ], self .gamma )[:- 1 ]
124
154
else :
125
155
rewards = discount_with_dones (rewards , dones , self .gamma )
126
156
mb_rewards [n ] = rewards
@@ -130,31 +160,56 @@ def run(self):
130
160
mb_masks = mb_masks .flatten ()
131
161
return mb_obs , mb_states , mb_rewards , mb_masks , mb_actions , mb_values
132
162
133
- def learn (policy , env , seed , nsteps = 5 , total_timesteps = int (80e6 ), vf_coef = 0.5 , ent_coef = 0.01 , max_grad_norm = 0.5 , lr = 7e-4 , lrschedule = 'linear' , epsilon = 1e-5 , alpha = 0.99 , gamma = 0.99 , log_interval = 100 ):
163
+
164
+ def learn (policy , env , seed , n_steps = 5 , total_timesteps = int (80e6 ), vf_coef = 0.5 , ent_coef = 0.01 , max_grad_norm = 0.5 ,
165
+ learning_rate = 7e-4 , lr_schedule = 'linear' , epsilon = 1e-5 , alpha = 0.99 , gamma = 0.99 , log_interval = 100 ):
166
+ """
167
+ Return a trained A2C model.
168
+
169
+ :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
170
+ :param env: (Gym environment) The environment to learn from
171
+ :param seed: (int) The initial seed for training
172
+ :param n_steps: (int) The number of steps to run for each environment
173
+ :param total_timesteps: (int) The total number of samples
174
+ :param vf_coef: (float) Value function coefficient for the loss calculation
175
+ :param ent_coef: (float) Entropy coefficient for the loss caculation
176
+ :param max_grad_norm: (float) The maximum value for the gradient clipping
177
+ :param learning_rate: (float) The learning rate
178
+ :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
179
+ 'double_linear_con', 'middle_drop' or 'double_middle_drop')
180
+ :param epsilon: (float) RMS prop optimizer epsilon
181
+ :param alpha: (float) RMS prop optimizer decay
182
+ :param gamma: (float) Discount factor
183
+ :param log_interval: (int) The number of timesteps before logging.
184
+ :return: (Model) A2C model
185
+ """
134
186
set_global_seeds (seed )
135
187
136
- nenvs = env .num_envs
188
+ n_envs = env .num_envs
137
189
ob_space = env .observation_space
138
190
ac_space = env .action_space
139
- model = Model (policy = policy , ob_space = ob_space , ac_space = ac_space , nenvs = nenvs , nsteps = nsteps , ent_coef = ent_coef , vf_coef = vf_coef ,
140
- max_grad_norm = max_grad_norm , lr = lr , alpha = alpha , epsilon = epsilon , total_timesteps = total_timesteps , lrschedule = lrschedule )
141
- runner = Runner (env , model , nsteps = nsteps , gamma = gamma )
142
-
143
- nbatch = nenvs * nsteps
144
- tstart = time .time ()
145
- for update in range (1 , total_timesteps // nbatch + 1 ):
191
+ model = Model (policy = policy , ob_space = ob_space , ac_space = ac_space , n_envs = n_envs ,
192
+ n_steps = n_steps , ent_coef = ent_coef ,
193
+ vf_coef = vf_coef , max_grad_norm = max_grad_norm , learning_rate = learning_rate ,
194
+ alpha = alpha , epsilon = epsilon , total_timesteps = total_timesteps ,
195
+ lr_schedule = lr_schedule )
196
+ runner = Runner (env , model , n_steps = n_steps , gamma = gamma )
197
+
198
+ n_batch = n_envs * n_steps
199
+ t_start = time .time ()
200
+ for update in range (1 , total_timesteps // n_batch + 1 ):
146
201
obs , states , rewards , masks , actions , values = runner .run ()
147
- policy_loss , value_loss , policy_entropy = model .train (obs , states , rewards , masks , actions , values )
148
- nseconds = time .time ()- tstart
149
- fps = int ((update * nbatch ) / nseconds )
202
+ _ , value_loss , policy_entropy = model .train (obs , states , rewards , masks , actions , values )
203
+ n_seconds = time .time () - t_start
204
+ fps = int ((update * n_batch ) / n_seconds )
150
205
if update % log_interval == 0 or update == 1 :
151
- ev = explained_variance (values , rewards )
206
+ explained_var = explained_variance (values , rewards )
152
207
logger .record_tabular ("nupdates" , update )
153
- logger .record_tabular ("total_timesteps" , update * nbatch )
208
+ logger .record_tabular ("total_timesteps" , update * n_batch )
154
209
logger .record_tabular ("fps" , fps )
155
210
logger .record_tabular ("policy_entropy" , float (policy_entropy ))
156
211
logger .record_tabular ("value_loss" , float (value_loss ))
157
- logger .record_tabular ("explained_variance" , float (ev ))
212
+ logger .record_tabular ("explained_variance" , float (explained_var ))
158
213
logger .dump_tabular ()
159
214
env .close ()
160
215
return model
0 commit comments