Skip to content

Commit bf37d4d

Browse files
Add SARS tdlearning back to lib (#1050)
1 parent 9e06129 commit bf37d4d

29 files changed

+402
-134
lines changed

.buildkite/pipeline.yml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ steps:
1717
Pkg.develop(path="src/ReinforcementLearningBase")
1818
Pkg.develop(path="src/ReinforcementLearningEnvironments")
1919
Pkg.develop(path="src/ReinforcementLearningCore")
20+
Pkg.develop(path="src/ReinforcementLearningFarm")
2021
2122
println("+++ :julia: Running tests")
2223
Pkg.test("ReinforcementLearningCore", coverage=true)

.github/workflows/ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ jobs:
9595
Pkg.develop(path="src/ReinforcementLearningBase")
9696
Pkg.develop(path="src/ReinforcementLearningCore")
9797
Pkg.develop(path="src/ReinforcementLearningEnvironments")
98+
Pkg.develop(path="src/ReinforcementLearningFarm")
9899
Pkg.test("ReinforcementLearningCore", coverage=true)'
99100
- uses: julia-actions/julia-processcoverage@v1
100101
with:

src/ReinforcementLearningCore/Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Metal = "1.0"
3737
ProgressMeter = "1"
3838
Reexport = "1"
3939
ReinforcementLearningBase = "0.12"
40+
ReinforcementLearningFarm = "0.0.1"
4041
ReinforcementLearningTrajectories = "0.3.7"
4142
Statistics = "1"
4243
StatsBase = "0.32, 0.33, 0.34"
@@ -52,9 +53,10 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
5253
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
5354
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5455
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
56+
ReinforcementLearningFarm = "14eff660-7080-4cec-bba2-cfb12cd77ac3"
5557
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5658
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
5759
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
5860

5961
[targets]
60-
test = ["CommonRLInterface", "CUDA", "cuDNN", "DomainSets", "Metal", "Preferences", "ReinforcementLearningEnvironments", "Test", "UUIDs"]
62+
test = ["CommonRLInterface", "CUDA", "cuDNN", "DomainSets", "Metal", "Preferences", "ReinforcementLearningEnvironments", "ReinforcementLearningFarm", "Test", "UUIDs"]

src/ReinforcementLearningCore/src/policies/agent/agent_base.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ RLBase.optimise!(::SyncTrajectoryStyle, agent::AbstractAgent, stage::S) where {S
3737
# already spawn a task to optimise inner policy when initializing the agent
3838
RLBase.optimise!(::AsyncTrajectoryStyle, agent::AbstractAgent, stage::S) where {S<:AbstractStage} = nothing
3939

40-
#by default, optimise does nothing at all stage
40+
#by default, optimise does nothing at all stages
4141
function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end
4242

4343
Flux.@layer Agent trainable=(policy,)

src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl

+8-7
Original file line numberDiff line numberDiff line change
@@ -99,31 +99,32 @@ get_ϵ(s::EpsilonGreedyExplorer) = get_ϵ(s, s.step)
9999
`NaN` will be filtered unless all the values are `NaN`.
100100
In that case, a random one will be returned.
101101
"""
102-
function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::Vector{I}) where {I<:Real}
102+
function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::A) where {I<:Real, A<:AbstractArray{I}}
103103
ϵ = get_ϵ(s)
104104
s.step += 1
105105
rand(s.rng) >= ϵ ? rand(s.rng, find_all_max(values)[2]) : rand(s.rng, 1:length(values))
106106
end
107107

108-
function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::Vector{I}) where {I<:Real}
108+
function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::A) where {I<:Real, A<:AbstractArray{I}}
109109
ϵ = get_ϵ(s)
110110
s.step += 1
111111
rand(s.rng) >= ϵ ? findmax(values)[2] : rand(s.rng, 1:length(values))
112112
end
113113

114114
#####
115115

116-
RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, x, mask::Trues) = RLBase.plan!(s, x)
116+
RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, x::A, mask::Trues) where {I<:Real, A<:AbstractArray{I}} = RLBase.plan!(s, x)
117117

118-
function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::Vector{I}, mask::M) where {I<:Real, M<:Union{BitVector, Vector{Bool}}}
118+
function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::A, mask::M) where {I<:Real, A<:AbstractArray{I}, M<:Union{BitVector, Vector{Bool}}}
119119
ϵ = get_ϵ(s)
120120
s.step += 1
121121
rand(s.rng) >= ϵ ? rand(s.rng, find_all_max(values, mask)[2]) :
122122
rand(s.rng, findall(mask))
123123
end
124124

125-
RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, x::Vector{I}, mask::Trues) where{I<:Real} = RLBase.plan!(s, x)
126-
function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::Vector{I}, mask::M) where {I<:Real, M<:Union{BitVector, Vector{Bool}}}
125+
RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, x::A, mask::Trues) where{I<:Real, A<:AbstractArray{I}} = RLBase.plan!(s, x)
126+
127+
function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::A, mask::M) where {I<:Real, A<:AbstractArray{I}, M<:Union{BitVector, Vector{Bool}}}
127128
ϵ = get_ϵ(s)
128129
s.step += 1
129130
rand(s.rng) >= ϵ ? findmax_masked(values, mask)[2] : rand(s.rng, findall(mask))
@@ -137,7 +138,7 @@ end
137138
138139
Return the probability of selecting each action given the estimated `values` of each action.
139140
"""
140-
function RLBase.prob(s::EpsilonGreedyExplorer{<:Any,true}, values)
141+
function RLBase.prob(s::EpsilonGreedyExplorer{<:Any,true}, values::A) where {I<:Real, A<:AbstractArray{I}}
141142
ϵ, n = get_ϵ(s), length(values)
142143
probs = fill/ n, n)
143144
max_val_inds = find_all_max(values)[2]

src/ReinforcementLearningCore/src/policies/learners/approximator.jl

-45
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
export FluxModelApproximator
2+
3+
using Flux
4+
5+
"""
6+
FluxModelApproximator(model, optimiser)
7+
8+
Wraps a Flux trainable model and implements the `RLBase.optimise!(::FluxModelApproximator, ::Gradient)`
9+
interface. See the RLCore documentation for more information on proper usage.
10+
"""
11+
struct FluxModelApproximator{M,O} <: AbstractLearner
12+
model::M
13+
optimiser_state::O
14+
end
15+
16+
17+
"""
18+
FluxModelApproximator(; model, optimiser, usegpu=false)
19+
20+
Constructs an `FluxModelApproximator` object for reinforcement learning.
21+
22+
# Arguments
23+
- `model`: The model used for approximation.
24+
- `optimiser`: The optimizer used for updating the model.
25+
- `usegpu`: A boolean indicating whether to use GPU for computation. Default is `false`.
26+
27+
# Returns
28+
An `FluxModelApproximator` object.
29+
"""
30+
function FluxModelApproximator(; model, optimiser, use_gpu=false)
31+
optimiser_state = Flux.setup(optimiser, model)
32+
if use_gpu # Pass model to GPU (if available) upon creation
33+
return FluxModelApproximator(gpu(model), gpu(optimiser_state))
34+
else
35+
return FluxModelApproximator(model, optimiser_state)
36+
end
37+
end
38+
39+
FluxModelApproximator(model, optimiser::Flux.Optimise.AbstractOptimiser; use_gpu=false) = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=use_gpu)
40+
41+
Flux.@layer FluxModelApproximator trainable=(model,)
42+
43+
forward(A::FluxModelApproximator, args...; kwargs...) = A.model(args...; kwargs...)
44+
forward(A::FluxModelApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(A, x))
45+
46+
RLBase.optimise!(A::FluxModelApproximator, grad::NamedTuple) =
47+
Flux.Optimise.update!(A.optimiser_state, A.model, grad.model)
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
include("abstract_learner.jl")
2-
include("approximator.jl")
2+
include("flux_model_approximator.jl")
33
include("tabular_approximator.jl")
4+
include("td_learner.jl")
45
include("target_network.jl")
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,49 @@
11
export TabularApproximator, TabularVApproximator, TabularQApproximator
22

