Skip to content

Commit 5aed0b2

Browse files
author
Jeremiah Lewis
committed
fixes
1 parent b1c5c12 commit 5aed0b2

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ For `table` of 2-d, it will serve as a state-action value approximator.
1616
function TabularApproximator(table::A, opt::O) where {A<:AbstractArray,O}
1717
n = ndims(table)
1818
n <= 2 || throw(ArgumentError("the dimension of table must be <= 2"))
19-
TabularApproximator{A,O}(table, opt)
19+
optimiser_state = Flux.setup(optimiser, table)
20+
TabularApproximator{A,O}(table, optimiser_state)
2021
end
2122

2223
TabularVApproximator(; n_state, opt, init = 0.0) =

src/ReinforcementLearningCore/test/policies/q_based_policy.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@
7777
)
7878
t = (state=2, action=3)
7979
push!(trajectory, t)
80-
t = (next_state=3, reward=5.0, terminal=false)
80+
next_state = 4
81+
t = (action=3, state=next_state, reward=5.0, terminal=false)
8182
push!(trajectory, t)
83+
trajectory.container[1]
8284
RLBase.optimise!(policy, PostActStage(), trajectory)
8385
# Add assertions here
8486
end

0 commit comments

Comments
 (0)