Skip to content

Issues with SceneBatch.to_agent #37

@bmacadam-sfu

Description

@bmacadam-sfu

Hi there,

I caught the following errors in the to_agent method for SceneBatch while trying to visualize some scenes:

  1. The index_neighbors function does not preserve StateArrays/StateTensors. This can be fixed by checking if a statetensor is passed and using its formatting.
        def index_neighbors(x: Tensor | StateTensor) -> Tensor | StateTensor:
            index_neighbors = x[others_mask].reshape([batch_size, num_agents-1]+list(x.shape[2:]))
            if isinstance(x, StateTensor):
                index_neighbors = StateTensor.from_array(index_neighbors, x._format)
            return index_neighbors
  1. The index_agent function doesn't play well with map names, this can be fixed by wrapping each map_name into a singleton list:
            map_names=index_agent_list([[m] for m in self.map_names]),

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions