|
1 | 1 | export TabularApproximator, TabularVApproximator, TabularQApproximator
|
2 | 2 |
|
3 |
| -const TabularApproximator = Approximator{A,O} where {A<:AbstractArray,O} |
4 |
| -const TabularQApproximator = Approximator{A,O} where {A<:AbstractArray,O} |
5 |
| -const TabularVApproximator = Approximator{A,O} where {A<:AbstractVector,O} |
| 3 | +struct TabularApproximator{A} <: AbstractLearner where {A<:AbstractArray} |
| 4 | + model::A |
| 5 | +end |
| 6 | + |
| 7 | +const TabularQApproximator = TabularApproximator{A} where {A<:AbstractMatrix} |
| 8 | +const TabularVApproximator = TabularApproximator{A} where {A<:AbstractVector} |
6 | 9 |
|
7 | 10 | """
|
8 |
| - TabularApproximator(table<:AbstractArray, opt) |
| 11 | + TabularApproximator(table<:AbstractArray) |
9 | 12 |
|
10 | 13 | For `table` of 1-d, it will serve as a state value approximator.
|
11 | 14 | For `table` of 2-d, it will serve as a state-action value approximator.
|
12 | 15 |
|
13 | 16 | !!! warning
|
14 | 17 | For `table` of 2-d, the first dimension is action and the second dimension is state.
|
15 | 18 | """
|
16 |
| -function TabularApproximator(table::A, opt::O) where {A<:AbstractArray,O} |
| 19 | +function TabularApproximator(table::A) where {A<:AbstractArray} |
17 | 20 | n = ndims(table)
|
18 | 21 | n <= 2 || throw(ArgumentError("the dimension of table must be <= 2"))
|
19 |
| - TabularApproximator{A,O}(table, opt) |
| 22 | + TabularApproximator{A}(table) |
20 | 23 | end
|
21 | 24 |
|
22 |
| -TabularVApproximator(; n_state, init = 0.0, opt = InvDecay(1.0)) = |
23 |
| - TabularApproximator(fill(init, n_state), opt) |
| 25 | +TabularVApproximator(; n_state, init = 0.0) = |
| 26 | + TabularApproximator(fill(init, n_state)) |
24 | 27 |
|
25 |
| -TabularQApproximator(; n_state, n_action, init = 0.0, opt = InvDecay(1.0)) = |
26 |
| - TabularApproximator(fill(init, n_action, n_state), opt) |
| 28 | +TabularQApproximator(; n_state, n_action, init = 0.0) = |
| 29 | + TabularApproximator(fill(init, n_action, n_state)) |
27 | 30 |
|
28 | 31 | # Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
|
29 | 32 | forward(L::TabularVApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(L, x))
|
30 | 33 | forward(L::TabularQApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(L, x))
|
31 | 34 |
|
32 | 35 | RLCore.forward(
|
33 |
| - app::TabularVApproximator{R,O}, |
| 36 | + app::TabularVApproximator{R}, |
34 | 37 | s::I,
|
35 |
| -) where {R<:AbstractVector,O,I} = @views app.model[s] |
| 38 | +) where {R<:AbstractVector,I} = @views app.model[s] |
36 | 39 |
|
37 | 40 | RLCore.forward(
|
38 |
| - app::TabularQApproximator{R,O}, |
| 41 | + app::TabularQApproximator{R}, |
39 | 42 | s::I,
|
40 |
| -) where {R<:AbstractArray,O,I} = @views app.model[:, s] |
| 43 | +) where {R<:AbstractArray,I} = @views app.model[:, s] |
41 | 44 |
|
42 | 45 | RLCore.forward(
|
43 |
| - app::TabularQApproximator{R,O}, |
| 46 | + app::TabularQApproximator{R}, |
44 | 47 | s::I1,
|
45 | 48 | a::I2,
|
46 |
| -) where {R<:AbstractArray,O,I1,I2} = @views app.model[a, s] |
| 49 | +) where {R<:AbstractArray,I1,I2} = @views app.model[a, s] |
0 commit comments