Skip to content

Commit b1c5c12

Browse files
author
Jeremiah Lewis
committed
naming
1 parent 87760ff commit b1c5c12

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/ReinforcementLearningCore/src/policies/learners/td_learner.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,32 @@ Update the Q-value of the given state-action pair.
4545
"""
4646
function bellman_update!(
4747
approx::TabularApproximator,
48-
s::I1,
49-
s_plus_one::I2,
50-
a::I3,
51-
r::F1, # reward
48+
state::I1,
49+
next_state::I2,
50+
action::I3,
51+
reward::F1,
5252
γ::Float64, # discount factor
5353
) where {I1<:Integer,I2<:Integer,I3<:Integer,F1<:AbstractFloat}
5454
# Q-learning formula following https://github.com/JuliaPOMDP/TabularTDLearning.jl/blob/25c4d3888e178c51ed1ff448f36b0fcaf7c1d8e8/src/q_learn.jl#LL63C26-L63C95
5555
# Terminology following https://en.wikipedia.org/wiki/Q-learning
56-
estimate_optimal_future_value = maximum(Q(approx, s_plus_one))
57-
current_value = Q(approx, s, a)
58-
raw_q_value = (r + γ * estimate_optimal_future_value - current_value) # Discount factor γ is applied here
56+
estimate_optimal_future_value = maximum(Q(approx, next_state))
57+
current_value = Q(approx, state, action)
58+
raw_q_value = (reward + γ * estimate_optimal_future_value - current_value) # Discount factor γ is applied here
5959
q_value_updated = Flux.Optimise.update!(approx.optimiser_state, :learning, [raw_q_value])[] # adust according to optimiser learning rate
60-
approx.model[a, s] += q_value_updated
61-
return Q(approx, s, a)
60+
approx.model[action, state] += q_value_updated
61+
return Q(approx, state, action)
6262
end
6363

6464
function _optimise!(
6565
n::I1,
6666
γ::F,
6767
approx::Approximator{Ar},
68-
s::I2,
69-
s_next::I2,
70-
a::I3,
71-
r::F,
68+
state::I2,
69+
next_state::I2,
70+
action::I3,
71+
reward::F,
7272
) where {I1<:Number,I2<:Number,I3<:Number,Ar<:AbstractArray,F<:AbstractFloat}
73-
bellman_update!(approx, s, s_next, a, r, γ)
73+
bellman_update!(approx, state, next_state, action, reward, γ)
7474
end
7575

7676
function RLBase.optimise!(
@@ -80,12 +80,12 @@ function RLBase.optimise!(
8080
_optimise!(L.n, L.γ, L.approximator, t.state, t.next_state, t.action, t.reward)
8181
end
8282

83-
function RLBase.optimise!(learner::TDLearner, stage, trajectory::Trajectory)
83+
function RLBase.optimise!(learner::TDLearner, stage::AbstractStage, trajectory::Trajectory)
8484
for batch in trajectory.container
8585
optimise!(learner, stage, batch)
8686
end
8787
end
8888

8989
# TDLearner{:SARS} optimises at the PostActStage
90-
RLBase.optimise!(L::TDLearner{:SARS}, stage::PostActStage, trace::NamedTuple) = RLBase.optimise!(L, trace)
90+
RLBase.optimise!(learner::TDLearner{:SARS}, stage::PostActStage, trace::NamedTuple) = RLBase.optimise!(learner, trace)
9191

0 commit comments

Comments
 (0)