3-
const TabularApproximator = Approximator{A,O} where {A<:AbstractArray,O}
4-
const TabularQApproximator = Approximator{A,O} where {A<:AbstractArray,O}
5-
const TabularVApproximator = Approximator{A,O} where {A<:AbstractVector,O}
3+
struct TabularApproximator{A} <: AbstractLearner where {A<:AbstractArray}
4+
model::A
5+
end
6+
7+
const TabularQApproximator = TabularApproximator{A} where {A<:AbstractMatrix}
8+
const TabularVApproximator = TabularApproximator{A} where {A<:AbstractVector}
69

710
"""
8-
TabularApproximator(table<:AbstractArray, opt)
11+
TabularApproximator(table<:AbstractArray)
912
1013
For `table` of 1-d, it will serve as a state value approximator.
1114
For `table` of 2-d, it will serve as a state-action value approximator.
1215
1316
!!! warning
1417
For `table` of 2-d, the first dimension is action and the second dimension is state.
1518
"""
16-
function TabularApproximator(table::A, opt::O) where {A<:AbstractArray,O}
19+
function TabularApproximator(table::A) where {A<:AbstractArray}
1720
n = ndims(table)
1821
n <= 2 || throw(ArgumentError("the dimension of table must be <= 2"))
19-
TabularApproximator{A,O}(table, opt)
22+
TabularApproximator{A}(table)
2023
end
2124

