Skip to content

Commit 24578f3

Browse files
authored
Merge pull request #13 from JuliaReinforcementLearning/improve_doc
improve docs
2 parents e435e20 + efe3668 commit 24578f3

File tree

5 files changed

+193
-7
lines changed

5 files changed

+193
-7
lines changed

docs/make.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,19 @@ makedocs(modules = [ReinforcementLearning],
77
linkcheck = !("skiplinks" in ARGS),
88
pages = [ "Introduction" => "index.md",
99
"Usage" => "usage.md",
10+
"Tutorial" => "tutorial.md",
1011
"Reference" => ["Comparison" => "comparison.md",
1112
"Learning" => "learning.md",
1213
"Learners" => "learners.md",
14+
"Buffers" => "buffers.md",
1315
"Environments" => "environments.md",
1416
"Stopping Criteria" => "stop.md",
1517
"Preprocessors" => "preprocessors.md",
1618
"Policies" => "policies.md",
1719
"Callbacks" => "callbacks.md",
1820
"Evaluation Metrics" => "metrics.md",
19-
],
20-
"API" => "api.md"],
21+
]
22+
],
2123
html_prettyurls = true
2224
)
2325

docs/src/buffers.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# [Buffers](@id buffers)
2+
3+
```@autodocs
4+
Modules = [ReinforcementLearning]
5+
Pages = ["buffers.jl"]
6+
```
7+

