|
| 1 | +using ReinforcementLearningCore |
| 2 | +using ReinforcementLearningBase |
| 3 | +import Base.push! |
| 4 | +import Base.getindex |
| 5 | +using CircularArrayBuffers: CircularVectorBuffer, CircularArrayBuffer |
| 6 | + |
| 7 | +""" |
| 8 | +TotalRewardPerLastNEpisodes{F}(; max_episodes = 100) |
| 9 | +
|
| 10 | +A hook that keeps track of the total reward per episode for the last `max_episodes` episodes. |
| 11 | +""" |
| 12 | +struct TotalRewardPerLastNEpisodes{B} <: AbstractHook where {B<:CircularArrayBuffer} |
| 13 | + rewards::B |
| 14 | + |
| 15 | + function TotalRewardPerLastNEpisodes(; max_episodes = 100) |
| 16 | + buffer = CircularVectorBuffer{Float64}(max_episodes) |
| 17 | + new{typeof(buffer)}(buffer) |
| 18 | + end |
| 19 | +end |
| 20 | + |
| 21 | +Base.getindex(h::TotalRewardPerLastNEpisodes{B}, inds...) where {B<:CircularArrayBuffer} = |
| 22 | + getindex(h.rewards, inds...) |
| 23 | + |
| 24 | +Base.push!( |
| 25 | + h::TotalRewardPerLastNEpisodes{B}, |
| 26 | + ::PostActStage, |
| 27 | + agent::P, |
| 28 | + env::E, |
| 29 | + player::Symbol, |
| 30 | +) where {P<:AbstractPolicy,E<:AbstractEnv,B<:CircularArrayBuffer} = |
| 31 | + h.rewards[end] += reward(env, player) |
| 32 | + |
| 33 | +Base.push!( |
| 34 | + hook::TotalRewardPerLastNEpisodes{B}, |
| 35 | + ::PreEpisodeStage, |
| 36 | + agent, |
| 37 | + env, |
| 38 | +) where {B<:CircularArrayBuffer} = Base.push!(hook.rewards, 0.0) |
| 39 | + |
| 40 | +Base.push!( |
| 41 | + hook::TotalRewardPerLastNEpisodes{B}, |
| 42 | + stage::Union{PreEpisodeStage,PostEpisodeStage,PostExperimentStage}, |
| 43 | + agent, |
| 44 | + env, |
| 45 | + player::Symbol, |
| 46 | +) where {B<:CircularArrayBuffer} = Base.push!(hook, stage, agent, env) |
0 commit comments