Skip to content

Commit 78fe890

Browse files
Fixed key in Dict.keys() always being false (#608)
1 parent 7eb32fd commit 78fe890

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

gymnasium/spaces/dict.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import collections.abc
55
import typing
66
from collections import OrderedDict
7-
from typing import Any, Sequence
7+
from typing import Any, KeysView, Sequence
88

99
import numpy as np
1010

@@ -180,6 +180,10 @@ def __getitem__(self, key: str) -> Space[Any]:
180180
"""Get the space that is associated to `key`."""
181181
return self.spaces[key]
182182

183+
def keys(self) -> KeysView:
184+
"""Returns the keys of the Dict."""
185+
return KeysView(self.spaces)
186+
183187
def __setitem__(self, key: str, value: Space[Any]):
184188
"""Set the space that is associated to `key`."""
185189
assert isinstance(

tests/spaces/test_dict.py

+11
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,14 @@ def test_iterator():
145145
assert key in DICT_SPACE.spaces
146146

147147
assert {key for key in DICT_SPACE} == DICT_SPACE.spaces.keys()
148+
149+
150+
def test_keys_contains():
151+
"""Test that `Dict.keys()` will correctly assess if the key is in the space."""
152+
space = Dict(a=Box(0, 1), b=Box(1, 2))
153+
154+
for key in space.keys():
155+
assert key in space.keys()
156+
assert "a" in space.keys()
157+
158+
assert "c" not in space.keys()

0 commit comments

Comments
 (0)