Skip to content

Commit a982d70

Browse files
GPU Code Migration Part 2.1 (#1029)
* Rearrange approximator setup * Drop Cuda references from RLCore * Fix test error * Move device out of RLCore * Remove CUDA from RLCore dependencies * Drop excess test * Fix RLEnv version * Dropping device is breaking change, bump version * Version bump RLCore in other pkg compat entries * Syntax error * Fix RLEnv dependencies * Mostly fixed... * Fix deps * Drop RLCore import * Fix dependency issue * Drop excess entries * Install Python in RLEnv tests * Fix refs * Fix conda caching issue * Drop python install * Port DQN to new gpu syntax * Fix type piracy * Add back explorer / learner plan method * Fix duplicate method, add missing method * Port dqn * Revert dqn fix * Tweak dqn optimise call * Drop cache for RLEnv due to cache issues * dqn works * iqn fixes * iqn passes * NFQ works * NFQ works * prio dqn works * qr_dqn works * rem_dqn works * rainbow passes * drop policy gradients temporarily * deactivate further experiments for now * drop device * temporarily drop cql_sac * Fix runtests to use testsets * Fix rainbow * Rearrange approximator setup * Drop Cuda references from RLCore * Fix test error * Move device out of RLCore * Remove CUDA from RLCore dependencies * Drop excess test * Fix RLEnv version * Dropping device is breaking change, bump version * Version bump RLCore in other pkg compat entries * Syntax error * Fix RLEnv dependencies * Mostly fixed... * Fix deps * Drop RLCore import * Fix dependency issue * Drop excess entries * Install Python in RLEnv tests * Fix refs * Fix conda caching issue * Drop python install * Port DQN to new gpu syntax * Fix type piracy * Add back explorer / learner plan method * Fix duplicate method, add missing method * Port dqn * Revert dqn fix * Tweak dqn optimise call * Drop cache for RLEnv due to cache issues * dqn works * iqn fixes * iqn passes * NFQ works * NFQ works * prio dqn works * qr_dqn works * rem_dqn works * rainbow passes * drop policy gradients temporarily * deactivate further experiments for now * drop device * temporarily drop cql_sac * Fix runtests to use testsets * Fix rainbow * Move CUDA to extras * bump RLZoo version * Update Project.toml * Update Project.toml * cuda missing from tests --------- Co-authored-by: Jeremiah Lewis <--get>
1 parent 7ecfb2e commit a982d70

File tree

35 files changed

+295
-289
lines changed

35 files changed

+295
-289
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ jobs:
163163
with:
164164
version: ${{ matrix.version }}
165165
arch: ${{ matrix.arch }}
166-
- uses: julia-actions/cache@v1
166+
# - uses: julia-actions/cache@v1
167167
- name: Get changed files
168168
id: RLEnvironments-changed
169169
uses: tj-actions/changed-files@v42

Project.toml

+3-5
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,15 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
88
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
99
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
1010
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
11-
ReinforcementLearningDatasets = "dd1544ca-2576-438c-a599-ae96278fd687"
1211
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
1312
ReinforcementLearningZoo = "d607f57d-ee1e-4ba7-bcf2-7734c1e31854"
1413

1514
[compat]
1615
Reexport = "0.2, 1"
17-
julia = "1.6"
18-
ReinforcementLearningBase = "0.10"
19-
ReinforcementLearningCore = "0.13"
16+
ReinforcementLearningBase = "0.12"
17+
ReinforcementLearningCore = "0.14"
2018
ReinforcementLearningEnvironments = "0.8"
21-
ReinforcementLearningZoo = "0.6"
19+
julia = "1.6"
2220

2321
[extras]
2422
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

src/DistributedReinforcementLearning/Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1313

1414
[compat]
1515
julia = "1"
16-
Flux = "0.11"
16+
Flux = "0.14"
1717
ReinforcementLearningBase = "0.8.5"
18-
ReinforcementLearningCore = "0.5.1"
18+
ReinforcementLearningCore = "0.14"
1919
ReinforcementLearningEnvironments = "0.3.3"
2020
StatsBase = "0.33, 0.34"

src/ReinforcementLearningCore/Project.toml

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
name = "ReinforcementLearningCore"
22
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
3-
version = "0.13.1"
3+
version = "0.14.0"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
98
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
109
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
1110
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
@@ -25,18 +24,16 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2524
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2625
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2726
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
28-
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2927

3028
[compat]
3129
AbstractTrees = "0.3, 0.4"
3230
Adapt = "3, 4"
33-
CUDA = "4, 5"
3431
ChainRulesCore = "1"
3532
CircularArrayBuffers = "0.1.12"
3633
Crayons = "4"
3734
Distributions = "0.25"
3835
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
39-
Flux = "0.13, 0.14"
36+
Flux = "0.14"
4037
Functors = "0.1, 0.2, 0.3, 0.4"
4138
GPUArrays = "8, 9, 10"
4239
Metal = "1.0"
@@ -49,17 +46,18 @@ Statistics = "1"
4946
StatsBase = "0.32, 0.33, 0.34"
5047
TimerOutputs = "0.5"
5148
UnicodePlots = "1.3, 2, 3"
52-
cuDNN = "1"
5349
julia = "1.9"
5450

5551
[extras]
5652
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
53+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5754
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
5855
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
5956
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
6057
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6158
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6259
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
60+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
6361

6462
[targets]
65-
test = ["CommonRLInterface", "DomainSets", "Metal", "Preferences", "Test", "UUIDs"]
63+
test = ["CommonRLInterface", "CUDA", "cuDNN", "DomainSets", "Metal", "Preferences", "Test", "UUIDs"]

src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, x::Vector{I}, mask::Trues) w
126126
function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::Vector{I}, mask::M) where {I<:Real, M<:Union{BitVector, Vector{Bool}}}
127127
ϵ = get_ϵ(s)
128128
s.step += 1
129-
rand(s.rng) >= ϵ ? findmax(values, mask)[2] : rand(s.rng, findall(mask))
129+
rand(s.rng) >= ϵ ? findmax_masked(values, mask)[2] : rand(s.rng, findall(mask))
130130
end
131131

