Skip to content

Commit 06cabb9

Browse files
Add TotalRewardPerEpisodeLastN hook (#1053)
1 parent e61836a commit 06cabb9

File tree

7 files changed

+81
-0
lines changed

7 files changed

+81
-0
lines changed

src/ReinforcementLearningFarm/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ uuid = "14eff660-7080-4cec-bba2-cfb12cd77ac3"
33
version = "0.0.1"
44

55
[deps]
6+
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
67
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
78
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
1011
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
1112

1213
[compat]
14+
CircularArrayBuffers = "0.1.12"
1315
ReinforcementLearningBase = "0.12"
1416
ReinforcementLearningCore = "0.14"
1517
ReinforcementLearningEnvironments = "0.8"

src/ReinforcementLearningFarm/src/ReinforcementLearningFarm.jl

+1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ const RLFarm = ReinforcementLearningFarm
66
export RLFarm
77

88
include("algorithms/algorithms.jl")
9+
include("hooks/hooks.jl")
910

1011
end # module
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include("total_reward_per_last_n_episodes.jl")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using ReinforcementLearningCore
2+
using ReinforcementLearningBase
3+
import Base.push!
4+
import Base.getindex
5+
using CircularArrayBuffers: CircularVectorBuffer, CircularArrayBuffer
6+
7+
"""
8+
TotalRewardPerLastNEpisodes{F}(; max_episodes = 100)
9+
10+
A hook that keeps track of the total reward per episode for the last `max_episodes` episodes.
11+
"""
12+
struct TotalRewardPerLastNEpisodes{B} <: AbstractHook where {B<:CircularArrayBuffer}
13+
rewards::B
14+
15+
function TotalRewardPerLastNEpisodes(; max_episodes = 100)
16+
buffer = CircularVectorBuffer{Float64}(max_episodes)
17+
new{typeof(buffer)}(buffer)
18+
end
19+
end
20+
21+
Base.getindex(h::TotalRewardPerLastNEpisodes{B}, inds...) where {B<:CircularArrayBuffer} =
22+
getindex(h.rewards, inds...)
23+
24+
Base.push!(
25+
h::TotalRewardPerLastNEpisodes{B},
26+
::PostActStage,
27+
agent::P,
28+
env::E,
29+
player::Symbol,
30+
) where {P<:AbstractPolicy,E<:AbstractEnv,B<:CircularArrayBuffer} =
31+
h.rewards[end] += reward(env, player)
32+
33+
Base.push!(
34+
hook::TotalRewardPerLastNEpisodes{B},
35+
::PreEpisodeStage,
36+
agent,
37+
env,
38+
) where {B<:CircularArrayBuffer} = Base.push!(hook.rewards, 0.0)
39+
40+
Base.push!(
41+
hook::TotalRewardPerLastNEpisodes{B},
42+
stage::Union{PreEpisodeStage,PostEpisodeStage,PostExperimentStage},
43+
agent,
44+
env,
45+
player::Symbol,
46+
) where {B<:CircularArrayBuffer} = Base.push!(hook, stage, agent, env)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include("total_reward_per_last_n_episodes.jl")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using ReinforcementLearningFarm: TotalRewardPerLastNEpisodes
2+
3+
@testset "TotalRewardPerLastNEpisodes" begin
4+
@testset "Single Agent" begin
5+
hook = TotalRewardPerLastNEpisodes(max_episodes = 10)
6+
env = TicTacToeEnv()
7+
agent = RandomPolicy()
8+
9+
for i = 1:15
10+
push!(hook, PreEpisodeStage(), agent, env)
11+
push!(hook, PostActStage(), agent, env)
12+
@test length(hook.rewards) == min(i, 10)
13+
@test hook.rewards[min(i, 10)] == reward(env)
14+
end
15+
end
16+
17+
@testset "MultiAgent" begin
18+
hook = TotalRewardPerLastNEpisodes(max_episodes = 10)
19+
env = TicTacToeEnv()
20+
agent = RandomPolicy()
21+
22+
for i = 1:15
23+
push!(hook, PreEpisodeStage(), agent, env, :Cross)
24+
push!(hook, PostActStage(), agent, env, :Cross)
25+
@test length(hook.rewards) == min(i, 10)
26+
@test hook.rewards[min(i, 10)] == reward(env, :Cross)
27+
end
28+
end
29+
end

src/ReinforcementLearningFarm/test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ using ReinforcementLearningFarm
1919

2020
@testset "ReinforcementLearningFarm.jl" begin
2121
include("algorithms/algorithms.jl")
22+
include("hooks/hooks.jl")
2223
end

0 commit comments

Comments
 (0)