Skip to content

Commit e61836a

Browse files
Update FluxModelApproximator references to FluxApproximator (#1051)
Co-authored-by: Jeremiah Lewis <--get>
1 parent bf37d4d commit e61836a

File tree

7 files changed

+72
-72
lines changed

7 files changed

+72
-72
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
export FluxApproximator
2+
3+
using Flux
4+
5+
"""
6+
FluxApproximator(model, optimiser)
7+
8+
Wraps a Flux trainable model and implements the `RLBase.optimise!(::FluxApproximator, ::Gradient)`
9+
interface. See the RLCore documentation for more information on proper usage.
10+
"""
11+
struct FluxApproximator{M,O} <: AbstractLearner
12+
model::M
13+
optimiser_state::O
14+
end
15+
16+
17+
"""
18+
FluxApproximator(; model, optimiser, usegpu=false)
19+
20+
Constructs an `FluxApproximator` object for reinforcement learning.
21+
22+
# Arguments
23+
- `model`: The model used for approximation.
24+
- `optimiser`: The optimizer used for updating the model.
25+
- `usegpu`: A boolean indicating whether to use GPU for computation. Default is `false`.
26+
27+
# Returns
28+
An `FluxApproximator` object.
29+
"""
30+
function FluxApproximator(; model, optimiser, use_gpu=false)
31+
optimiser_state = Flux.setup(optimiser, model)
32+
if use_gpu # Pass model to GPU (if available) upon creation
33+
return FluxApproximator(gpu(model), gpu(optimiser_state))
34+
else
35+
return FluxApproximator(model, optimiser_state)
36+
end
37+
end
38+
39+
FluxApproximator(model, optimiser::Flux.Optimise.AbstractOptimiser; use_gpu=false) = FluxApproximator(model=model, optimiser=optimiser, use_gpu=use_gpu)
40+
41+
Flux.@layer FluxApproximator trainable=(model,)
42+
43+
forward(A::FluxApproximator, args...; kwargs...) = A.model(args...; kwargs...)
44+
forward(A::FluxApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(A, x))
45+
46+
RLBase.optimise!(A::FluxApproximator, grad::NamedTuple) =
47+
Flux.Optimise.update!(A.optimiser_state, A.model, grad.model)

src/ReinforcementLearningCore/src/policies/learners/flux_model_approximator.jl

-47
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
include("abstract_learner.jl")
2-
include("flux_model_approximator.jl")
2+
include("flux_approximator.jl")
33
include("tabular_approximator.jl")
44
include("td_learner.jl")
55
include("target_network.jl")

src/ReinforcementLearningCore/src/policies/learners/target_network.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ export TargetNetwork, target, model
22

33
using Flux
44

5-
target(ap::FluxModelApproximator) = ap.model #see TargetNetwork
6-
model(ap::FluxModelApproximator) = ap.model #see TargetNetwork
5+
target(ap::FluxApproximator) = ap.model #see TargetNetwork
6+
model(ap::FluxApproximator) = ap.model #see TargetNetwork
77

88
"""
9-
TargetNetwork(network::FluxModelApproximator; sync_freq::Int = 1, ρ::Float32 = 0f0)
9+
TargetNetwork(network::FluxApproximator; sync_freq::Int = 1, ρ::Float32 = 0f0)
1010
11-
Wraps an FluxModelApproximator to hold a target network that is updated towards the model of the
11+
Wraps an FluxApproximator to hold a target network that is updated towards the model of the
1212
approximator.
1313
- `sync_freq` is the number of updates of `network` between each update of the `target`.
1414
- ρ (\rho) is "how much of the target is kept when updating it".
@@ -21,11 +21,11 @@ Implements the `RLBase.optimise!(::TargetNetwork, ::Gradient)` interface to upda
2121
and the target with weights replacement or Polyak averaging.
2222
2323
Note to developers: `model(::TargetNetwork)` will return the trainable Flux model
24-
and `target(::TargetNetwork)` returns the target model and `target(::FluxModelApproximator)`
24+
and `target(::TargetNetwork)` returns the target model and `target(::FluxApproximator)`
2525
returns the non-trainable Flux model. See the RLCore documentation.
2626
"""
2727
mutable struct TargetNetwork{M}
28-
network::FluxModelApproximator{M}
28+
network::FluxApproximator{M}
2929
target::M
3030
sync_freq::Int
3131
ρ::Float32
@@ -46,13 +46,13 @@ Constructs a target network for reinforcement learning.
4646
# Returns
4747
A `TargetNetwork` object.
4848
"""
49-
function TargetNetwork(network::FluxModelApproximator; sync_freq = 1, ρ = 0f0, use_gpu = false)
49+
function TargetNetwork(network::FluxApproximator; sync_freq = 1, ρ = 0f0, use_gpu = false)
5050
@assert 0 <= ρ <= 1 "ρ must in [0,1]"
5151
ρ = Float32(ρ)
5252

5353
if use_gpu
54-
@assert typeof(gpu(network.model)) == typeof(network.model) "`FluxModelApproximator` model is not on GPU. Please set `use_gpu=false`` or ensure model is on GPU, by setting `use_gpu=true` when constructing `FluxModelApproximator`."
55-
# NOTE: model is pushed to gpu in FluxModelApproximator, need to transfer to cpu before deepcopy, then push target model to gpu
54+
@assert typeof(gpu(network.model)) == typeof(network.model) "`FluxApproximator` model is not on GPU. Please set `use_gpu=false`` or ensure model is on GPU, by setting `use_gpu=true` when constructing `FluxApproximator`."
55+
# NOTE: model is pushed to gpu in FluxApproximator, need to transfer to cpu before deepcopy, then push target model to gpu
5656
target = gpu(deepcopy(cpu(network.model)))
5757
else
5858
target = deepcopy(network.model)

src/ReinforcementLearningCore/test/policies/learners/flux_model_approximator.jl renamed to src/ReinforcementLearningCore/test/policies/learners/flux_approximator.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
using Test
22
using Flux
33

4-
@testset "FluxModelApproximator Tests" begin
4+
@testset "FluxApproximator Tests" begin
55
@testset "Creation, with use_gpu = true toggle" begin
66
model = Chain(Dense(10, 5, relu), Dense(5, 2))
77
optimiser = Adam()
8-
approximator = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=true)
8+
approximator = FluxApproximator(model=model, optimiser=optimiser, use_gpu=true)
99

10-
@test approximator isa FluxModelApproximator
10+
@test approximator isa FluxApproximator
1111
@test typeof(approximator.model) == typeof(gpu(model))
1212
@test approximator.optimiser_state isa NamedTuple
1313
end
1414

1515
@testset "Forward" begin
1616
model = Chain(Dense(10, 5, relu), Dense(5, 2))
1717
optimiser = Adam()
18-
approximator = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=false)
18+
approximator = FluxApproximator(model=model, optimiser=optimiser, use_gpu=false)
1919

2020
input = rand(Float32, 10)
2121
output = RLCore.forward(approximator, input)
@@ -27,7 +27,7 @@ using Flux
2727
@testset "Forward to environment" begin
2828
model = Chain(Dense(4, 5, relu), Dense(5, 2))
2929
optimiser = Adam()
30-
approximator = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=false)
30+
approximator = FluxApproximator(model=model, optimiser=optimiser, use_gpu=false)
3131

3232
env = CartPoleEnv(T=Float32)
3333
output = RLCore.forward(approximator, env)
@@ -38,7 +38,7 @@ using Flux
3838
@testset "Optimise" begin
3939
model = Chain(Dense(10, 5, relu), Dense(5, 2))
4040
optimiser = Adam()
41-
approximator = FluxModelApproximator(model=model, optimiser=optimiser)
41+
approximator = FluxApproximator(model=model, optimiser=optimiser)
4242

4343
input = rand(Float32, 10)
4444

src/ReinforcementLearningCore/test/policies/learners/learners.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testset "approximators.jl" begin
22
include("abstract_learner.jl")
3-
include("flux_model_approximator.jl")
3+
include("flux_approximator.jl")
44
include("tabular_approximator.jl")
55
include("target_network.jl")
66
include("td_learner.jl")

src/ReinforcementLearningCore/test/policies/learners/target_network.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ using ReinforcementLearningCore
77
model = Chain(Dense(10, 5, relu), Dense(5, 2))
88
optimiser = Adam()
99
if ((@isdefined CUDA) && CUDA.functional()) || ((@isdefined Metal) && Metal.functional())
10-
@test_throws "AssertionError: `FluxModelApproximator` model is not on GPU." TargetNetwork(FluxModelApproximator(model, optimiser), use_gpu=true)
10+
@test_throws "AssertionError: `FluxApproximator` model is not on GPU." TargetNetwork(FluxApproximator(model, optimiser), use_gpu=true)
1111
end
12-
@test TargetNetwork(FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork
13-
@test TargetNetwork(FluxModelApproximator(model, optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork
12+
@test TargetNetwork(FluxApproximator(model=model, optimiser=optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork
13+
@test TargetNetwork(FluxApproximator(model, optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork
1414

15-
approx = FluxModelApproximator(model, optimiser, use_gpu=false)
15+
approx = FluxApproximator(model, optimiser, use_gpu=false)
1616
target_network = TargetNetwork(approx, use_gpu=false)
1717

1818

@@ -26,7 +26,7 @@ using ReinforcementLearningCore
2626

2727
@testset "Forward" begin
2828
model = Chain(Dense(10, 5, relu), Dense(5, 2))
29-
target_network = TargetNetwork(FluxModelApproximator(model, Adam()))
29+
target_network = TargetNetwork(FluxApproximator(model, Adam()))
3030

3131
input = rand(Float32, 10)
3232
output = RLCore.forward(target_network, input)
@@ -38,7 +38,7 @@ using ReinforcementLearningCore
3838
@testset "Optimise" begin
3939
optimiser = Adam()
4040
model = Chain(Dense(10, 5, relu), Dense(5, 2))
41-
approximator = FluxModelApproximator(model, optimiser)
41+
approximator = FluxApproximator(model, optimiser)
4242
target_network = TargetNetwork(approximator)
4343
input = rand(Float32, 10)
4444
grad = Flux.Zygote.gradient(target_network) do model
@@ -54,7 +54,7 @@ using ReinforcementLearningCore
5454

5555
@testset "Sync" begin
5656
optimiser = Adam()
57-
model = FluxModelApproximator(Chain(Dense(10, 5, relu), Dense(5, 2)), optimiser)
57+
model = FluxApproximator(Chain(Dense(10, 5, relu), Dense(5, 2)), optimiser)
5858
target_network = TargetNetwork(model, sync_freq=2, ρ=0.5)
5959

6060
input = rand(Float32, 10)
@@ -73,7 +73,7 @@ end
7373

7474
@testset "TargetNetwork" begin
7575
m = Chain(Dense(4,1))
76-
app = FluxModelApproximator(model = m, optimiser = Flux.Adam(), use_gpu=true)
76+
app = FluxApproximator(model = m, optimiser = Flux.Adam(), use_gpu=true)
7777
tn = TargetNetwork(app, sync_freq = 3, use_gpu=true)
7878
@test typeof(model(tn)) == typeof(target(tn))
7979
p1 = Flux.destructure(model(tn))[1]

0 commit comments

Comments
 (0)