Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions genesis/engine/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
PBDOptions,
ProfilingOptions,
RigidOptions,
SensorOptions,
SFOptions,
SimOptions,
SPHOptions,
Expand Down Expand Up @@ -514,6 +515,10 @@ def add_light(
else:
gs.raise_exception("Adding lights is only supported by 'RayTracer' and 'BatchRenderer'.")

@gs.assert_unbuilt
def add_sensor(self, sensor_options: SensorOptions):
return self._sim._sensor_manager.create_sensor(sensor_options)

@gs.assert_unbuilt
def add_camera(
self,
Expand Down
8 changes: 8 additions & 0 deletions genesis/engine/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from .states.cache import QueriedStates
from .states.solvers import SimState
from genesis.sensors.sensor_manager import SensorManager

if TYPE_CHECKING:
from genesis.engine.scene import Scene
Expand Down Expand Up @@ -151,6 +152,9 @@ def __init__(
# entities
self._entities: list[Entity] = gs.List()

# sensors
self._sensor_manager = SensorManager(self)

def _add_entity(self, morph: Morph, material, surface, visualize_contact=False):
if isinstance(material, gs.materials.Tool):
entity = self.tool_solver.add_entity(self.n_entities, material, morph, surface)
Expand Down Expand Up @@ -206,6 +210,8 @@ def build(self):
if self.n_envs > 0 and self.sf_solver.is_active():
gs.raise_exception("Batching is not supported for SF solver as of now.")

self._sensor_manager.build()

# hybrid
for entity in self._entities:
if isinstance(entity, HybridEntity):
Expand Down Expand Up @@ -275,6 +281,8 @@ def step(self, in_backward=False):
if self.rigid_solver.is_active():
self.rigid_solver.clear_external_force()

self._sensor_manager.step()

def _step_grad(self):
for _ in range(self._substeps - 1, -1, -1):

Expand Down
2 changes: 1 addition & 1 deletion genesis/options/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .solvers import *
from .vis import *
from .profiling import ProfilingOptions

from .sensors import SensorOptions

__all__ = ["ProfilingOptions"]
16 changes: 16 additions & 0 deletions genesis/options/sensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from genesis.options import Options


class SensorOptions(Options):
"""
Base class for all sensor options.
Each sensor should have their own options class that inherits from this class.
The options class should be registered with the SensorManager using the @register_sensor decorator.

Parameters
----------
read_delay : float
The delay in seconds before the sensor data is read.
"""

read_delay: float = 0.0
123 changes: 116 additions & 7 deletions genesis/sensors/base_sensor.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,130 @@
from typing import TYPE_CHECKING, Any, List, Type

import numpy as np
import taichi as ti
import genesis as gs
import torch

from typing import List, Optional
import genesis as gs
from genesis.repr_base import RBC

if TYPE_CHECKING:
from genesis.options.sensors import SensorOptions
from genesis.utils.ring_buffer import TensorRingBuffer

from .sensor_manager import SensorManager


@ti.data_oriented
class Sensor(RBC):
"""
Base class for all types of sensors.
A sensor must have a read() method that returns the sensor data.
"""

def __init__(self, sensor_options: "SensorOptions", sensor_idx: int, sensor_manager: "SensorManager"):
self._options: "SensorOptions" = sensor_options
self._idx: int = sensor_idx
self._manager: "SensorManager" = sensor_manager

# initialized by SensorManager during build
self._read_delay_steps: int = 0
self._shape_indices: list[tuple[int, int]] = []
self._shared_metadata: dict[str, Any] | None = None
self._cache: "TensorRingBuffer" | None = None

# =============================== implementable methods ===============================

def build(self):
"""
This method is called by the SensorManager during the scene build phase to initialize the sensor.
This is where any shared metadata should be initialized.
"""
raise NotImplementedError("Sensors must implement `build()`.")

def _get_return_format(self) -> dict[str, tuple[int, ...]] | tuple[int, ...]:
"""
Data format of the read() return value.

Returns
-------
return_format : dict | tuple
- If tuple, the final shape of the read() return value.
e.g. (2, 3) means read() will return a tensor of shape (2, 3).
- If dict a dictionary with string keys and tensor values will be returned.
e.g. {"pos": (3,), "quat": (4,)} returns a dict of tensors [0:3] and [3:7] from the cache.
"""
raise NotImplementedError("Sensors must implement `return_format()`.")

def _get_cache_length(self) -> int:
"""
The length of the cache for this sensor instance, e.g. number of points for a Lidar point cloud.
"""
raise NotImplementedError("Sensors must implement `cache_length()`.")

@classmethod
def _update_shared_ground_truth_cache(
cls, shared_metadata: dict[str, Any], shared_ground_truth_cache: torch.Tensor
):
"""
Update the shared sensor ground truth cache for all sensors of this class using metadata in SensorManager.
"""
raise NotImplementedError("Sensors must implement `update_shared_ground_truth_cache()`.")

@classmethod
def _update_shared_cache(
cls, shared_metadata: dict[str, Any], shared_ground_truth_cache: torch.Tensor, shared_cache: "TensorRingBuffer"
):
"""
Update the shared sensor cache for all sensors of this class using metadata in SensorManager.
"""
raise NotImplementedError("Sensors must implement `update_shared_cache()`.")

@classmethod
def _get_cache_dtype(cls) -> torch.dtype:
"""
The dtype of the cache for this sensor.
"""
raise NotImplementedError("Sensors must implement `get_cache_dtype()`.")

# =============================== shared methods ===============================

@gs.assert_built
def read(self, envs_idx: Optional[List[int]] = None):
def read(self, envs_idx: List[int] | None = None):
"""
Read the sensor data.
Sensor implementations should ideally cache the data to avoid unnecessary computations.
Read the sensor data (with noise applied if applicable).
"""
raise NotImplementedError("The Sensor subclass must implement `read()`.")
return self._get_formatted_data(self._cache.get(self._read_delay_steps), envs_idx)

@gs.assert_built
def read_ground_truth(self, envs_idx: List[int] | None = None):
"""
Read the ground truth sensor data (without noise).
"""
return self._get_formatted_data(self._manager.get_cloned_from_ground_truth_cache(self), envs_idx)

def _get_formatted_data(
self, tensor: torch.Tensor, envs_idx: list[int] | None
) -> torch.Tensor | dict[str, torch.Tensor]:
# Note: This method does not clone the data tensor, it should have been cloned by the caller.

if envs_idx is None:
envs_idx = self._manager._sim._scene._envs_idx

return_format = self._get_return_format()
return_shapes = return_format.values() if isinstance(return_format, dict) else (return_format,)
return_values = []

for i, shape in enumerate(return_shapes):
start_idx, end_idx = self._shape_indices[i]
value = tensor[envs_idx, start_idx:end_idx].reshape(len(envs_idx), *shape).squeeze()
if self._manager._sim.n_envs == 0:
value = value.squeeze(0)
return_values.append(value)

if isinstance(return_format, dict):
return dict(zip(return_format.keys(), return_values))
else:
return return_values[0]

@property
def is_built(self) -> bool:
return self._manager._sim._scene._is_built
114 changes: 114 additions & 0 deletions genesis/sensors/sensor_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import TYPE_CHECKING, Any, Type

import numpy as np
import torch

import genesis as gs
from genesis.utils.ring_buffer import TensorRingBuffer

if TYPE_CHECKING:
from genesis.options.sensors import SensorOptions

from .base_sensor import Sensor


class SensorManager:
SENSOR_TYPES_MAP: dict[Type["SensorOptions"], Type["Sensor"]] = {}

def __init__(self, sim):
self._sim = sim
self._sensors_by_type: dict[Type["Sensor"], list["Sensor"]] = {}
self._sensors_metadata: dict[Type["Sensor"], dict[str, Any]] = {}
self._ground_truth_cache: dict[Type[torch.dtype], torch.Tensor] = {}
self._cache: dict[Type[torch.dtype], TensorRingBuffer] = {}
self._cache_slices_by_type: dict[Type["Sensor"], slice] = {}

self._last_ground_truth_cache_cloned_step: dict[Type[torch.dtype], int] = {}
self._cloned_ground_truth_cache: dict[Type[torch.dtype], torch.Tensor] = {}

def create_sensor(self, sensor_options: "SensorOptions"):
sensor_cls = SensorManager.SENSOR_TYPES_MAP[type(sensor_options)]
self._sensors_by_type.setdefault(sensor_cls, [])
sensor = sensor_cls(sensor_options, len(self._sensors_by_type[sensor_cls]), self)
self._sensors_by_type[sensor_cls].append(sensor)
return sensor

def build(self):
max_cache_buf_len = 0
cache_size_per_dtype = {}
for sensor_cls, sensors in self._sensors_by_type.items():
self._sensors_metadata[sensor_cls] = {}
dtype = sensor_cls._get_cache_dtype()
cache_size_per_dtype.setdefault(dtype, 0)
cls_cache_start_idx = cache_size_per_dtype[dtype]

for sensor in sensors:
return_format = sensor._get_return_format()
return_shapes = return_format.values() if isinstance(return_format, dict) else (return_format,)

tensor_size = 0
for shape in return_shapes:
data_size = np.prod(shape)
sensor._shape_indices.append((tensor_size, tensor_size + data_size))
tensor_size += data_size

delay_steps_float = sensor._options.read_delay / self._sim.dt
sensor._read_delay_steps = round(delay_steps_float)
if not np.isclose(delay_steps_float, sensor._read_delay_steps, atol=1e-6):
gs.logger.warn(
f"Read delay should be a multiple of the simulation time step. Got {sensor._options.read_delay}"
f" and {self._sim.dt}. Actual read delay will be {1/sensor._read_delay_steps}."
)

sensor._cache_size = sensor._get_cache_length() * tensor_size
sensor._cache_idx = cache_size_per_dtype[dtype]
cache_size_per_dtype[dtype] += sensor._cache_size

max_cache_buf_len = max(max_cache_buf_len, sensor._read_delay_steps + 1)

cls_cache_end_idx = cache_size_per_dtype[dtype]
self._cache_slices_by_type[sensor_cls] = slice(cls_cache_start_idx, cls_cache_end_idx)

for dtype in cache_size_per_dtype.keys():
cache_shape = (self._sim._B, cache_size_per_dtype[dtype])
self._ground_truth_cache[dtype] = torch.zeros(cache_shape, dtype=dtype)
self._cache[dtype] = TensorRingBuffer(max_cache_buf_len, cache_shape, dtype=dtype)

for sensor_cls, sensors in self._sensors_by_type.items():
dtype = sensor_cls._get_cache_dtype()
for sensor in sensors:
sensor._shared_metadata = self._sensors_metadata[sensor_cls]
sensor._cache = self._cache[dtype][:, sensor._cache_idx : sensor._cache_idx + sensor._cache_size]
sensor.build()

def step(self):
for sensor_cls in self._sensors_by_type.keys():
dtype = sensor_cls._get_cache_dtype()
cache_slice = self._cache_slices_by_type[sensor_cls]
sensor_cls._update_shared_ground_truth_cache(
self._sensors_metadata[sensor_cls], self._ground_truth_cache[dtype][cache_slice]
)
sensor_cls._update_shared_cache(
self._sensors_metadata[sensor_cls],
self._ground_truth_cache[dtype][cache_slice],
self._cache[dtype][cache_slice],
)

def get_cloned_from_ground_truth_cache(self, sensor: "Sensor") -> torch.Tensor:
dtype = sensor._get_cache_dtype()
if self._last_ground_truth_cache_cloned_step[dtype] != self._sim.cur_step_global:
self._last_ground_truth_cache_cloned_step[dtype] = self._sim.cur_step_global
self._cloned_ground_truth_cache[dtype] = self._ground_truth_cache[dtype].clone()
return self._cloned_ground_truth_cache[dtype][:, sensor._cache_idx : sensor._cache_idx + sensor._cache_size]

@property
def sensors(self):
return tuple([sensor for sensor_list in self._sensors_by_type.values() for sensor in sensor_list])


def register_sensor(sensor_cls: Type["Sensor"]):
def _impl(options_cls: Type["SensorOptions"]):
SensorManager.SENSOR_TYPES_MAP[options_cls] = sensor_cls
return options_cls

return _impl
Loading
Loading