-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathActionValue.py
More file actions
207 lines (178 loc) · 8.34 KB
/
ActionValue.py
File metadata and controls
207 lines (178 loc) · 8.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from collections import deque, namedtuple
from model.qcnn import Q_CNN_1, Q_CNN_2
from model.qnn import QNN
from torch.utils.tensorboard import SummaryWriter
import os
import datetime
import numpy as np
import random as rand
import torch
import torch.nn as nn
import torch.optim as optim
Transition = namedtuple(
'Transition',
('state', 'action', 'nextState', 'reward'))
class ReplayBuffer(object):
def __init__(self, seed, capacity):
# Seeding
if seed is not None:
rand.seed(seed)
self.buffer = deque([], maxlen=capacity)
def push(self, *args):
"""Save a transition"""
self.buffer.append(Transition(*args))
def sample(self, batchSize):
return rand.sample(self.buffer, batchSize)
def __len__(self):
return len(self.buffer)
class ActionValueFunction:
def __init__(
self,
seed: int,
numActions: int,
writer: SummaryWriter,
modelPath: str = None,
train: bool = True,
learningRate = 1e-2,
discountFactor = 0.99,
replayBufferCapacity = 100_000,
batchTransitionSampleSize = 32,
trainingFrequency = 4,
targetNetworkUpdateFrequency = 10_000,
checkpointRate = 1_000_000):
# Seeding
self.seed = -1
if seed is not None:
self.seed = seed
rand.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
self.writer = writer
# numInputs = 216 # Todo: make this configurable
# self.onlineNetwork = QNN(numInputs=numInputs, numOutputs=numActions).to(QNN.device)
self.onlineNetwork = Q_CNN_2(numOutputs=numActions).to(QNN.device)
self.train = train
if self.train:
self.optimizer = optim.AdamW(
self.onlineNetwork.parameters(),
lr= learningRate,
amsgrad=True)
# self.targetNetwork = QNN(numInputs=numInputs, numOutputs=numActions).to(QNN.device)
self.targetNetwork = Q_CNN_2(numOutputs=numActions).to(QNN.device)
self._hardUpdateTarget()
self.targetNetworkUpdateFrequency = targetNetworkUpdateFrequency
self.discountFactor = discountFactor
self.replayBuffer = ReplayBuffer(
seed,
replayBufferCapacity)
self.batchTransitionSampleSize = batchTransitionSampleSize
self.trainingFrequency = trainingFrequency
self.checkpointRate = checkpointRate
if modelPath is not None:
self._loadModel(modelPath)
self.numUpdates = 0 # Number of calls to update state info
self.numTrainingSteps = 0 # Number of optimizations
def evaluate(self, state):
self.onlineNetwork.eval() # Set the model to evaluation mode
with torch.no_grad():
qValues = self.onlineNetwork(Q_CNN_2.preProcess(state));
return qValues.cpu().numpy();
def update(self, state, action, reward, nextState, runTDUpdate):
if not self.train:
return None
state = Q_CNN_2.preProcess(state)
action = torch.tensor([[action]], device=QNN.device)
nextState = Q_CNN_2.preProcess(nextState) if nextState is not None else None
reward = torch.tensor([reward], device=QNN.device)
self.replayBuffer.push(
state,
action,
nextState,
reward)
self.numUpdates += 1
tdLossAvgQValuePair = None
if runTDUpdate:
if self.numUpdates % self.trainingFrequency == 0:
tdLossAvgQValuePair = self._optimize()
self.numTrainingSteps += 1
if self.numTrainingSteps % self.targetNetworkUpdateFrequency == 0: # save after periodic intervals of optimization steps
self._hardUpdateTarget()
if self.numTrainingSteps % self.checkpointRate == 0: # save after periodic intervals of optimization steps
self._saveModel()
return tdLossAvgQValuePair
def close(self):
if self.train:
self._saveModel()
def _optimize(self):
batchSize = self.batchTransitionSampleSize
if len(self.replayBuffer) < batchSize:
return
self.onlineNetwork.train() # Set the model to training mode
transitions = self.replayBuffer.sample(batchSize)
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for detailed explanation).
# This converts batch-array of Transitions to Transition of batch-arrays.
batch = Transition(*zip(*transitions))
# Compute a mask of non-final states and concatenate the batch elements
nonFinalMask = torch.tensor(
tuple(map(lambda s: s is not None, batch.nextState)),
device=QNN.device,
dtype=torch.bool)
nonTerminatingNextStates = torch.cat([s for s in batch.nextState if s is not None])
stateBatch = torch.cat(batch.state)
actionBatch = torch.cat(batch.action)
rewardBatch = torch.cat(batch.reward)
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the columns of actions taken. These are the actions which would've been taken for each batch state
stateActionValues = self.onlineNetwork(stateBatch).gather(1, actionBatch)
# Compute V(s_{t+1}) for all non-terminal next states USING DELAYED TARGET NETWORK
nextStateValues = torch.zeros(batchSize, device=QNN.device)
with torch.no_grad():
nextStateValues[nonFinalMask] = self.targetNetwork(nonTerminatingNextStates).max(1).values
# Compute the expected Q values
expectedStateActionValues = rewardBatch + (self.discountFactor * nextStateValues)
# Compute TD error using Huber Loss
lossCriterion = nn.SmoothL1Loss()
tdHuberError = lossCriterion(stateActionValues, expectedStateActionValues.unsqueeze(1))
# Optimize the model
self.optimizer.zero_grad()
tdHuberError.backward()
# torch.nn.utils.clip_grad_norm_(self.onlineNetwork.parameters(), 100) # Clip the gradient in-place
self.optimizer.step()
return (tdHuberError.item(), stateActionValues.mean().item())
def _hardUpdateTarget(self):
self.targetNetwork.load_state_dict(self.onlineNetwork.state_dict())
def _saveModel(self):
now = datetime.datetime.now().strftime("%Y%m%d_%H%M")
path = f'./model/{self.onlineNetwork.NAME}_checkpoint_{now}.pth' if self.seed < 0 else f'./model/{self.onlineNetwork.NAME}_checkpoint_{now}_seed_{self.seed}.pth'
model = self.onlineNetwork
torch.save(
{
'discountFactor': self.discountFactor, # override any discount factor input
'numTrainingSteps': self.numTrainingSteps,
'model_state_dict': self.onlineNetwork.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(), # learning rate baked in
},
path)
def _loadModel(self, path):
if path is not None:
try:
if not os.path.exists(path):
print(f"Path does not exist: {path}")
return
checkpoint = torch.load(path)
self.onlineNetwork.load_state_dict(checkpoint['model_state_dict'])
if self.train:
self.targetNetwork.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.discountFactor = checkpoint['discountFactor']
self.numTrainingSteps = checkpoint['numTrainingSteps']
print (f"Loaded model from {path}")
numParameters = sum(p.numel() for p in self.onlineNetwork.parameters())
print(f"Model has {numParameters} parameters")
for name, param in self.onlineNetwork.named_parameters():
print(f"{name}: {param.numel()} parameters")
except Exception as e:
print(f"Error loading model: {e}")
else:
print("No model found to load")