132132
#####
@@ -188,7 +188,7 @@ function RLBase.prob(s::EpsilonGreedyExplorer{<:Any,false}, values, mask)
188188
ϵ, n = get_ϵ(s), length(values)
189189
probs = zeros(n)
190190
probs[mask] .= ϵ / sum(mask)
191-
probs[findmax(values, mask)[2]] += 1 - ϵ
191+
probs[findmax_masked(values, mask)[2]] += 1 - ϵ
192192
Categorical(probs; check_args=false)
193193
end
194194

@@ -201,7 +201,7 @@ struct GreedyExplorer <: AbstractExplorer end
201201
RLBase.plan!(s::GreedyExplorer, x, mask::Trues) = s(x)
202202

203203
RLBase.plan!(s::GreedyExplorer, values) = findmax(values)[2]
204-
RLBase.plan!(s::GreedyExplorer, values, mask) = findmax(values, mask)[2]
204+
RLBase.plan!(s::GreedyExplorer, values, mask) = findmax_masked(values, mask)[2]
205205

206206
RLBase.prob(s::GreedyExplorer, values) =
207207
Categorical(onehot(findmax(values)[2], 1:length(values)); check_args=false)
@@ -210,4 +210,4 @@ RLBase.prob(s::GreedyExplorer, values, action::Integer) =
210210
findmax(values)[2] == action ? 1.0 : 0.0
211211

212212
RLBase.prob(s::GreedyExplorer, values, mask) =
213-
Categorical(onehot(findmax(values, mask)[2], length(values)); check_args=false)
213+
Categorical(onehot(findmax_masked(values, mask)[2], length(values)); check_args=false)

