-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolved_hammer.py
136 lines (113 loc) · 5.19 KB
/
solved_hammer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
File: cartpolev1.py
Created: 2017-03-09
By Peter Caven, [email protected]
Description:
-- Python 3.6 --
Solve the CartPole-v1 problem:
- this is the same solution as for the 'CartPole-v0' problem, with the episode length extended.
- observed features are expanded with a random transform to ensure linear separability.
- action selection is by dot product of an expanded observation with a weight vector.
- a queued history of recent observations is shuffled and replayed to update the output weights.
- output weights are updated at the end of each incomplete episode by Widrow-Hoff LMS update.
- the target outputs for the LMS algorithm are the means of the past outputs.
- output weights are maintained at a fixed norm for regularization.
"""
import gym
from gym import wrappers
from numpy import *
from numpy.random import uniform,normal
from numpy.linalg import norm
from random import shuffle
from collections import deque
from statistics import mean
env = gym.make('CartPole-v1')
# env = wrappers.Monitor(env, '../experiments/cartpole/v1/experiment-1')
#------------------------------------------------------------------
# Hyperparameters
alpha = 2.0e-1 # the 'learning rate'
maxEpisodes = 1000 # run the agent for 'maxEpisodes'
maxTimeSteps = 500 # maximum number of steps per episode
fixedNorm = 0.5 # output weights are scaled to have norm == 'fixedNorm'
maxHistory = 5000 # maximum number of recent observations for replay
solvedEpisodes = 100 # cartpole is solved when average reward > 195 for 'solvedEpisodes'
episodeLength = 500 # the target for CartPole-v1
#------------------------------------------------------------------
# Observations Transform
inputLength = 4 # length of an observation vector
expansionFactor = 30 # expand observation dimensions by 'expansionFactor'
expandedLength = expansionFactor*inputLength # length of transformed observations
# Feature transform with fixed random weights.
V = normal(scale=1.0, size=(expandedLength, inputLength))
# Output weights, randomly initialized.
W = uniform(low=-1.0, high=1.0, size=expandedLength)
# Fix the norm of the output weights to 'fixedNorm'.
W *= fixedNorm/norm(W)
#------------------------------------------------------------------
def CartPoleAgent(alpha, W, V):
"""
CartPoleAgent solves 'CartPole-v1'.
"""
#--------------------------------------------------
# observation history
H = deque([], maxHistory)
# episode total reward history
R = deque([], solvedEpisodes)
# histories of positive and negative outputs
PO = deque([0], maxHistory)
NO = deque([0], maxHistory)
#--------------------------------------------------
for episode in range(maxEpisodes):
observation = env.reset()
H.append(observation)
totalReward = 0
for t in range(1,maxTimeSteps+1):
env.render()
#--------------------------------------------------
out = dot(tanh(dot(V,observation)), W)
if out < 0:
NO.append(out)
action = 0
else:
PO.append(out)
action = 1
#--------------------------------------------------
observation, reward, done, info = env.step(action)
H.append(observation)
totalReward += reward
#--------------------------------------------------
if done:
R.append(totalReward)
if t < episodeLength:
#------------------------------------------
# Replay shuffled past observations using the
# latest weights.
# Use the means of past outputs as
# LMS algorithm target outputs.
#------------------------------------------
mn = mean(NO)
mp = mean(PO)
shuffle(H)
for obs in H:
h = tanh(dot(V,obs)) # transform the observation
out = dot(h, W)
if out < 0:
e = mn - out
else:
e = mp - out
W += alpha * e * h # Widrow-Hoff LMS update
W *= fixedNorm/norm(W) # keep the weights at fixed norm
#------------------------------------------
#--------------------------------------------------
avgReward = sum(R)/solvedEpisodes
# print(f"[{episode:3d}:{totalReward:3.0f}] R:{avgReward:6.2f} mp:{mean(PO):7.3f} mn:{mean(NO):7.3f} len(H):{len(H):4d} W:{W[:2]}", flush=True)
#--------------------------------------------------
if avgReward == episodeLength:
print("Solved.")
return
#--------------------------------------------------
break
#------------------------------------------------------------------
#------------------------------------------------------------------
CartPoleAgent(alpha, W, V)
env.close()