-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathactor_critic.py
236 lines (193 loc) · 7.37 KB
/
actor_critic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mpi4py import MPI
# define actor class as per HER paper: https://arxiv.org/pdf/1707.01495.pdf
# 3 hidden layers, 64 hidden units in each layers
# ReLu activation for hidden layers, tanh activation for actor output
# rescale tanh output to [-5cm, 5cm] range
# add square of preactivations to actor's cost function
# input dims of size -> number of states + goal
class Actor(nn.Module):
def __init__(self, lr, input_dims, fc1_dims, fc2_dims,fc3_dims, nA,name,chkpt_dir='models'):
'''
Parameters:
----------
lr: float
Learning rate for actor model
input_dims: int
Input dimension for actor network
shape of state space + shape of goal
fc1_dims: int
Number of hidden units for fc1 layer * number of actions in action space
fc2_dims: int
Number of hidden units for fc2 layer * number of actions in action space
fc3_dims: int
Number of hidden units for fc3 layer * number of actions in action space
nA: int
Number of actions in action space
name: string
Name of the model
chkpt_dir: string
Directory to save the model in
Returns:
-------
None
'''
super(Actor, self).__init__()
self.input_dims = input_dims
self.fc1_dims = fc1_dims
self.fc2_dims = fc2_dims
self.fc3_dims = fc3_dims
self.lr = lr
self.nA = nA
self.chkpt_file = os.path.join(chkpt_dir,name)
# define network architecture
self.fc1 = nn.Linear(self.input_dims, self.fc1_dims) #hidden layer 1
self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims) # hidden layer 2
self.fc3 = nn.Linear(self.fc2_dims, self.fc3_dims) # hidden layer 3
self.mu = nn.Linear(self.fc3_dims, self.nA) # output layer-> noisy version of original policy
# define optimiser
self.optimiser = torch.optim.Adam(self.parameters(), lr = self.lr)
# move the model on to a device
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.to(self.device)
def forward(self, state):
'''
Parameters:
----------
state: int
Current state, goal pair
Returns:
-------
output: int
Distance from the goal state
'''
output = self.fc1(state)
output = F.relu(output)
output = self.fc2(output)
output = F.relu(output)
output = self.fc3(output)
output = F.relu(output)
output = torch.tanh(self.mu(output))
return output
def save_model(self, obs_mean, obs_std, goal_mean, goal_std):
print(f'Saving actor model at checkpoint')
if MPI.COMM_WORLD.Get_rank() == 0:
torch.save([obs_mean, obs_std, goal_mean, goal_std,self.state_dict()], self.chkpt_file)
def load_model(self):
print(f'Loading actor model from checkpoint')
if MPI.COMM_WORLD.Get_rank() == 0:
obs_mean, obs_std, goal_mean, goal_std, state_dict = torch.load(self.chkpt_file)
self.load_state_dict(state_dict)
return obs_mean, obs_std, goal_mean, goal_std
class Critic(nn.Module):
def __init__(self, lr, input_dims, fc1_dims, fc2_dims, fc3_dims,nA, name, chkpt_dir = 'models'):
'''
Parameters:
----------
lr: float
Learning rate for critic model
input_dims: int
Input dimension for critic network
shape of state space + shape of goal + shape of action space
fc1_dims: int
Number of hidden units for fc1 layer * number of actions in action space
fc2_dims: int
Number of hidden units for fc2 layer * number of actions in action space
fc3_dims: int
Number of hidden units for fc3 layer * number of actions in action space
nA: int
Number of actions in action space
name: string
Name of the model
chkpt_dir: string
Directory to save the model in
Returns:
-------
None
'''
super(Critic, self).__init__()
self.input_dims = input_dims # nS+ nG + nA
self.fc1_dims = fc1_dims
self.fc2_dims = fc2_dims
self.fc3_dims = fc3_dims
self.lr = lr
self.nA = nA
self.chkpt_file = os.path.join(chkpt_dir,name)
# define network architecture
self.fc1 = nn.Linear(self.input_dims, self.fc1_dims) #hidden layer 1
self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims) # hidden layer 2
self.fc3 = nn.Linear(self.fc2_dims, self.fc3_dims) # hidden layer 3
self.Q = nn.Linear(self.fc3_dims, 1) # output -> Q value based on state and action(from target policy from actor model)
# define optimiser
self.optimiser = torch.optim.Adam(self.parameters(), lr = self.lr)
# move the model on to a device
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.to(self.device)
# define forward class
def forward(self, state, action):
'''
Parameters:
----------
state: int
Current state, goal pair
action: float32
Current action value
Returns:
-------
q_value: float32
Q value corresponding to (s,a,g) pair
'''
# create state, action value pair
state_action_value = torch.cat([state, action], dim=1)
state_action_value = self.fc1(state_action_value)
state_action_value = F.relu(state_action_value)
state_action_value = self.fc2(state_action_value)
state_action_value = F.relu(state_action_value)
state_action_value = self.fc3(state_action_value)
state_action_value = F.relu(state_action_value)
q_value = self.Q(state_action_value)
return q_value
def save_model(self, obs_mean, obs_std, goal_mean, goal_std):
'''
Parameters:
----------
obs_mean: float32
Mean of observations
obs_std: float32
Standard Deviation of observations
goal_mean: float32
Mean of goals
goal_std: float32
Standard deviaiton of goals
Returns:
-------
None
'''
if MPI.COMM_WORLD.Get_rank() == 0:
print(f'Saving actor model at checkpoint')
torch.save([obs_mean, obs_std, goal_mean, goal_std,self.state_dict()], self.chkpt_file)
def load_model(self):
'''
Parameters:
-----------
None
Returns:
--------
obs_mean: float32
Mean of observations
obs_std: float32
Standard Deviation of observations
goal_mean: float32
Mean of goals
goal_std: float32
Standard deviaiton of goals
'''
if MPI.COMM_WORLD.Get_rank() == 0:
print(f'Loading actor model from checkpoint')
obs_mean, obs_std, goal_mean, goal_std, state_dict = torch.load(self.chkpt_file)
self.load_state_dict(state_dict)
return obs_mean, obs_std, goal_mean, goal_std