Skip to content

Commit fb6e5e8

Browse files
committed
feat: add obstruction calculation to fp obs
1 parent a01a449 commit fb6e5e8

2 files changed

Lines changed: 31 additions & 12 deletions

File tree

navix/grid.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,15 +466,16 @@ def view_cone(transparency_map: Array, origin: Array, radius: int) -> Array:
466466
467467
Returns:
468468
Array: The view cone of the given origin in the grid with the given radius."""
469-
# transparency_map is a boolean map of transparent (1) and opaque (0) tiles
470469

471470
def fin_diff(array, _):
472471
array = jnp.roll(array, -1, axis=0) + array + jnp.roll(array, +1, axis=0)
473472
array = jnp.roll(array, -1, axis=1) + array + jnp.roll(array, +1, axis=1)
474473
return array * transparency_map, ()
475474

475+
# initialise the field to all zeros, except at the source (agent's position)
476476
mask = jnp.zeros_like(transparency_map).at[tuple(origin)].set(1)
477477

478+
# start the diffusion process using finite differences
478479
# if radius is small, it should be fast enough to compile
479480
MIN_SCAN_RADIUS = 10
480481
if radius <= MIN_SCAN_RADIUS:
@@ -484,15 +485,32 @@ def fin_diff(array, _):
484485
else:
485486
view = jax.lax.scan(fin_diff, mask, None, radius)[0]
486487

488+
# view has anything that is visible > 0
487489
# we now set a hard threshold > 0, but we can also think in the future
488490
# to use a cutoff at a different value to mimic the effect of a torch
489-
# (or eyesight for what matters)
490-
view = jnp.where(view > 0, 1, 0)
491+
vis_free = view > 0
492+
493+
# add frontier obstacles
494+
# frontier obstacles = opaque cells neighbouring any visible-free cell (8-neighbourhood)
495+
opaque = transparency_map == 0
496+
nb = (
497+
vis_free
498+
| jnp.roll(vis_free, +1, 0)
499+
| jnp.roll(vis_free, -1, 0)
500+
| jnp.roll(vis_free, +1, 1)
501+
| jnp.roll(vis_free, -1, 1)
502+
| jnp.roll(jnp.roll(vis_free, +1, 0), +1, 1)
503+
| jnp.roll(jnp.roll(vis_free, +1, 0), -1, 1)
504+
| jnp.roll(jnp.roll(vis_free, -1, 0), +1, 1)
505+
| jnp.roll(jnp.roll(vis_free, -1, 0), -1, 1)
506+
)
507+
frontier = nb & opaque
491508

492-
# we add back the opaque tiles
493-
view = jnp.where(transparency_map == 0, 1, view)
509+
# final visible = transparent region plus blocking frontier
510+
visible = vis_free | frontier
511+
visible = visible.at[tuple(origin)].set(True)
494512

495-
return view
513+
return visible.astype(transparency_map.dtype)
496514

497515

498516
def from_ascii_map(ascii_map: str, mapping: Dict[str, int] = {}) -> Array:

navix/observations.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,23 +253,24 @@ def rgb_first_person(state: State) -> Array:
253253
*state.grid.shape, *patches.shape[1:]
254254
) # (H, W, TILE_SIZE, TILE_SIZE, 3)
255255

256+
# apply minigrid opacity
257+
patchwork = apply_minigrid_opacity(patchwork)
258+
256259
# apply fov
260+
dark_cell_colour = 0 # dark color for unseen tiles
257261
transparency_map = jnp.where(state.grid == 0, 1, 0) # (H, W)
258262
positions = state.get_positions()
259263
transparent = state.get_transparency()
260-
transparency_map = transparency_map.at[tuple(positions.T)].set(~transparent)
264+
transparency_map = transparency_map.at[tuple(positions.T)].set(transparent)
261265
view = view_cone(transparency_map, player.position, RADIUS) # (H, W)
262266
view = jnp.asarray(view, dtype=jnp.bool)
263-
patchwork = patchwork * view[..., None, None, None]
267+
patchwork = jnp.where(view[..., None, None, None], patchwork, dark_cell_colour)
264268

265269
# crop grid to agent's view
266270
patchwork = crop(
267-
patchwork, player.position, player.direction, RADIUS
271+
patchwork, player.position, player.direction, RADIUS, dark_cell_colour
268272
) # (RADIUS * 2 + 1, RADIUS * 2 + 1, TILE_SIZE, TILE_SIZE, 3)
269273

270-
# apply minigrid opacity
271-
patchwork = apply_minigrid_opacity(patchwork)
272-
273274
# reconstruct image
274275
obs = jnp.swapaxes(patchwork, 1, 2)
275276
shape = obs.shape

0 commit comments

Comments
 (0)