Skip to content

Commit 0d306d1

Browse files
authored
compatibility with python3.9 (#55)
* compatibility with python 3.9 * bump version
1 parent 9d2fb53 commit 0d306d1

File tree

12 files changed

+29
-23
lines changed

12 files changed

+29
-23
lines changed

src/xminigrid/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .registration import make, register, registered_environments
33

44
# TODO: add __all__
5-
__version__ = "0.9.2"
5+
__version__ = "0.9.3"
66

77
# ---------- XLand-MiniGrid environments ----------
88

src/xminigrid/benchmarks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import pickle
66
import urllib.request
7-
from typing import Callable
7+
from typing import Callable, Union
88

99
import jax
1010
import jax.numpy as jnp
@@ -38,7 +38,7 @@ class Benchmark(struct.PyTreeNode):
3838
def num_rulesets(self) -> int:
3939
return len(self.goals)
4040

41-
def get_ruleset(self, ruleset_id: int | jax.Array) -> RuleSet:
41+
def get_ruleset(self, ruleset_id: Union[int, jax.Array]) -> RuleSet:
4242
return get_ruleset(self.goals, self.rules, self.init_tiles, ruleset_id)
4343

4444
def sample_ruleset(self, key: jax.Array) -> RuleSet:
@@ -114,7 +114,7 @@ def get_ruleset(
114114
goals: jax.Array,
115115
rules: jax.Array,
116116
init_tiles: jax.Array,
117-
ruleset_id: int | jax.Array,
117+
ruleset_id: Union[int, jax.Array],
118118
) -> RuleSet:
119119
goal = jax.lax.dynamic_index_in_dim(goals, ruleset_id, keepdims=False)
120120
rules = jax.lax.dynamic_index_in_dim(rules, ruleset_id, keepdims=False)

src/xminigrid/core/goals.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4+
from typing import Union
45

56
import jax
67
import jax.numpy as jnp
@@ -14,7 +15,7 @@
1415

1516

1617
def check_goal(
17-
encoding: jax.Array, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array
18+
encoding: jax.Array, grid: GridState, agent: AgentState, action: Union[int, jax.Array], position: jax.Array
1819
) -> jax.Array:
1920
check = jax.lax.switch(
2021
encoding[0],
@@ -45,7 +46,7 @@ def check_goal(
4546
class BaseGoal(struct.PyTreeNode):
4647
@abc.abstractmethod
4748
def __call__(
48-
self, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array
49+
self, grid: GridState, agent: AgentState, action: Union[int, jax.Array], position: jax.Array
4950
) -> jax.Array: ...
5051

5152
@classmethod

src/xminigrid/core/grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Callable
3+
from typing import Callable, Optional, Union
44

55
import jax
66
import jax.numpy as jnp
@@ -158,7 +158,7 @@ def coordinates_mask(grid: GridState, address: tuple[IntOrArray, IntOrArray], co
158158
return mask
159159

160160

161-
def sample_coordinates(key: jax.Array, grid: GridState, num: int, mask: jax.Array | None = None) -> jax.Array:
161+
def sample_coordinates(key: jax.Array, grid: GridState, num: int, mask: Optional[jax.Array] = None) -> jax.Array:
162162
if mask is None:
163163
mask = jnp.ones((grid.shape[0], grid.shape[1]), dtype=jnp.bool_)
164164

src/xminigrid/core/rules.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4+
from typing import Union
45

56
import jax
67
import jax.numpy as jnp
@@ -18,7 +19,7 @@
1819
# In general, we need a way to select specific function/class based on ID number.
1920
# We can not just decode without evaluation, as then return type will be different between branches
2021
def check_rule(
21-
encodings: jax.Array, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array
22+
encodings: jax.Array, grid: GridState, agent: AgentState, action: Union[int, jax.Array], position: jax.Array
2223
) -> tuple[GridState, AgentState]:
2324
def _check(carry, encoding):
2425
grid, agent = carry
@@ -51,7 +52,7 @@ def _check(carry, encoding):
5152
class BaseRule(struct.PyTreeNode):
5253
@abc.abstractmethod
5354
def __call__(
54-
self, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array
55+
self, grid: GridState, agent: AgentState, action: Union[int, jax.Array], position: jax.Array
5556
) -> tuple[GridState, AgentState]: ...
5657

5758
@classmethod

src/xminigrid/environment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4-
from typing import Any, Generic, Optional, TypeVar
4+
from typing import Any, Generic, Optional, TypeVar, Union
55

66
import jax
77
import jax.numpy as jnp
@@ -40,7 +40,7 @@ def default_params(self, **kwargs: Any) -> EnvParamsT: ...
4040
def num_actions(self, params: EnvParamsT) -> int:
4141
return int(NUM_ACTIONS)
4242

43-
def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int] | dict[str, Any]:
43+
def observation_shape(self, params: EnvParamsT) -> Union[tuple[int, int, int], dict[str, Any]]:
4444
return params.view_size, params.view_size, NUM_LAYERS
4545

4646
@abc.abstractmethod
@@ -89,7 +89,7 @@ def step(self, params: EnvParamsT, timestep: TimeStep[EnvCarryT], action: IntOrA
8989
)
9090
return timestep
9191

92-
def render(self, params: EnvParamsT, timestep: TimeStep[EnvCarryT]) -> np.ndarray | str:
92+
def render(self, params: EnvParamsT, timestep: TimeStep[EnvCarryT]) -> Union[np.ndarray, str]:
9393
if params.render_mode == "rgb_array":
9494
return rgb_render(np.asarray(timestep.state.grid), timestep.state.agent, params.view_size)
9595
elif params.render_mode == "rich_text":

src/xminigrid/manual_control.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import argparse
44
import os
5+
from typing import Optional
56

67
import imageio.v3 as iio
78
import jax
@@ -24,7 +25,7 @@ def __init__(
2425
env_params: EnvParamsT,
2526
agent_view: bool = False,
2627
save_video: bool = False,
27-
video_path: str | None = None,
28+
video_path: Optional[str] = None,
2829
video_format: str = ".mp4",
2930
video_fps: int = 8,
3031
):

src/xminigrid/rendering/rgb_render.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# TODO: this is rendering mostly ported or adapted from the original Minigrid. A bit dirty right now...
22
import functools
33
import math
4+
from typing import Optional, Union
45

56
import numpy as np
67

@@ -181,7 +182,7 @@ def _render_player(img: np.ndarray, direction: int):
181182

182183

183184
# TODO: add highlight for can_see_through_walls=Fasle
184-
def get_highlight_mask(grid: np.ndarray, agent: AgentState | None, view_size: int) -> np.ndarray:
185+
def get_highlight_mask(grid: np.ndarray, agent: Optional[AgentState], view_size: int) -> np.ndarray:
185186
mask = np.zeros((grid.shape[0] + 2 * view_size, grid.shape[1] + 2 * view_size), dtype=np.bool_)
186187
if agent is None:
187188
return mask
@@ -207,7 +208,7 @@ def get_highlight_mask(grid: np.ndarray, agent: AgentState | None, view_size: in
207208

208209
@functools.cache
209210
def render_tile(
210-
tile: tuple, agent_direction: int | None = None, highlight: bool = False, tile_size: int = 32, subdivs: int = 3
211+
tile: tuple, agent_direction: Optional[int] = None, highlight: bool = False, tile_size: int = 32, subdivs: int = 3
211212
) -> np.ndarray:
212213
img = np.full((tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8, fill_value=255)
213214
# draw tile
@@ -228,7 +229,7 @@ def render_tile(
228229
# WARN: will NOT work under jit and needed for debugging/presentation mainly.
229230
def render(
230231
grid: np.ndarray,
231-
agent: AgentState | None = None,
232+
agent: Optional[AgentState] = None,
232233
view_size: IntOrArray = 7,
233234
tile_size: IntOrArray = 32,
234235
) -> np.ndarray:

src/xminigrid/rendering/text_render.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional, Union
2+
13
import jax
24
import jax.numpy as jnp
35

@@ -55,7 +57,7 @@ def _wrap_with_color(string: str, color: str) -> str:
5557

5658

5759
# WARN: will NOT work under jit and needed for debugging mainly.
58-
def render(grid: jax.Array, agent: AgentState | None = None) -> str:
60+
def render(grid: jax.Array, agent: Optional[AgentState] = None) -> str:
5961
string = ""
6062

6163
for y in range(grid.shape[0]):

src/xminigrid/rendering/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from __future__ import annotations
33

44
import math
5-
from typing import Callable
5+
from typing import Callable, Union
66

77
import numpy as np
88
from typing_extensions import TypeAlias
99

10-
Color: TypeAlias = tuple[int, int, int] | int | np.ndarray
10+
Color: TypeAlias = Union[tuple[int, int, int], int, np.ndarray]
1111
Point: TypeAlias = tuple[float, float] # | np.ndarray
1212

1313

0 commit comments

Comments
 (0)