src/ReinforcementLearningCore/src/policies/learners.jl

-26
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
export AbstractLearner, Approximator
2+
3+
using Flux
4+
using Functors: @functor
5+
6+
abstract type AbstractLearner end
7+
8+
Base.show(io::IO, m::MIME"text/plain", L::AbstractLearner) = show(io, m, convert(AnnotatedStructTree, L))
9+
10+
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
11+
function forward(L::Le, env::E) where {Le <: AbstractLearner, E <: AbstractEnv}
12+
env |> state |> Flux.gpu |> (x -> forward(L, x)) |> Flux.cpu
13+
end
14+
15+
function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::Trajectory) end
16+
17+
18+
"""
19+
Approximator(model, optimiser)
20+
21+
Wraps a Flux trainable model and implements the `RLBase.optimise!(::Approximator, ::Gradient)`
22+
interface. See the RLCore documentation for more information on proper usage.
23+
"""
24+
struct Approximator{M,O} <: AbstractLearner
25+
model::M
26+
optimiser_state::O
27+
end
28+
29+
function Approximator(; model, optimiser)
30+
optimiser_state = Flux.setup(optimiser, model)
31+
Approximator(gpu(model), gpu(optimiser_state)) # Pass model to GPU (if available) upon creation
32+
end
33+
34+
Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(AnnotatedStructTree, A))
35+
36+
@functor Approximator (model,)
37+
38+
function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv)
39+
legal_action_space_ = RLBase.legal_action_space_mask(env)
40+
RLBase.plan!(explorer, forward(learner, env), legal_action_space_)
41+
end
42+
43+
function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv, player::Symbol)
44+
legal_action_space_ = RLBase.legal_action_space_mask(env, player)
45+
return RLBase.plan!(explorer, forward(learner, env), legal_action_space_)
46+
end
47+
48+
forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...)
49+
50+
RLBase.optimise!(A::Approximator, grad) =
51+
Flux.Optimise.update!(A.optimiser_state, A.model, grad)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
include("abstract_learner.jl")
2+
include("tabular_approximator.jl")
3+
include("target_network.jl")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
export TabularApproximator, TabularVApproximator, TabularQApproximator
2+
3+
using Flux: gpu
4+
5+
"""
6+
TabularApproximator(table<:AbstractArray, opt)
7+
8+
For `table` of 1-d, it will serve as a state value approximator.
9+
For `table` of 2-d, it will serve as a state-action value approximator.
10+
11+
!!! warning
12+
For `table` of 2-d, the first dimension is action and the second dimension is state.
13+
"""
14+
# TODO: add back missing AbstractApproximator
15+
struct TabularApproximator{N,A,O} <: AbstractLearner
16+
table::A
17+
optimizer::O
18+
function TabularApproximator(table::A, opt::O) where {A<:AbstractArray,O}
19+
n = ndims(table)
20+
n <= 2 || throw(ArgumentError("the dimension of table must be <= 2"))
21+
new{n,A,O}(table, opt)
22+
end
23+
end
24+
25+
TabularVApproximator(; n_state, init = 0.0, opt = InvDecay(1.0)) =
26+
TabularApproximator(fill(init, n_state), opt)
27+
28+
TabularQApproximator(; n_state, n_action, init = 0.0, opt = InvDecay(1.0)) =
29+
TabularApproximator(fill(init, n_action, n_state), opt)
30+
31+
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
32+
function forward(L::TabularApproximator, env::E) where {E <: AbstractEnv}
33+
env |> state |> (x -> forward(L, x))
34+
end
35+
36+
RLCore.forward(
37+
app::TabularApproximator{1,R,O},
38+
s::I,
39+
) where {R<:AbstractArray,O,I<:Integer} = @views app.table[s]
40+
41+
RLCore.forward(
42+
app::TabularApproximator{2,R,O},
43+
s::I,
44+
) where {R<:AbstractArray,O,I<:Integer} = @views app.table[:, s]
45+
RLCore.forward(
46+
app::TabularApproximator{2,R,O},
47+
s::I1,
48+
a::I2,
49+
) where {R<:AbstractArray,O,I1<:Integer,I2<:Integer} = @views app.table[a, s]

