Skip to content

Commit 30f3105

Browse files
authored
Make QBasedPolicy general for AbstractLearner s (#1069)
1 parent 1b4d449 commit 30f3105

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/ReinforcementLearningCore/src/policies/q_based_policy.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,32 @@ action of an environment at its current state. It is typically a table or a neur
1010
QBasedPolicy can be queried for an action with `RLBase.plan!`, the explorer will affect the action selection
1111
accordingly.
1212
"""
13-
struct QBasedPolicy{L<:TDLearner,E<:AbstractExplorer} <: AbstractPolicy
13+
struct QBasedPolicy{L<:AbstractLearner,E<:AbstractExplorer} <: AbstractPolicy
1414
"estimate the Q value"
1515
learner::L
1616
"select the action based on Q values calculated by the learner"
1717
explorer::E
1818

19-
function QBasedPolicy(; learner::L, explorer::E) where {L<:TDLearner, E<:AbstractExplorer}
19+
function QBasedPolicy(; learner::L, explorer::E) where {L<:AbstractLearner, E<:AbstractExplorer}
2020
new{L,E}(learner, explorer)
2121
end
2222

23-
function QBasedPolicy(learner::L, explorer::E) where {L<:TDLearner, E<:AbstractExplorer}
23+
function QBasedPolicy(learner::L, explorer::E) where {L<:AbstractLearner, E<:AbstractExplorer}
2424
new{L,E}(learner, explorer)
2525
end
2626
end
2727

2828
Flux.@layer QBasedPolicy trainable=(learner,)
2929

30-
function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E) where {Ex<:AbstractExplorer,L<:TDLearner,E<:AbstractEnv}
30+
function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv}
3131
RLBase.plan!(policy.explorer, policy.learner, env)
3232
end
3333

34-
function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E, player::Player) where {Ex<:AbstractExplorer,L<:TDLearner,E<:AbstractEnv, Player<:AbstractPlayer}
34+
function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E, player::Player) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv, Player<:AbstractPlayer}
3535
RLBase.plan!(policy.explorer, policy.learner, env, player)
3636
end
3737

38-
RLBase.prob(policy::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:TDLearner,Ex<:AbstractExplorer} =
38+
RLBase.prob(policy::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} =
3939
prob(policy.explorer, forward(policy.learner, env), legal_action_space_mask(env))
4040

4141
#the internal learner defines the optimization stage.

0 commit comments

Comments
 (0)