@@ -466,15 +466,16 @@ def view_cone(transparency_map: Array, origin: Array, radius: int) -> Array:
466466
467467 Returns:
468468 Array: The view cone of the given origin in the grid with the given radius."""
469- # transparency_map is a boolean map of transparent (1) and opaque (0) tiles
470469
471470 def fin_diff (array , _ ):
472471 array = jnp .roll (array , - 1 , axis = 0 ) + array + jnp .roll (array , + 1 , axis = 0 )
473472 array = jnp .roll (array , - 1 , axis = 1 ) + array + jnp .roll (array , + 1 , axis = 1 )
474473 return array * transparency_map , ()
475474
475+ # initialise the field to all zeros, except at the source (agent's position)
476476 mask = jnp .zeros_like (transparency_map ).at [tuple (origin )].set (1 )
477477
478+ # start the diffusion process using finite differences
478479 # if radius is small, it should be fast enough to compile
479480 MIN_SCAN_RADIUS = 10
480481 if radius <= MIN_SCAN_RADIUS :
@@ -484,15 +485,32 @@ def fin_diff(array, _):
484485 else :
485486 view = jax .lax .scan (fin_diff , mask , None , radius )[0 ]
486487
488+ # view has anything that is visible > 0
487489 # we now set a hard threshold > 0, but we can also think in the future
488490 # to use a cutoff at a different value to mimic the effect of a torch
489- # (or eyesight for what matters)
490- view = jnp .where (view > 0 , 1 , 0 )
491+ vis_free = view > 0
492+
493+ # add frontier obstacles
494+ # frontier obstacles = opaque cells neighbouring any visible-free cell (8-neighbourhood)
495+ opaque = transparency_map == 0
496+ nb = (
497+ vis_free
498+ | jnp .roll (vis_free , + 1 , 0 )
499+ | jnp .roll (vis_free , - 1 , 0 )
500+ | jnp .roll (vis_free , + 1 , 1 )
501+ | jnp .roll (vis_free , - 1 , 1 )
502+ | jnp .roll (jnp .roll (vis_free , + 1 , 0 ), + 1 , 1 )
503+ | jnp .roll (jnp .roll (vis_free , + 1 , 0 ), - 1 , 1 )
504+ | jnp .roll (jnp .roll (vis_free , - 1 , 0 ), + 1 , 1 )
505+ | jnp .roll (jnp .roll (vis_free , - 1 , 0 ), - 1 , 1 )
506+ )
507+ frontier = nb & opaque
491508
492- # we add back the opaque tiles
493- view = jnp .where (transparency_map == 0 , 1 , view )
509+ # final visible = transparent region plus blocking frontier
510+ visible = vis_free | frontier
511+ visible = visible .at [tuple (origin )].set (True )
494512
495- return view
513+ return visible . astype ( transparency_map . dtype )
496514
497515
498516def from_ascii_map (ascii_map : str , mapping : Dict [str , int ] = {}) -> Array :
0 commit comments