src/ReinforcementLearningCore/src/policies/approximator.jl src/ReinforcementLearningCore/src/policies/learners/target_network.jl

+7-23
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,6 @@ export Approximator, TargetNetwork, target, model
33
using Flux
44

55

6-
"""
7-
Approximator(model, optimiser)
8-
9-
Wraps a Flux trainable model and implements the `RLBase.optimise!(::Approximator, ::Gradient)`
10-
interface. See the RLCore documentation for more information on proper usage.
11-
"""
12-
Base.@kwdef mutable struct Approximator{M,O}
13-
model::M
14-
optimiser::O
15-
end
16-
17-
Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(AnnotatedStructTree, A))
18-
19-
@functor Approximator (model,)
20-
21-
forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...)
22-
23-
RLBase.optimise!(A::Approximator, gs) = Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)
24-
256
target(ap::Approximator) = ap.model #see TargetNetwork
267
model(ap::Approximator) = ap.model #see TargetNetwork
278

@@ -52,9 +33,11 @@ mutable struct TargetNetwork{M}
5233
n_optimise::Int
5334
end
5435

55-
function TargetNetwork(x; sync_freq=1, ρ=0.0f0)
36+
function TargetNetwork(network; sync_freq = 1, ρ = 0f0)
5637
@assert 0 <= ρ <= 1 "ρ must in [0,1]"
57-
TargetNetwork(x, deepcopy(x.model), sync_freq, ρ, 0)
38+
# NOTE: model is pushed to gpu in Approximator, need to transfer to cpu before deepcopy, then push target model to gpu
39+
target = gpu(deepcopy(cpu(network.model)))
40+
TargetNetwork(network, target, sync_freq, ρ, 0)
5841
end
5942

6043
@functor TargetNetwork (network, target)
@@ -66,9 +49,10 @@ forward(tn::TargetNetwork, args...) = forward(tn.network, args...)
6649
model(tn::TargetNetwork) = model(tn.network)
6750
target(tn::TargetNetwork) = tn.target
6851

69-
function RLBase.optimise!(tn::TargetNetwork, gs)
52+
function RLBase.optimise!(tn::TargetNetwork, grad)
7053
A = tn.network
71-
Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)
54+
optimise!(A, grad)
55+
7256
tn.n_optimise += 1
7357

7458
if tn.n_optimise % tn.sync_freq == 0
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
include("agent/agent.jl")
22
include("random_policy.jl")
33
include("explorers/explorers.jl")
4-
include("learners.jl")
4+
include("learners/learners.jl")
55
include("q_based_policy.jl")
6-
include("approximator.jl")

src/ReinforcementLearningCore/src/utils/basic.jl

+2-7
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,11 @@ function find_all_max(x, mask::AbstractVector{Bool})
113113
v, [k for (m, k) in zip(mask, keys(x)) if m && x[k] == v]
114114
end
115115

116-
# !!! watch https://github.com/JuliaLang/julia/pull/35316#issuecomment-622629895
117-
# Base.findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)
118-
# _rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m)
119116

120-
# !!! type piracy
121-
Base.findmax(A::AbstractVector{T}, mask::AbstractVector{Bool}) where {T} =
117+
findmax_masked(A::AbstractVector{T}, mask::AbstractVector{Bool}) where {T} =
122118
findmax(ifelse.(mask, A, typemin(T)))
123119

124-
Base.findmax(A::AbstractVector, mask::Trues) = findmax(A)
125-
120+
findmax_masked(A::AbstractVector, mask::Trues) = findmax(A)
126121

127122
const VectorOrMatrix = Union{AbstractMatrix,AbstractVector}
128123

0 commit comments

Comments
 (0)