Skip to content

Commit 89a46d9

Browse files
Add missing legal_action_space_mask default methods (#1075)
* Fix devcontainer * Add patch for missing legal_action_space_mask defaults, add test to StockTradingEnv
1 parent cf14bf0 commit 89a46d9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+139
-58
lines changed

.devcontainer/devcontainer.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
"--privileged"
1212
],
1313
"dockerFile": "Dockerfile",
14-
"updateContentCommand": "julia -e 'using Pkg; Pkg.develop(path=\"src/ReinforcementLearningBase\"); Pkg.develop(path=\"src/ReinforcementLearningEnvironments\"); Pkg.develop(path=\"src/ReinforcementLearningCore\"); Pkg.develop(path=\"src/ReinforcementLearningFarm\"); Pkg.develop(path=\"src/ReinforcementLearning\");'"
14+
"updateContentCommand": "julia -e 'using Pkg; Pkg.develop(path=\"src/ReinforcementLearningBase\"); Pkg.develop(path=\"src/ReinforcementLearningEnvironments\"); Pkg.develop(path=\"src/ReinforcementLearningCore\"); Pkg.develop(path=\"src/ReinforcementLearningFarm\"); Pkg.develop(path=\".\");'"
1515
}

.github/workflows/CompatHelper.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
run: which julia
1616
continue-on-error: true
1717
- name: Install Julia, but only if it is not already available in the PATH
18-
uses: julia-actions/setup-julia@v1
18+
uses: julia-actions/setup-julia@v2
1919
with:
2020
version: '1'
2121
arch: ${{ runner.arch }}

.github/workflows/ci.yml

+9-9
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ jobs:
3434
- uses: actions/checkout@v4
3535
with:
3636
fetch-depth: 100
37-
- uses: julia-actions/setup-julia@v1
37+
- uses: julia-actions/setup-julia@v2
3838
with:
3939
version: ${{ matrix.version }}
4040
arch: ${{ matrix.arch }}
41-
- uses: julia-actions/cache@v1
41+
- uses: julia-actions/cache@v2
4242
- name: Get changed files
4343
id: RLBase-changed
4444
uses: tj-actions/changed-files@v42
@@ -75,11 +75,11 @@ jobs:
7575
- uses: actions/checkout@v4
7676
with:
7777
fetch-depth: 100
78-
- uses: julia-actions/setup-julia@v1
78+
- uses: julia-actions/setup-julia@v2
7979
with:
8080
version: ${{ matrix.version }}
8181
arch: ${{ matrix.arch }}
82-
- uses: julia-actions/cache@v1
82+
- uses: julia-actions/cache@v2
8383
- name: Get changed files
8484
id: RLCore-changed
8585
uses: tj-actions/changed-files@v42
@@ -121,11 +121,11 @@ jobs:
121121
- uses: actions/checkout@v4
122122
with:
123123
fetch-depth: 100
124-
- uses: julia-actions/setup-julia@v1
124+
- uses: julia-actions/setup-julia@v2
125125
with:
126126
version: ${{ matrix.version }}
127127
arch: ${{ matrix.arch }}
128-
- uses: julia-actions/cache@v1
128+
- uses: julia-actions/cache@v2
129129
- name: Get changed files
130130
id: RLFarm-changed
131131
uses: tj-actions/changed-files@v42
@@ -168,11 +168,11 @@ jobs:
168168
- uses: actions/checkout@v4
169169
with:
170170
fetch-depth: 100
171-
- uses: julia-actions/setup-julia@v1
171+
- uses: julia-actions/setup-julia@v2
172172
with:
173173
version: ${{ matrix.version }}
174174
arch: ${{ matrix.arch }}
175-
# - uses: julia-actions/cache@v1
175+
# - uses: julia-actions/cache@v2
176176
- name: Get changed files
177177
id: RLEnvironments-changed
178178
uses: tj-actions/changed-files@v42
@@ -205,7 +205,7 @@ jobs:
205205
with:
206206
fetch-depth: 0
207207
- run: python -m pip install --user matplotlib
208-
- uses: julia-actions/setup-julia@v1
208+
- uses: julia-actions/setup-julia@v2
209209
with:
210210
version: "1"
211211
- name: Build homepage