docs/src/tutorial.md

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Tutorial
2+
You would like to test existing reinforcement learning methods on your
3+
environment or try your method on existing environments? Extending this package
4+
is a piece of cake. Please consider registering the binding to your own
5+
environment as a new package (see e.g.
6+
[RLEnvAtari](https://github.com/JuliaReinforcementLearning/RLEnvAtari.jl)) and
7+
open a [pull
8+
request](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pulls)
9+
for any other extension.
10+
11+
## Write your own learner
12+
13+
For a new learner you need to implement the functions
14+
```
15+
update!(learner, buffer) # returns nothing
16+
selectaction(learner, policy, state) # returns an action
17+
defaultbuffer(learner, environment, preprocessor) # returns a buffer
18+
```
19+
20+
Let's assume you want to implement plain, simple Q-learning (you don't need to
21+
do this; it is already implemented. Your file `qlearning.jl` could contain
22+
```julia
23+
import ReinforcementLearning: update!, selectaction, defaultbuffer, Buffer
24+
25+
struct MyQLearning
26+
Q::Array{Float64, 2} # number of actions x number of states
27+
alpha::Float64 # learning rate
28+
end
29+
30+
function update!(learner::MyQLearning, buffer)
31+
s = buffer.states[1]
32+
snext = buffer.states[2]
33+
r = buffer.rewards[1]
34+
a = buffer.actions[1]
35+
Q = learner.Q
36+
Q[a, s] += learner.alpha * (r + maximum(Q[:, snext]) - Q[a, s])
37+
end
38+
39+
function selectaction(learner::MyQLearning, policy, state)
40+
selectaction(policy, learner.Q[:, state])
41+
end
42+
43+
function defaultbuffer(learner::MyQLearning, environment, preprocessor)
44+
state, done = getstate(environment)
45+
processedstate = preprocessstate(preprocessor, state)
46+
Buffer(statetype = typeof(processedstate), capacity = 2)
47+
end
48+
```
49+
The function `defaultbuffer` gets called during the construction of an
50+
`RLSetup`. It returns a buffer that is filled with states, actions and rewards
51+
during interaction with the environment. Currently there are three types of
52+
Buffers implemented
53+
```julia
54+
import ReinforcementLearning: Buffer, EpisodeBuffer, ArrayStateBuffer
55+
?Buffer
56+
```
57+
58+
## [Bind your own environment](@id api_environments)
59+
For new environments you need to implement the functions
60+
```
61+
interact!(action, environment) # returns state, reward done
62+
getstate(environment) # returns state, done
63+
reset!(environment) # returns state
64+
```
65+
66+
Optionally you may also implement the function
67+
```
68+
plotenv(environment, state, action, reward, done)
69+
```
70+
71+
Please have a look at the
72+
[cartpole](https://github.com/JuliaReinforcementLearning/RLEnvClassicControl.jl/blob/master/src/cartpole.jl)
73+
for an example.
74+
75+
## Preprocessors
76+
```
77+
preprocessstate(preprocessor, state) # returns the preprocessed state
78+
```
79+
Optional:
80+
```
81+
preprocess(preprocessor, reward, state, done) # returns a preprocessed (state, reward done) tuple.
82+
```
83+
84+
## Policies
85+
```
86+
selectaction(policy, values) # returns an action
87+
getactionprobabilities(policy, state) # Returns a normalized (1-norm) vector with non-negative entries.
88+
```
89+
90+
## Callbacks
91+
```
92+
callback!(callback, rlsetup, state, action, reward, done) # returns nothing
93+
```
94+
95+
## Stopping Criteria
96+
```
97+
isbreak!(stoppingcriterion, state, action, reward, done) # returns true of false
98+
```

src/buffers.jl

+53
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
"""
2+
struct Buffer{Ts, Ta}
3+
states::CircularBuffer{Ts}
4+
actions::CircularBuffer{Ta}
5+
rewards::CircularBuffer{Float64}
6+
done::CircularBuffer{Bool}
7+
"""
18
struct Buffer{Ts, Ta}
29
states::CircularBuffer{Ts}
310
actions::CircularBuffer{Ta}
411
rewards::CircularBuffer{Float64}
512
done::CircularBuffer{Bool}
613
end
14+
"""
15+
Buffer(; statetype = Int64, actiontype = Int64,
16+
capacity = 2, capacitystates = capacity,
17+
capacityrewards = capacity - 1)
18+
"""
719
function Buffer(; statetype = Int64, actiontype = Int64,
820
capacity = 2, capacitystates = capacity,
921
capacityrewards = capacity - 1)
@@ -23,12 +35,23 @@ function pushreturn!(b, r, done)
2335
push!(b.done, done)
2436
end
2537

38+
"""
39+
struct EpisodeBuffer{Ts, Ta}
40+
states::Array{Ts, 1}
41+
actions::Array{Ta, 1}
42+
rewards::Array{Float64, 1}
43+
done::Array{Bool, 1}
44+
"""
2645
struct EpisodeBuffer{Ts, Ta}
2746
states::Array{Ts, 1}
2847
actions::Array{Ta, 1}
2948
rewards::Array{Float64, 1}
3049
done::Array{Bool, 1}
3150
end
51+
"""
52+
EpisodeBuffer(; statetype = Int64, actiontype = Int64) =
53+
EpisodeBuffer(statetype[], actiontype[], Float64[], Bool[])
54+
"""
3255
EpisodeBuffer(; statetype = Int64, actiontype = Int64) =
3356
EpisodeBuffer(statetype[], actiontype[], Float64[], Bool[])
3457
function pushreturn!(b::EpisodeBuffer, r, done)
@@ -42,13 +65,24 @@ function pushreturn!(b::EpisodeBuffer, r, done)
4265
push!(b.done, done)
4366
end
4467

68+
"""
69+
mutable struct ArrayCircularBuffer{T}
70+
data::T
71+
capacity::Int64
72+
start::Int64
73+
counter::Int64
74+
full::Bool
75+
"""
4576
mutable struct ArrayCircularBuffer{T}
4677
data::T
4778
capacity::Int64
4879
start::Int64
4980
counter::Int64
5081
full::Bool
5182
end
83+
"""
84+
ArrayCircularBuffer(arraytype, datatype, elemshape, capacity)
85+
"""
5286
function ArrayCircularBuffer(arraytype, datatype, elemshape, capacity)
5387
ArrayCircularBuffer(arraytype(zeros(datatype,
5488
convert(Dims, (elemshape..., capacity)))),
@@ -96,12 +130,31 @@ for N in 2:5
96130
end
97131
lastindex(a::ArrayCircularBuffer) = a.full ? a.capacity : a.counter
98132

133+
"""
134+
struct ArrayStateBuffer{Ts, Ta}
135+
states::ArrayCircularBuffer{Ts}
136+
actions::CircularBuffer{Ta}
137+
rewards::CircularBuffer{Float64}
138+
done::CircularBuffer{Bool}
139+
"""
99140
struct ArrayStateBuffer{Ts, Ta}
100141
states::ArrayCircularBuffer{Ts}
101142
actions::CircularBuffer{Ta}
102143
rewards::CircularBuffer{Float64}
103144
done::CircularBuffer{Bool}
104145
end
146+
"""
147+
ArrayStateBuffer(; arraytype = Array, datatype = Float64,
148+
elemshape = (1), actiontype = Int64,
149+
capacity = 2, capacitystates = capacity,
150+
capacityrewards = capacity - 1)
151+
152+
An `ArrayStateBuffer` is similar to a [`Buffer`](@ref) but the states are stored
153+
in a prealocated array of size `(elemshape..., capacity)`. `K` consecutive
154+
states at position `i` in the state buffer can can efficiently be retrieved with
155+
`nmarkovview(buffer.states, i, K)` or `nmarkovgetindex(buffer.states, i, K)`.
156+
See the implementation of DQN for an example.
157+
"""
105158
function ArrayStateBuffer(; arraytype = Array, datatype = Float64,
106159
elemshape = (1), actiontype = Int64,
107160
capacity = 2, capacitystates = capacity,

src/learner/tdlearning.jl

+31-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
"""
2+
mutable struct TDLearner{T,Tp}
3+
ns::Int64 = 10
4+
na::Int64 = 4
5+
γ::Float64 = .9
6+
λ::Float64 = .8
7+
α::Float64 = .1
8+
nsteps::Int64 = 1
9+
initvalue::Float64 = 0.
10+
unseenvalue::Float64 = initvalue == Inf64 ? 0. : initvalue
11+
params::Array{Float64, 2} = zeros(na, ns) .+ initvalue
12+
tracekind = DataType = λ == 0 ? NoTraces : ReplacingTraces
13+
traces::T = tracekind == NoTraces ? NoTraces() : tracekind(ns, na, λ, γ)
14+
endvaluepolicy::Tp = SarsaEndPolicy()
15+
"""
116
@with_kw mutable struct TDLearner{T,Tp}
217
ns::Int64 = 10
318
na::Int64 = 4
@@ -17,18 +32,29 @@ struct QLearningEndPolicy end
1732
struct ExpectedSarsaEndPolicy{Tp}
1833
policy::Tp
1934
end
20-
Sarsa(; kargs...) = TDLearner(; kargs...)
21-
QLearning(; kargs...) = TDLearner(; endvaluepolicy = QLearningEndPolicy(), kargs...)
22-
ExpectedSarsa(; kargs...) = TDLearner(; endvaluepolicy = ExpectedSarsaEndPolicy(VeryOptimisticEpsilonGreedyPolicy(.1)), kargs...)
35+
"""
36+
Sarsa(; kargs...) = TDLearner(; kargs...)
37+
"""
38+
function Sarsa(; kargs...) TDLearner(; kargs...) end
39+
"""
40+
QLearning(; kargs...) = TDLearner(; endvaluepolicy = QLearningEndPolicy(), kargs...)
41+
"""
42+
function QLearning(; kargs...)
43+
TDLearner(; endvaluepolicy = QLearningEndPolicy(), kargs...)
44+
end
45+
"""
46+
ExpectedSarsa(; kargs...) = TDLearner(; endvaluepolicy = ExpectedSarsaEndPolicy(VeryOptimisticEpsilonGreedyPolicy(.1)), kargs...)
47+
"""
48+
function ExpectedSarsa(; kargs...)
49+
TDLearner(; endvaluepolicy = ExpectedSarsaEndPolicy(VeryOptimisticEpsilonGreedyPolicy(.1)), kargs...)
50+
end
2351
export Sarsa, QLearning, ExpectedSarsa
2452

2553
@inline function selectaction(learner::Union{TDLearner, AbstractPolicyGradient},
2654
policy,
2755
state)
2856
selectaction(policy, getvalue(learner.params, state))
2957
end
30-
params(learner::TDLearner) = learner.params
31-
reconstructwithparams(learner::TDLearner, w) = reconstruct(learner, params = w)
3258

3359
# td error
3460

0 commit comments

Comments
 (0)