22-
TabularVApproximator(; n_state, init = 0.0, opt = InvDecay(1.0)) =
23-
TabularApproximator(fill(init, n_state), opt)
25+
TabularVApproximator(; n_state, init = 0.0) =
26+
TabularApproximator(fill(init, n_state))
2427

25-
TabularQApproximator(; n_state, n_action, init = 0.0, opt = InvDecay(1.0)) =
26-
TabularApproximator(fill(init, n_action, n_state), opt)
28+
TabularQApproximator(; n_state, n_action, init = 0.0) =
29+
TabularApproximator(fill(init, n_action, n_state))
2730

2831
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
2932
forward(L::TabularVApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(L, x))
3033
forward(L::TabularQApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(L, x))
3134

3235
RLCore.forward(
33-
app::TabularVApproximator{R,O},
36+
app::TabularVApproximator{R},
3437
s::I,
35-
) where {R<:AbstractVector,O,I} = @views app.model[s]
38+
) where {R<:AbstractVector,I} = @views app.model[s]
3639

3740
RLCore.forward(
38-
app::TabularQApproximator{R,O},
41+
app::TabularQApproximator{R},
3942
s::I,
40-
) where {R<:AbstractArray,O,I} = @views app.model[:, s]
43+
) where {R<:AbstractArray,I} = @views app.model[:, s]
4144

4245
RLCore.forward(
43-
app::TabularQApproximator{R,O},
46+
app::TabularQApproximator{R},
4447
s::I1,
4548
a::I2,
46-
) where {R<:AbstractArray,O,I1,I2} = @views app.model[a, s]
49+
) where {R<:AbstractArray,I1,I2} = @views app.model[a, s]

src/ReinforcementLearningCore/src/policies/learners/target_network.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
export Approximator, TargetNetwork, target, model
1+
export TargetNetwork, target, model
22

33
using Flux
44

5-
target(ap::Approximator) = ap.model #see TargetNetwork
6-
model(ap::Approximator) = ap.model #see TargetNetwork
5+
target(ap::FluxModelApproximator) = ap.model #see TargetNetwork
6+
model(ap::FluxModelApproximator) = ap.model #see TargetNetwork
77

88
"""
9-
TargetNetwork(network::Approximator; sync_freq::Int = 1, ρ::Float32 = 0f0)
9+
TargetNetwork(network::FluxModelApproximator; sync_freq::Int = 1, ρ::Float32 = 0f0)
1010
11-
Wraps an Approximator to hold a target network that is updated towards the model of the
11+
Wraps an FluxModelApproximator to hold a target network that is updated towards the model of the
1212
approximator.
1313
- `sync_freq` is the number of updates of `network` between each update of the `target`.
1414
- ρ (\rho) is "how much of the target is kept when updating it".
@@ -21,11 +21,11 @@ Implements the `RLBase.optimise!(::TargetNetwork, ::Gradient)` interface to upda
2121
and the target with weights replacement or Polyak averaging.
2222
2323
Note to developers: `model(::TargetNetwork)` will return the trainable Flux model
24-
and `target(::TargetNetwork)` returns the target model and `target(::Approximator)`
24+
and `target(::TargetNetwork)` returns the target model and `target(::FluxModelApproximator)`
2525
returns the non-trainable Flux model. See the RLCore documentation.
2626
"""
2727
mutable struct TargetNetwork{M}
28-
network::Approximator{M}
28+
network::FluxModelApproximator{M}
2929
target::M
3030
sync_freq::Int
3131
ρ::Float32
@@ -46,13 +46,13 @@ Constructs a target network for reinforcement learning.
4646
# Returns
4747
A `TargetNetwork` object.
4848
"""
49-
function TargetNetwork(network::Approximator; sync_freq = 1, ρ = 0f0, use_gpu = false)
49+
function TargetNetwork(network::FluxModelApproximator; sync_freq = 1, ρ = 0f0, use_gpu = false)
5050
@assert 0 <= ρ <= 1 "ρ must in [0,1]"
5151
ρ = Float32(ρ)
5252

5353
if use_gpu
54-
@assert typeof(gpu(network.model)) == typeof(network.model) "`Approximator` model is not on GPU. Please set `use_gpu=false`` or ensure model is on GPU, by setting `use_gpu=true` when constructing `Approximator`."
55-
# NOTE: model is pushed to gpu in Approximator, need to transfer to cpu before deepcopy, then push target model to gpu
54+
@assert typeof(gpu(network.model)) == typeof(network.model) "`FluxModelApproximator` model is not on GPU. Please set `use_gpu=false`` or ensure model is on GPU, by setting `use_gpu=true` when constructing `FluxModelApproximator`."
55+
# NOTE: model is pushed to gpu in FluxModelApproximator, need to transfer to cpu before deepcopy, then push target model to gpu
5656
target = gpu(deepcopy(cpu(network.model)))
5757
else
5858
target = deepcopy(network.model)

0 commit comments

Comments
 (0)