Project.toml

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ authors = ["Johanni Brea <[email protected]>", "Jun Tian <tianjun.c
44
version = "0.11.0"
55

66
[deps]
7-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
87
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
98
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
109
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"

docs/homepage/blog/a_practical_introduction_to_RL.jl/index.html

+1-1
Original file line numberDiff line numberDiff line change
@@ -15415,7 +15415,7 @@ <h2 id="Environments">Environments<a class="anchor-link" href="#Environments">&#
1541515415
<div class="text_cell_render border-box-sizing rendered_html">
1541615416

1541715417
<pre><code>RLBase.action_space(env::MultiArmBanditsEnv) = Base.OneTo(length(env.true_values))
15418-
RLBase.state(env::MultiArmBanditsEnv) = 1
15418+
RLBase.state(env::MultiArmBanditsEnv, ::Observation, ::DefaultPlayer) = 1
1541915419
RLBase.state_space(env::MultiArmBanditsEnv) = Base.OneTo(1)
1542015420
RLBase.is_terminated(env::MultiArmBanditsEnv) = env.is_terminated
1542115421
RLBase.reward(env::MultiArmBanditsEnv) = env.reward

docs/src/How_to_write_a_customized_environment.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Here `RLBase` is just an alias for `ReinforcementLearningBase`.
6868

6969
```@repl customized_env
7070
RLBase.reward(env::LotteryEnv) = env.reward
71-
RLBase.state(env::LotteryEnv) = !isnothing(env.reward)
71+
RLBase.state(env::LotteryEnv, ::Observation, ::DefaultPlayer) = !isnothing(env.reward)
7272
RLBase.state_space(env::LotteryEnv) = [false, true]
7373
RLBase.is_terminated(env::LotteryEnv) = !isnothing(env.reward)
7474
RLBase.reset!(env::LotteryEnv) = env.reward = nothing
@@ -181,7 +181,7 @@ RLCore.forward(p.learner.approximator, false)
181181

182182
OK, now we know where the problem is. But how to fix it?
183183

184-
An initial idea is to rewrite the `RLBase.state(env::LotteryEnv)` function to
184+
An initial idea is to rewrite the `RLBase.state(env::LotteryEnv, ::Observation, ::DefaultPlayer)` function to
185185
force it return an `Int`. That's workable. But in some cases, we may be using
186186
environments written by others and it's not very easy to modify the code
187187
directly. Fortunatelly, some environment wrappers are provided to help us

src/ReinforcementLearningBase/NEWS.md

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
### ReinforcementLearningBase.jl Release Notes
22

3+
#### v0.13.1
4+
5+
- Don't call `legal_action_space_mask` methods when `ActionStyle` is `MinimalActionSet`
6+
7+
#### v0.13.0
8+
9+
- Breaking release compatible with RL.jl v0.11
10+
311
#### v0.12.0
412

513
- Transition to `RLCore.forward`, `RLBase.act!`, `RLBase.plan!` and `Base.push!` syntax instead of functional objects for hooks, policies and environments

src/ReinforcementLearningBase/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReinforcementLearningBase"
22
uuid = "e575027e-6cd6-5018-9292-cdc6200d2b44"
33
authors = ["Johanni Brea <[email protected]>", "Jun Tian <[email protected]>"]
4-
version = "0.13.0"
4+
version = "0.13.1"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/ReinforcementLearningBase/src/interface.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ For environments of [`MINIMAL_ACTION_SET`](@ref), the result is the same with
487487
@multi_agent_env_api legal_action_space(env::AbstractEnv, player=current_player(env)) =
488488
legal_action_space(ActionStyle(env), env, player)
489489

490-
legal_action_space(::MinimalActionSet, env, player::AbstractPlayer) = action_space(env)
490+
legal_action_space(::MinimalActionSet, env::AbstractEnv, player::AbstractPlayer) = action_space(env)
491491

492492
"""
493493
legal_action_space_mask(env, player=current_player(env)) -> AbstractArray{Bool}
@@ -497,6 +497,9 @@ Required for environments of [`FULL_ACTION_SET`](@ref). As a default implementat
497497
the subset [`legal_action_space`](@ref).
498498
"""
499499
@multi_agent_env_api legal_action_space_mask(env::AbstractEnv, player=current_player(env)) =
500+
legal_action_space_mask(ActionStyle(env), env, player)
501+
502+
legal_action_space_mask(::FullActionSet, env::AbstractEnv, player=current_player(env)) =
500503
map(action_space(env, player)) do action
501504
action in legal_action_space(env, player)
502505
end

src/ReinforcementLearningBase/test/interface.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ struct TestEnv <: RLBase.AbstractEnv
44
state::Int
55
end
66

7-
function RLBase.state(env::TestEnv, ::Observation{Any}, ::DefaultPlayer)
7+
function RLBase.state(env::TestEnv, ::Observation, ::DefaultPlayer)
88
return env.state
99
end
1010

src/ReinforcementLearningCore/NEWS.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# ReinforcementLearningCore.jl Release Notes
22

3+
#### v0.15.3
4+
5+
- Make `FluxApproximator` work with `QBasedPolicy`
6+
37
#### v0.15.2
48

59
- Make QBasedPolicy general for AbstractLearner s (#1069)

src/ReinforcementLearningCore/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReinforcementLearningCore"
22
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
3-
version = "0.15.2"
3+
version = "0.15.3"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/ReinforcementLearningCore/src/policies/learners/abstract_learner.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::Trajectory) end
2525

2626
function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::NamedTuple) end
2727

28-
function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv)
29-
legal_action_space_ = RLBase.legal_action_space_mask(env)
30-
RLBase.plan!(explorer, forward(learner, env), legal_action_space_)
28+
function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv, player=current_player(env))
29+
return RLBase.plan!(ActionStyle(env), explorer, learner, env, player)
3130
end
3231

33-
function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv, player::AbstractPlayer)
32+
function RLBase.plan!(::FullActionSet, explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv, player=current_player(env))
3433
legal_action_space_ = RLBase.legal_action_space_mask(env, player)
3534
return RLBase.plan!(explorer, forward(learner, env, player), legal_action_space_)
3635
end
36+
37+
function RLBase.plan!(::MinimalActionSet, explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv, player=current_player(env))
38+
return RLBase.plan!(explorer, forward(learner, env, player))
39+
end

src/ReinforcementLearningCore/src/policies/learners/flux_approximator.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ FluxApproximator(model, optimiser::Flux.Optimise.AbstractOptimiser; use_gpu=fals
4141
Flux.@layer FluxApproximator trainable=(model,)
4242

4343
forward(A::FluxApproximator, args...; kwargs...) = A.model(args...; kwargs...)
44-
forward(A::FluxApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(A, x))
44+
forward(A::FluxApproximator, env::E, player::AbstractPlayer=current_player(env)) where {E <: AbstractEnv} = env |> (x -> state(x, player)) |> (x -> forward(A, x))
4545

4646
RLBase.optimise!(A::FluxApproximator, grad::NamedTuple) =
4747
Flux.Optimise.update!(A.optimiser_state, A.model, grad.model)

src/ReinforcementLearningCore/src/policies/q_based_policy.jl

+7
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@ function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E, player::Player) where
3636
end
3737

3838
RLBase.prob(policy::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} =
39+
prob(ActionStyle(env), policy, env)
40+
41+
RLBase.prob(::MinimalActionSet, policy::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} =
42+
prob(policy.explorer, forward(policy.learner, env))
43+
44+
RLBase.prob(::FullActionSet, policy::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} =
3945
prob(policy.explorer, forward(policy.learner, env), legal_action_space_mask(env))
4046

47+
4148
#the internal learner defines the optimization stage.
4249
RLBase.optimise!(policy::QBasedPolicy, stage::AbstractStage, trajectory::Trajectory) = RLBase.optimise!(policy.learner, stage, trajectory)

src/ReinforcementLearningCore/src/policies/random_policy.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ RandomPolicy(s = nothing; rng = Random.default_rng()) = RandomPolicy(s, rng)
2424

2525
RLBase.optimise!(::RandomPolicy, x::NamedTuple) = nothing
2626

27-
RLBase.plan!(p::RandomPolicy{S,RNG}, env::AbstractEnv) where {S,RNG<:AbstractRNG} = rand(p.rng, p.action_space)
27+
RLBase.plan!(p::RandomPolicy{S,RNG}, ::AbstractEnv) where {S,RNG<:AbstractRNG} = rand(p.rng, p.action_space)
2828

2929
function RLBase.plan!(p::RandomPolicy{Nothing,RNG}, env::AbstractEnv) where {RNG<:AbstractRNG}
3030
legal_action_space_ = RLBase.legal_action_space(env)
@@ -45,7 +45,7 @@ function RLBase.prob(p::RandomPolicy{S,RNG}, s) where {S,RNG<:AbstractRNG}
4545
Categorical(Fill(1 / n, n); check_args = false)
4646
end
4747

48-
RLBase.prob(p::RandomPolicy{Nothing,RNG}, x) where {RNG<:AbstractRNG} =
48+
RLBase.prob(::RandomPolicy{Nothing,RNG}, x) where {RNG<:AbstractRNG} =
4949
@error "no I really don't know how to calculate the prob from nothing"
5050

5151
#####
@@ -54,7 +54,7 @@ RLBase.prob(p::RandomPolicy{Nothing,RNG}, env::AbstractEnv) where {RNG<:Abstract
5454
prob(p, env, ChanceStyle(env))
5555

5656
function RLBase.prob(
57-
p::RandomPolicy{Nothing,RNG},
57+
::RandomPolicy{Nothing,RNG},
5858
env::AbstractEnv,
5959
::RLBase.AbstractChanceStyle,
6060
) where {RNG<:AbstractRNG}

src/ReinforcementLearningCore/test/policies/learners/abstract_learner.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct MockLearner <: AbstractLearner end
1414
return [1.0, 2.0]
1515
end
1616

17-
RLBase.state(::MockEnv, ::Observation{Any}, ::DefaultPlayer) = 1
17+
RLBase.state(::MockEnv, ::Observation, ::DefaultPlayer) = 1
1818
RLBase.state(::MockEnv, ::Observation{Any}, ::Player) = 1
1919

2020
env = MockEnv()

src/ReinforcementLearningEnvironments/NEWS.md

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
### ReinforcementLearningEnvironments.jl Release Notes
22

3+
#### v0.9.1
4+
5+
- Update `state` calls to use full signature (so compatible with more algorithms)
6+
7+
#### v0.9.0
8+
9+
- Compatible with RL.jl v0.11
10+
311
#### v0.8
412

513
- Transition to `RLCore.forward`, `RLBase.act!`, `RLBase.plan!` and `Base.push!` syntax instead of functional objects for hooks, policies and environments
@@ -63,4 +71,4 @@
6371

6472
#### v0.6.0
6573

66-
- Set `AcrobotEnv` into lazy loading to reduce the dependency of `OrdinaryDiffEq`.
74+
- Set `AcrobotEnv` into lazy loading to reduce the dependency of `OrdinaryDiffEq`.

src/ReinforcementLearningEnvironments/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977"
3434
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
3535
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
3636
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
37+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3738
OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
3839
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
3940
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
@@ -48,6 +49,7 @@ test = [
4849
"JLD2",
4950
"Conda",
5051
"DomainSets",
52+
"Flux",
5153
"OpenSpiel",
5254
"OrdinaryDiffEq",
5355
"PyCall",

src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ RLBase.state_space(env::AcrobotEnv) = ArrayProductDomain(
8585
)
8686

8787
RLBase.is_terminated(env::AcrobotEnv) = env.done
88-
RLBase.state(env::AcrobotEnv) = acrobot_observation(env.state)
88+
RLBase.state(env::AcrobotEnv, ::Observation, ::DefaultPlayer) = acrobot_observation(env.state)
8989
RLBase.reward(env::AcrobotEnv) = env.reward
9090

9191
function RLBase.reset!(env::AcrobotEnv{T}) where {T<:Number}

src/ReinforcementLearningEnvironments/src/environments/3rd_party/atari.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ RLBase.nameof(env::AtariEnv) = "AtariEnv($(env.name))"
121121
RLBase.action_space(env::AtariEnv) = env.action_space
122122
RLBase.reward(env::AtariEnv) = env.reward
123123
RLBase.is_terminated(env::AtariEnv) = is_terminal(env)
124-
RLBase.state(env::AtariEnv) = env.screens[1]
124+
RLBase.state(env::AtariEnv, ::Observation, ::DefaultPlayer) = env.screens[1]
125125
RLBase.state_space(env::AtariEnv) = env.observation_space
126126

127127
function Random.seed!(env::AtariEnv, s)

src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function RLBase.is_terminated(env::GymEnv{T}) where {T}
8686
end
8787
end
8888

89-
function RLBase.state(env::GymEnv{T}) where {T}
89+
function RLBase.state(env::GymEnv{T}, ::Observation, ::DefaultPlayer) where {T}
9090
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4
9191
obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
9292
obs

src/ReinforcementLearningEnvironments/src/environments/3rd_party/snake.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ RLBase.act!(env::SnakeGameEnv, action::Int) = env([SNAKE_GAME_ACTIONS[action]])
4242
RLBase.act!(env::SnakeGameEnv, actions::Vector{Int}) = env(map(a -> SNAKE_GAME_ACTIONS[a], actions))
4343

4444
RLBase.action_space(env::SnakeGameEnv) = Base.OneTo(4)
45-
RLBase.state(env::SnakeGameEnv) = env.game.board
45+
RLBase.state(env::SnakeGameEnv, ::Observation, ::DefaultPlayer) = env.game.board
4646
RLBase.state_space(env::SnakeGameEnv) = ArrayProductDomain(fill(false:true, size(env.game.board)))
4747
RLBase.reward(env::SnakeGameEnv{<:Any,SINGLE_AGENT}) =
4848
length(env.game.snakes[]) - env.latest_snakes_length[]

src/ReinforcementLearningEnvironments/src/environments/examples/BitFlippingEnv.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function RLBase.act!(env::BitFlippingEnv, action::Int)
3737
end
3838
end
3939

40-
RLBase.state(env::BitFlippingEnv) = state(env::BitFlippingEnv, Observation{BitArray{1}}())
40+
RLBase.state(env::BitFlippingEnv, ::Observation, ::DefaultPlayer) = state(env::BitFlippingEnv, Observation{BitArray{1}}())
4141
RLBase.state(env::BitFlippingEnv, ::Observation) = env.state
4242
RLBase.state(env::BitFlippingEnv, ::GoalState) = env.goal_state
4343
RLBase.state_space(env::BitFlippingEnv, ::Observation) = ArrayProductDomain(fill(false:true, env.N))

src/ReinforcementLearningEnvironments/src/environments/examples/CartPoleEnv.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ CartPoleEnv{T}(; kwargs...) where {T} = CartPoleEnv(T=T, kwargs...)
8383
Random.seed!(env::CartPoleEnv, seed) = Random.seed!(env.rng, seed)
8484
RLBase.reward(env::CartPoleEnv{T}) where {T} = env.done ? zero(T) : one(T)
8585
RLBase.is_terminated(env::CartPoleEnv) = env.done
86-
RLBase.state(env::CartPoleEnv) = env.state
86+
RLBase.state(env::CartPoleEnv, ::Observation, ::DefaultPlayer) = env.state
8787

8888
function RLBase.state_space(env::CartPoleEnv{T}) where {T}
8989
((-2 * env.params.xthreshold) .. (2 * env.params.xthreshold)) ×

src/ReinforcementLearningEnvironments/src/environments/examples/GraphShortestPathEnv.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function RLBase.act!(env::GraphShortestPathEnv, action)
5454
env.reward = env.pos == env.goal ? 0 : -1
5555
end
5656

57-
RLBase.state(env::GraphShortestPathEnv) = env.pos
57+
RLBase.state(env::GraphShortestPathEnv, ::Observation, ::DefaultPlayer) = env.pos
5858
RLBase.state_space(env::GraphShortestPathEnv) = axes(env.graph, 2)
5959
RLBase.action_space(env::GraphShortestPathEnv) = axes(env.graph, 2)
6060
RLBase.legal_action_space(env::GraphShortestPathEnv) = (env.graph[:, env.pos]).nzind

src/ReinforcementLearningEnvironments/src/environments/examples/KuhnPokerEnv.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ RLBase.action_space(env::KuhnPokerEnv, ::ChancePlayer) = Base.OneTo(length(KUHN_
107107

108108
RLBase.legal_action_space(env::KuhnPokerEnv, p::ChancePlayer) = Tuple(x for x in action_space(env, p) if KUHN_POKER_CARDS[x] env.cards)
109109

110-
function RLBase.legal_action_space_mask(env::KuhnPokerEnv, p::ChancePlayer)
110+
function RLBase.legal_action_space_mask(env::KuhnPokerEnv, ::ChancePlayer)
111111
m = fill(true, 3)
112112
m[env.cards] .= false
113113
m

src/ReinforcementLearningEnvironments/src/environments/examples/MontyHallEnv.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function RLBase.legal_action_space_mask(env::MontyHallEnv)
5858
mask
5959
end
6060

61-
function RLBase.state(env::MontyHallEnv)
61+
function RLBase.state(env::MontyHallEnv, ::Observation, ::DefaultPlayer)
6262
if isnothing(env.host_action)
6363
1
6464
else

src/ReinforcementLearningEnvironments/src/environments/examples/MountainCarEnv.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ RLBase.action_space(::MountainCarEnv{<:AbstractFloat,<:AbstractFloat}) = -1.0 ..
9494

9595
RLBase.reward(env::MountainCarEnv{T}) where {T} = env.done ? zero(T) : -one(T)
9696
RLBase.is_terminated(env::MountainCarEnv) = env.done
97-
RLBase.state(env::MountainCarEnv) = env.state
97+
RLBase.state(env::MountainCarEnv, ::Observation, ::DefaultPlayer) = env.state
9898

9999
function RLBase.reset!(env::MountainCarEnv{T}) where {T}
100100
env.state[1] = 0.2 * rand(env.rng, T) - 0.6

0 commit comments

Comments
 (0)