Skip to content

Commit 9e06129

Browse files
Fix naming consistency and add missing hook tests (#1049)
* Make nsteps naming consistent * add missing hook tests * basic hook tests and nesting fixes --------- Co-authored-by: Jeremiah Lewis <--get>
1 parent 55f60b0 commit 9e06129

File tree

4 files changed

+59
-26
lines changed

4 files changed

+59
-26
lines changed

docs/homepage/guide/index.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Usually a closure or a functional object will be used to store some intermediate
8484

8585
In most cases, you don't need to write a customized hook. Some generic hooks are provided so that you can inject logic at the appropriate time:
8686

87-
- [`DoEveryNStep`](https://juliareinforcementlearning.org/ReinforcementLearning.jl/latest/rl_core/#ReinforcementLearningCore.DoEveryNStep)
87+
- [`DoEveryNSteps`](https://juliareinforcementlearning.org/ReinforcementLearning.jl/latest/rl_core/#ReinforcementLearningCore.DoEveryNSteps)
8888
- [`DoEveryNEpisode`](https://juliareinforcementlearning.org/ReinforcementLearning.jl/latest/rl_core/#ReinforcementLearningCore.DoEveryNEpisode)
8989

9090
However, if you do need to write a customized hook, the following methods must be provided:
@@ -98,10 +98,10 @@ If your hook is a subtype of `AbstractHook`, then all the above methods will hav
9898

9999
## How to use TensorBoard?
100100

101-
This package adopts a non-invasive way for logging. So you can log everything you like with a hook. For example, to log the loss of each step. You can use the [`DoEveryNStep`](https://juliareinforcementlearning.org/ReinforcementLearning.jl/latest/rl_core/#ReinforcementLearningCore.DoEveryNStep).
101+
This package adopts a non-invasive way for logging. So you can log everything you like with a hook. For example, to log the loss of each step. You can use the [`DoEveryNSteps`](https://juliareinforcementlearning.org/ReinforcementLearning.jl/latest/rl_core/#ReinforcementLearningCore.DoEveryNSteps).
102102

103103
```julia
104-
DoEveryNStep() do t, agent, env
104+
DoEveryNSteps() do t, agent, env
105105
with_logger(lg) do
106106
@info "training" loss = agent.policy.learner.loss
107107
end
@@ -117,7 +117,7 @@ run(
117117
agent,
118118
env,
119119
stop_condition,
120-
DoEveryNStep(EVALUATION_FREQ) do t, agent, env
120+
DoEveryNSteps(EVALUATION_FREQ) do t, agent, env
121121
run(agent, env, eval_stop_condition, eval_hook)
122122
end
123123
)

docs/src/How_to_use_hooks.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ Sometimes, we'd like to periodically run some functions. Two handy hooks are
8585
provided for this kind of tasks:
8686

8787
- [`DoEveryNEpisode`](@ref)
88-
- [`DoEveryNStep`](@ref)
88+
- [`DoEveryNSteps`](@ref)
8989

9090
Following are some typical usages.
9191

@@ -160,7 +160,7 @@ run(
160160
policy,
161161
env,
162162
StopAfterNSteps(10_000),
163-
DoEveryNStep(n=1_000) do t, p, e
163+
DoEveryNSteps(n=1_000) do t, p, e
164164
ps = params(p)
165165
f = joinpath(parameters_dir, "parameters_at_step_$t.bson")
166166
BSON.@save f ps

src/ReinforcementLearningCore/src/core/hooks.jl

+13-17
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export AbstractHook,
77
BatchStepsPerEpisode,
88
TimePerStep,
99
DoEveryNEpisode,
10-
DoEveryNStep,
10+
DoEveryNSteps,
1111
DoOnExit
1212

1313
using UnicodePlots: lineplot, lineplot!
@@ -38,10 +38,10 @@ struct ComposedHook{T<:Tuple} <: AbstractHook
3838
ComposedHook(hooks...) = new{typeof(hooks)}(hooks)
3939
end
4040

41-
Base.:(+)(h1::AbstractHook, h2::AbstractHook) = ComposedHook((h1, h2))
42-
Base.:(+)(h1::ComposedHook, h2::AbstractHook) = ComposedHook((h1.hooks..., h2))
43-
Base.:(+)(h1::AbstractHook, h2::ComposedHook) = ComposedHook((h1, h2.hooks...))
44-
Base.:(+)(h1::ComposedHook, h2::ComposedHook) = ComposedHook((h1.hooks..., h2.hooks...))
41+
Base.:(+)(h1::AbstractHook, h2::AbstractHook) = ComposedHook(h1, h2)
42+
Base.:(+)(h1::ComposedHook, h2::AbstractHook) = ComposedHook(h1.hooks..., h2)
43+
Base.:(+)(h1::AbstractHook, h2::ComposedHook) = ComposedHook(h1, h2.hooks...)
44+
Base.:(+)(h1::ComposedHook, h2::ComposedHook) = ComposedHook(h1.hooks..., h2.hooks...)
4545

4646
@inline function _push!(stage::AbstractStage, policy::P, env::E, hook::H, hook_tuple...) where {P <: AbstractPolicy, E <: AbstractEnv, H <: AbstractHook}
4747
Base.push!(hook, stage, policy, env)
@@ -286,26 +286,22 @@ function Base.push!(hook::TimePerStep, ::PostActStage, agent, env)
286286
end
287287

288288
"""
289-
DoEveryNStep(f; n=1, t=0)
289+
DoEveryNSteps(f; n=1, t=0)
290290
291291
Execute `f(t, agent, env)` every `n` step.
292292
`t` is a counter of steps.
293293
"""
294-
mutable struct DoEveryNStep{F,T} <: AbstractHook where {F,T<:Integer}
294+
mutable struct DoEveryNSteps{F} <: AbstractHook where {F}
295295
f::F
296-
n::T
297-
t::T
298-
299-
function DoEveryNStep(f; n=1, t=0)
300-
new{typeof(f),Int64}(f, n, t)
301-
end
302-
303-
function DoEveryNStep{T}(f; n=1, t=0) where {T<:Integer}
304-
new{typeof(f),T}(f, n, t)
296+
n::Int
297+
t::Int
298+
299+
function DoEveryNSteps(f::F; n=1, t=0) where {F}
300+
new{F}(f, n, t)
305301
end
306302
end
307303

308-
function Base.push!(hook::DoEveryNStep, ::PostActStage, agent, env)
304+
function Base.push!(hook::DoEveryNSteps, ::PostActStage, agent, env)
309305
hook.t += 1
310306
if hook.t % hook.n == 0
311307
hook.f(hook.t, agent, env)

src/ReinforcementLearningCore/test/core/hooks.jl

+40-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
struct MockHook <: AbstractHook end
2+
13
"""
24
test_noop!(hook; stages=[PreActStage()])
35
@@ -36,6 +38,41 @@ function test_run!(hook::AbstractHook)
3638
return hook_
3739
end
3840

41+
@testset "AbstractHook + AbstractHook" begin
42+
@test MockHook() + MockHook() == ComposedHook(MockHook(), MockHook())
43+
end
44+
45+
@testset "ComposedHook + AbstractHook" begin
46+
struct MockHook <: AbstractHook end
47+
@test ComposedHook(MockHook()) + MockHook() == ComposedHook(MockHook(), MockHook())
48+
end
49+
50+
@testset "AbstractHook + ComposedHook" begin
51+
@test MockHook() + ComposedHook(MockHook()) == ComposedHook(MockHook(), MockHook())
52+
end
53+
54+
@testset "ComposedHook + ComposedHook" begin
55+
@test ComposedHook(MockHook()) + ComposedHook(MockHook()) == ComposedHook(MockHook(), MockHook())
56+
end
57+
58+
@testset "push! method for ComposedHook" begin
59+
stage = PreActStage()
60+
policy = RandomPolicy()
61+
env = TicTacToeEnv()
62+
composed_hook = ComposedHook(MockHook(), MockHook())
63+
push!(composed_hook, stage, policy, env)
64+
@test composed_hook.hooks == (MockHook(), MockHook())
65+
end
66+
67+
@testset "push! method for ComposedHook with multiple hooks" begin
68+
stage = PreActStage()
69+
policy = RandomPolicy()
70+
env = TicTacToeEnv()
71+
composed_hook = ComposedHook(MockHook(), MockHook())
72+
push!(composed_hook, stage, policy, env)
73+
@test composed_hook.hooks == (MockHook(), MockHook())
74+
end
75+
3976
@testset "TotalRewardPerEpisode" begin
4077
h_1 = TotalRewardPerEpisode(; is_display_on_exit=true)
4178
h_2 = TotalRewardPerEpisode(; is_display_on_exit=false)
@@ -57,9 +94,9 @@ end
5794
end
5895
end
5996

60-
@testset "DoEveryNStep" begin
61-
h_1 = DoEveryNStep((hook, agent, env) -> (env.pos += 1); n=2)
62-
h_2 = DoEveryNStep((hook, agent, env) -> (env.pos += 1); n=1)
97+
@testset "DoEveryNSteps" begin
98+
h_1 = DoEveryNSteps((hook, agent, env) -> (env.pos += 1); n=2)
99+
h_2 = DoEveryNSteps((hook, agent, env) -> (env.pos += 1); n=1)
63100

64101
for h in (h_1, h_2)
65102
env = RandomWalk1D()

0 commit comments

Comments
 (0)