-
Notifications
You must be signed in to change notification settings - Fork 63
Open
Description
Hi there,
I caught the following errors in the to_agent method for SceneBatch while trying to visualize some scenes:
- The
index_neighborsfunction does not preserve StateArrays/StateTensors. This can be fixed by checking if a statetensor is passed and using its formatting.
def index_neighbors(x: Tensor | StateTensor) -> Tensor | StateTensor:
index_neighbors = x[others_mask].reshape([batch_size, num_agents-1]+list(x.shape[2:]))
if isinstance(x, StateTensor):
index_neighbors = StateTensor.from_array(index_neighbors, x._format)
return index_neighbors
- The
index_agentfunction doesn't play well with map names, this can be fixed by wrapping each map_name into a singleton list:
map_names=index_agent_list([[m] for m in self.map_names]),
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels