11# TODO: this is rendering mostly ported or adapted from the original Minigrid. A bit dirty right now...
22import functools
33import math
4+ from typing import Optional , Union
45
56import numpy as np
67
@@ -181,7 +182,7 @@ def _render_player(img: np.ndarray, direction: int):
181182
182183
183184# TODO: add highlight for can_see_through_walls=Fasle
184- def get_highlight_mask (grid : np .ndarray , agent : AgentState | None , view_size : int ) -> np .ndarray :
185+ def get_highlight_mask (grid : np .ndarray , agent : Optional [ AgentState ] , view_size : int ) -> np .ndarray :
185186 mask = np .zeros ((grid .shape [0 ] + 2 * view_size , grid .shape [1 ] + 2 * view_size ), dtype = np .bool_ )
186187 if agent is None :
187188 return mask
@@ -207,7 +208,7 @@ def get_highlight_mask(grid: np.ndarray, agent: AgentState | None, view_size: in
207208
208209@functools .cache
209210def render_tile (
210- tile : tuple , agent_direction : int | None = None , highlight : bool = False , tile_size : int = 32 , subdivs : int = 3
211+ tile : tuple , agent_direction : Optional [ int ] = None , highlight : bool = False , tile_size : int = 32 , subdivs : int = 3
211212) -> np .ndarray :
212213 img = np .full ((tile_size * subdivs , tile_size * subdivs , 3 ), dtype = np .uint8 , fill_value = 255 )
213214 # draw tile
@@ -228,7 +229,7 @@ def render_tile(
228229# WARN: will NOT work under jit and needed for debugging/presentation mainly.
229230def render (
230231 grid : np .ndarray ,
231- agent : AgentState | None = None ,
232+ agent : Optional [ AgentState ] = None ,
232233 view_size : IntOrArray = 7 ,
233234 tile_size : IntOrArray = 32 ,
234235) -> np .ndarray :
0 commit comments