Skip to content

Commit 94b5213

Browse files
Move UnicodePlots to extension (#1088)
* Update Flux and GPUArrays compatibility in Project.toml; refactor FluxApproximator and TargetNetwork implementations * Refactor target network optimization and update test assertions for consistency * Simplify FluxApproximator's optimise! method by using a single-line function definition * Bump version to 0.15.4 in Project.toml * Update NEWS.md for v0.15.4: Upgrade Flux.jl to v0.16 and resolve deprecation warnings * Add Conda dependency and update test environment setup * Update test environment setup to use pip for gym installation * Fix RLEnv tests * Fix optimizer reference in stock trading environment example * Fix optimizer reference in stock trading environment example * Refactor optimizer implementation in DDPGPolicy to use OptimiserChain * Refactor UnicodePlots integration into extension * Move UnicodePlots to package extension in release notes
1 parent 25ec21e commit 94b5213

File tree

5 files changed

+29
-18
lines changed

5 files changed

+29
-18
lines changed

src/ReinforcementLearningCore/NEWS.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# ReinforcementLearningCore.jl Release Notes
22

3+
#### v0.15.5
4+
5+
- Move `UnicodePlots` to package extension
6+
37
#### v0.15.4
48

59
- Update `Flux.jl` to `v0.16` and fix deprecation warnings and method errors

src/ReinforcementLearningCore/Project.toml

+6-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ version = "0.15.4"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
7-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
87
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
98
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
109
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
@@ -21,11 +20,15 @@ ReinforcementLearningTrajectories = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
2120
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2221
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2322
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
23+
24+
[weakdeps]
2425
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2526

27+
[extensions]
28+
UnicodePlotsExt = "UnicodePlots"
29+
2630
[compat]
2731
AbstractTrees = "0.3, 0.4"
28-
Adapt = "3, 4"
2932
ChainRulesCore = "1"
3033
CircularArrayBuffers = "0.1.12"
3134
Crayons = "4"
@@ -57,4 +60,4 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
5760
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
5861

5962
[targets]
60-
test = ["CommonRLInterface", "CUDA", "cuDNN", "DomainSets", "Metal", "Preferences", "ReinforcementLearningEnvironments", "Test", "UUIDs"]
63+
test = ["CommonRLInterface", "CUDA", "cuDNN", "DomainSets", "Metal", "Preferences", "ReinforcementLearningEnvironments", "Test", "UnicodePlots", "UUIDs"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module UnicodePlotsExt
2+
using ReinforcementLearningCore
3+
using UnicodePlots: lineplot, lineplot!
4+
5+
function Base.show(io::IO, hook::TotalRewardPerEpisode{true, F}) where {F<:Number}
6+
if length(hook.rewards) > 0
7+
println(io, lineplot(
8+
hook.rewards,
9+
title="Total reward per episode",
10+
xlabel="Episode",
11+
ylabel="Score",
12+
))
13+
else
14+
println(io, typeof(hook))
15+
end
16+
return
17+
end
18+
end

src/ReinforcementLearningCore/src/core/hooks.jl

-15
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ export AbstractHook,
1010
DoEveryNSteps,
1111
DoOnExit
1212

13-
using UnicodePlots: lineplot, lineplot!
1413
using Statistics: mean, std
1514
using CircularArrayBuffers: CircularVectorBuffer
1615
import ReinforcementLearningBase: RLBase
@@ -172,20 +171,6 @@ function Base.push!(hook::TotalRewardPerEpisode,
172171
return
173172
end
174173

175-
function Base.show(io::IO, hook::TotalRewardPerEpisode{true, F}) where {F<:Number}
176-
if length(hook.rewards) > 0
177-
println(io, lineplot(
178-
hook.rewards,
179-
title="Total reward per episode",
180-
xlabel="Episode",
181-
ylabel="Score",
182-
))
183-
else
184-
println(io, typeof(hook))
185-
end
186-
return
187-
end
188-
189174
function Base.push!(hook::TotalRewardPerEpisode{true, F},
190175
::PostExperimentStage,
191176
agent::AbstractPolicy,

src/ReinforcementLearningCore/test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Test
2+
using UnicodePlots
23
using UUIDs
34
using Preferences
45

0 commit comments

Comments
 (0)