Skip to content

Commit d9d4208

Browse files
gasnicaACMLCZH
authored andcommitted
[MISC] Add viewer code to handle moving objects (Genesis-Embodied-AI#1378)
* Add simple box-moving interaction. * Lazy load tensor in vec3.py, add Vec3/Quat.from_tensor/from_array. * Rename oobb to obb. * Vec3/Quat.as_tensor() is now a method.
1 parent 2b54138 commit d9d4208

File tree

6 files changed

+155
-73
lines changed

6 files changed

+155
-73
lines changed

genesis/engine/entities/rigid_entity/rigid_link.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,24 @@ def __init__(
5555
self._entity_idx_in_solver = entity.idx
5656

5757
self._uid = gs.UID()
58-
self._idx = idx
59-
self._parent_idx = parent_idx
60-
self._root_idx = root_idx
61-
self._child_idxs = list()
62-
self._invweight = invweight
63-
64-
self._joint_start = joint_start
65-
self._n_joints = n_joints
66-
67-
self._geom_start = geom_start
68-
self._cell_start = cell_start
69-
self._vert_start = vert_start
70-
self._face_start = face_start
71-
self._edge_start = edge_start
72-
self._verts_state_start = verts_state_start
73-
self._vgeom_start = vgeom_start
74-
self._vvert_start = vvert_start
75-
self._vface_start = vface_start
58+
self._idx: int = idx
59+
self._parent_idx: int = parent_idx # -1 if no parent
60+
self._root_idx: int | None = root_idx # None if no root
61+
self._child_idxs: list[int] = list()
62+
self._invweight: float | None = invweight
63+
64+
self._joint_start: int = joint_start
65+
self._n_joints: int = n_joints
66+
67+
self._geom_start: int = geom_start
68+
self._cell_start: int = cell_start
69+
self._vert_start: int = vert_start
70+
self._face_start: int = face_start
71+
self._edge_start: int = edge_start
72+
self._verts_state_start: int = verts_state_start
73+
self._vgeom_start: int = vgeom_start
74+
self._vvert_start: int = vvert_start
75+
self._vface_start: int = vface_start
7676

7777
# Link position & rotation at creation time:
7878
self._pos: ArrayLike = pos

genesis/ext/pyrender/interaction/aabb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def raycast(self, ray: Ray) -> RayHit:
5454
hit_pos = ray.origin + ray.direction * enter
5555
return RayHit(enter, hit_pos, normal)
5656

57-
def raycast_oobb(self, pose: Pose, ray: Ray) -> RayHit:
57+
def raycast_obb(self, pose: Pose, ray: Ray) -> RayHit:
5858
inv_pose = pose.get_inverse()
5959
origin2 = inv_pose.transform_point(ray.origin)
6060
direction2 = inv_pose.transform_direction(ray.direction)

genesis/ext/pyrender/interaction/vec3.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,20 @@
44
import numpy as np
55
from numpy.typing import NDArray
66

7+
from genesis.utils.misc import tensor_to_array
8+
79
# If not needing runtime checks, we can just use annotated types:
810
# Vec3 = Annotated[npt.NDArray[np.float32], (3,)]
911
# Aabb = Annotated[npt.NDArray[np.float32], (2, 3)]
1012

1113

14+
def _ensure_torch_imported() -> None:
15+
global gs
16+
import genesis as gs
17+
global torch
18+
import torch
19+
20+
1221
class Vec3:
1322
"""
1423
Use this wrapper around np.array if you want to ensure adherence to float32 arithmethic
@@ -35,6 +44,9 @@ def __mul__(self, other: float) -> 'Vec3':
3544
def __rmul__(self, other: float) -> 'Vec3':
3645
return Vec3(self.v * np.float32(other))
3746

47+
def __neg__(self) -> 'Vec3':
48+
return Vec3(-self.v)
49+
3850
def dot(self, other: 'Vec3') -> float:
3951
return np.dot(self.v, other.v).item()
4052

@@ -44,12 +56,19 @@ def cross(self, other: 'Vec3') -> 'Vec3':
4456
def normalized(self) -> 'Vec3':
4557
return Vec3(self.v / (np.linalg.norm(self.v) + 1e-24))
4658

59+
def magnitude(self) -> float:
60+
return np.linalg.norm(self.v)
61+
4762
def copy(self) -> 'Vec3':
4863
return Vec3(self.v.copy())
4964

5065
def __repr__(self) -> str:
5166
return f"Vec3({self.v[0]}, {self.v[1]}, {self.v[2]})"
5267

68+
def as_tensor(self) -> 'torch.Tensor':
69+
_ensure_torch_imported()
70+
return torch.tensor(self.v, dtype=gs.tc_float)
71+
5372
@property
5473
def x(self) -> float:
5574
return self.v[0]
@@ -67,27 +86,17 @@ def from_xyz(cls, x: float, y: float, z: float) -> 'Vec3':
6786
return cls(np.array([x, y, z], dtype=np.float32))
6887

6988
@classmethod
70-
def from_int32(cls, v: NDArray[np.int32]) -> 'Vec3':
71-
assert v.shape == (3,), f"Vec3 must be initialized with a 3-element array, got {v.shape}"
72-
assert v.dtype == np.int32, f"from_int32 must be initialized with a int32 array, got {v.dtype}"
73-
return cls.from_xyz(*v)
74-
75-
@classmethod
76-
def from_int64(cls, v: NDArray[np.int64]) -> 'Vec3':
89+
def from_array(cls, v: np.ndarray) -> 'Vec3':
7790
assert v.shape == (3,), f"Vec3 must be initialized with a 3-element array, got {v.shape}"
78-
assert v.dtype == np.int64, f"from_int64 must be initialized with a int64 array, got {v.dtype}"
91+
assert v.dtype == np.int32 or v.dtype == np.int64 or v.dtype == np.float32 or v.dtype == np.float64, \
92+
f"from_array must be initialized with a array of ints/floats 32/64-bit, got {v.dtype}"
7993
return cls.from_xyz(*v)
8094

8195
@classmethod
82-
def from_float64(cls, v: NDArray[np.float64]) -> 'Vec3':
83-
assert v.shape == (3,), f"Vec3 must be initialized with a 3-element array, got {v.shape}"
84-
assert v.dtype == np.float64, f"from_float64 must be initialized with a float64 array, got {v.dtype}"
85-
return cls.from_xyz(*v)
86-
87-
@classmethod
88-
def from_any_array(cls, v: np.ndarray) -> 'Vec3':
89-
assert v.shape == (3,), f"Vec3 must be initialized with a 3-element array, got {v.shape}"
90-
return cls.from_xyz(*v)
96+
def from_tensor(cls, v: 'torch.Tensor') -> 'Vec3':
97+
_ensure_torch_imported()
98+
array: np.ndarray = tensor_to_array(v)
99+
return cls.from_array(array)
91100

92101

93102
@classmethod
@@ -136,6 +145,10 @@ def copy(self) -> 'Quat':
136145
def __repr__(self) -> str:
137146
return f"Quat({self.v[0]}, {self.v[1]}, {self.v[2]}, {self.v[3]})"
138147

148+
def as_tensor(self) -> 'torch.Tensor':
149+
_ensure_torch_imported()
150+
return torch.tensor(self.v, dtype=gs.tc_float)
151+
139152
@property
140153
def w(self) -> float:
141154
return self.v[0]
@@ -152,16 +165,21 @@ def y(self) -> float:
152165
def z(self) -> float:
153166
return self.v[3]
154167

155-
156168
@classmethod
157169
def from_wxyz(cls, w: float, x: float, y: float, z: float) -> 'Quat':
158170
return cls(np.array([w, x, y, z], dtype=np.float32))
159171

160172
@classmethod
161-
def from_any_array(cls, v: np.ndarray) -> 'Quat':
173+
def from_array(cls, v: np.ndarray) -> 'Quat':
162174
assert v.shape == (4,), f"Quat must be initialized with a 4-element array, got {v.shape}"
163175
return cls.from_wxyz(*v)
164176

177+
@classmethod
178+
def from_tensor(cls, v: 'torch.Tensor') -> 'Quat':
179+
_ensure_torch_imported()
180+
array: np.ndarray = tensor_to_array(v)
181+
return cls.from_array(array)
182+
165183

166184
@dataclass
167185
class Pose:

genesis/ext/pyrender/interaction/viewer_interaction.py

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
from typing_extensions import override
33

44
import numpy as np
5-
from numpy.typing import NDArray
65

76
import genesis as gs
87
from genesis.engine.entities.rigid_entity.rigid_entity import RigidEntity
98

109
from .aabb import AABB
1110
from .ray import Plane, Ray, RayHit
1211
from .vec3 import Pose, Quat, Vec3, Color
13-
from .viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE
12+
from .viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE, EVENT_HANDLED
1413

1514
if TYPE_CHECKING:
1615
from genesis.engine.entities.rigid_entity.rigid_geom import RigidGeom
@@ -24,14 +23,6 @@ class ViewerInteraction(ViewerInteractionBase):
2423
- mouse dragging
2524
"""
2625

27-
camera: 'Node'
28-
scene: 'Scene'
29-
viewport_size: tuple[int, int]
30-
camera_yfov: float
31-
32-
tan_half_fov: float
33-
prev_mouse_pos: tuple[int, int]
34-
3526
def __init__(self,
3627
camera: 'Node',
3728
scene: 'Scene',
@@ -41,13 +32,18 @@ def __init__(self,
4132
camera_fov: float = 60.0,
4233
):
4334
super().__init__(log_events)
44-
self.camera = camera
45-
self.scene = scene
46-
self.viewport_size = viewport_size
47-
self.camera_yfov = camera_yfov
35+
self.camera: 'Node' = camera
36+
self.scene: 'Scene' = scene
37+
self.viewport_size: tuple[int, int] = viewport_size
38+
self.camera_yfov: float = camera_yfov
4839

49-
self.tan_half_fov = np.tan(0.5 * self.camera_yfov)
50-
self.prev_mouse_pos = tuple(np.array(viewport_size) / 2)
40+
self.tan_half_fov: float = np.tan(0.5 * self.camera_yfov)
41+
self.prev_mouse_pos: tuple[int, int] = (0.5 * viewport_size[0], 0.5 * viewport_size[1])
42+
43+
self.picked_entity: RigidEntity | None = None
44+
self.picked_point_in_local: Vec3 | None = None
45+
self.mouse_drag_plane: Plane | None = None
46+
self.prev_mouse_3d_pos: Vec3 | None = None
5147

5248
@override
5349
def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE:
@@ -58,17 +54,55 @@ def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STAT
5854
def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE:
5955
super().on_mouse_drag(x, y, dx, dy, buttons, modifiers)
6056
self.prev_mouse_pos = (x, y)
57+
if self.picked_entity:
58+
mouse_ray: Ray = self.screen_position_to_ray(x, y)
59+
ray_hit: RayHit = self.mouse_drag_plane.raycast(mouse_ray)
60+
assert ray_hit.is_hit
61+
if ray_hit.is_hit:
62+
new_mouse_3d_pos: Vec3 = ray_hit.position
63+
delta_3d_pos: Vec3 = new_mouse_3d_pos - self.prev_mouse_3d_pos
64+
self.prev_mouse_3d_pos = new_mouse_3d_pos
65+
66+
use_force: bool = False
67+
if use_force:
68+
# apply force
69+
pass
70+
else:
71+
#apply displacement
72+
pos = Vec3.from_tensor(self.picked_entity.get_pos())
73+
pos = pos + delta_3d_pos
74+
self.picked_entity.set_pos(pos.as_tensor())
75+
76+
return EVENT_HANDLED
6177

6278
@override
6379
def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE:
6480
super().on_mouse_press(x, y, button, modifiers)
81+
if button == 1: # left mouse button
82+
(ray_hit, self.picked_entity) = self.raycast_against_entities(self.screen_position_to_ray(x, y))
83+
if self.picked_entity and ray_hit:
84+
temp_fwd = self.get_camera_forward()
85+
temp_back = -temp_fwd
86+
87+
self.mouse_drag_plane = Plane(temp_back, ray_hit.position)
88+
self.prev_mouse_3d_pos = ray_hit.position
89+
90+
pose: Pose = self.get_pose_of_first_geom(self.picked_entity)
91+
self.picked_point_in_local = pose.inverse_transform_point(ray_hit.position)
92+
93+
@override
94+
def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE:
95+
super().on_mouse_release(x, y, button, modifiers)
96+
if button == 1: # left mouse button
97+
self.picked_entity = None
98+
self.picked_point_in_local = None
6599

66100
@override
67101
def on_draw(self) -> None:
68102
super().on_draw()
69103
if self.scene._visualizer is not None and self.scene._visualizer.viewer_lock is not None:
70104
self.scene.clear_debug_objects()
71-
mouse_ray = self.screen_position_to_ray(*self.prev_mouse_pos)
105+
mouse_ray: Ray = self.screen_position_to_ray(*self.prev_mouse_pos)
72106
closest_hit = None
73107
hit_entity: RigidEntity | None = None
74108

@@ -77,7 +111,7 @@ def on_draw(self) -> None:
77111
closest_hit = ray_hit
78112

79113
for entity in self.get_entities():
80-
ray_hit = self.raycast_against_entity_oobb(entity, mouse_ray)
114+
ray_hit = self.raycast_against_entity_obb(entity, mouse_ray)
81115
if ray_hit.is_hit:
82116
if closest_hit is None or ray_hit.distance < closest_hit.distance:
83117
closest_hit = ray_hit
@@ -87,7 +121,21 @@ def on_draw(self) -> None:
87121
self.scene.draw_debug_sphere(closest_hit.position.v, 0.01, (0, 1, 0, 1))
88122
self._draw_arrow(closest_hit.position, 0.25 * closest_hit.normal, (0, 1, 0, 1))
89123
if hit_entity:
90-
self._draw_entity_unrotated_oobb(hit_entity)
124+
self._draw_entity_unrotated_obb(hit_entity)
125+
126+
if self.picked_entity:
127+
assert self.mouse_drag_plane is not None
128+
assert self.picked_point_in_local is not None
129+
130+
# draw held point
131+
pose: Pose = self.get_pose_of_first_geom(self.picked_entity)
132+
held_point: Vec3 = pose.transform_point(self.picked_point_in_local)
133+
self.scene.draw_debug_sphere(held_point.v, 0.02, Color.red().tuple())
134+
135+
plane_hit: RayHit = self.mouse_drag_plane.raycast(mouse_ray)
136+
if plane_hit.is_hit:
137+
self.scene.draw_debug_sphere(plane_hit.position.v, 0.02, Color.red().tuple())
138+
self.scene.draw_debug_line(held_point.v, plane_hit.position.v, color=Color.red().tuple())
91139

92140
def screen_position_to_ray(self, x: float, y: float) -> Ray:
93141
# convert screen position to ray
@@ -107,35 +155,49 @@ def screen_position_to_ray(self, x: float, y: float) -> Ray:
107155
# Note: ignoring pixel aspect ratio
108156

109157
mtx = self.camera.matrix
110-
position = Vec3.from_float64(mtx[:3, 3])
111-
forward = Vec3.from_float64(-mtx[:3, 2])
112-
right = Vec3.from_float64(mtx[:3, 0])
113-
up = Vec3.from_float64(mtx[:3, 1])
158+
position = Vec3.from_array(mtx[:3, 3])
159+
forward = Vec3.from_array(-mtx[:3, 2])
160+
right = Vec3.from_array(mtx[:3, 0])
161+
up = Vec3.from_array(mtx[:3, 1])
114162

115163
direction = forward + right * x + up * y
116164
return Ray(position, direction)
117165

166+
def get_camera_forward(self) -> Vec3:
167+
mtx = self.camera.matrix
168+
return Vec3.from_array(-mtx[:3, 2])
169+
118170
def get_camera_ray(self) -> Ray:
119171
mtx = self.camera.matrix
120-
position = Vec3.from_float64(mtx[:3, 3])
121-
forward = Vec3.from_float64(-mtx[:3, 2])
172+
position = Vec3.from_array(mtx[:3, 3])
173+
forward = Vec3.from_array(-mtx[:3, 2])
122174
return Ray(position, forward)
123175

124176
def _raycast_against_ground_plane(self, ray: Ray) -> RayHit:
125177
ground_plane = Plane(Vec3.from_xyz(0, 0, 1), Vec3.zero())
126178
return ground_plane.raycast(ray)
127179

128-
def raycast_against_entity_oobb(self, entity: RigidEntity, ray: Ray) -> RayHit:
180+
def raycast_against_entity_obb(self, entity: RigidEntity, ray: Ray) -> RayHit:
129181
if isinstance(entity.morph, gs.morphs.Box):
130182
box: gs.morphs.Box = entity.morph
131183
size = Vec3.from_xyz(*box.size)
132184
pose = self.get_pose_of_first_geom(entity)
133185
aabb = AABB.from_center_and_size(Vec3.zero(), size)
134-
ray_hit = aabb.raycast_oobb(pose, ray)
186+
ray_hit = aabb.raycast_obb(pose, ray)
135187
return ray_hit
136188
else:
137189
return RayHit.no_hit()
138190

191+
def raycast_against_entities(self, ray: Ray) -> tuple[RayHit | None, RigidEntity | None]:
192+
closest_hit = None
193+
hit_entity: RigidEntity | None = None
194+
for entity in self.get_entities():
195+
ray_hit = self.raycast_against_entity_obb(entity, ray)
196+
if ray_hit.is_hit and (closest_hit is None or ray_hit.distance < closest_hit.distance):
197+
closest_hit = ray_hit
198+
hit_entity = entity
199+
return (closest_hit, hit_entity)
200+
139201
def get_entities(self) -> list[RigidEntity]:
140202
return self.scene.sim.rigid_solver.entities
141203

@@ -152,7 +214,7 @@ def _draw_arrow(
152214
self.scene.draw_debug_arrow(pos.v, dir.v, color=color) # Only draws arrowhead -- bug?
153215
self.scene.draw_debug_line(pos.v, pos.v + dir.v, color=color)
154216

155-
def _draw_entity_unrotated_oobb(self, entity: RigidEntity) -> None:
217+
def _draw_entity_unrotated_obb(self, entity: RigidEntity) -> None:
156218
if isinstance(entity.morph, gs.morphs.Plane):
157219
plane: gs.morphs.Plane = entity.morph
158220
pass
@@ -161,10 +223,9 @@ def _draw_entity_unrotated_oobb(self, entity: RigidEntity) -> None:
161223
size = Vec3.from_xyz(*box.size)
162224
geom: RigidGeom = entity.geoms[0]
163225
assert geom._solver.n_envs == 0, "ViewerInteraction only supports single-env for now"
164-
gpos = geom.get_pos() # squeezed if n_envs == 0
165-
gquat = geom.get_quat() # squeezed if n_envs == 0
166-
pos = Vec3.from_any_array(gpos.cpu().numpy())
167-
quat = Quat.from_any_array(gquat.cpu().numpy())
226+
# geom.get_pos() and .get_quat() are squeezed if n_envs == 0
227+
pos = Vec3.from_tensor(geom.get_pos())
228+
quat = Quat.from_tensor(geom.get_quat())
168229
aabb = AABB.from_center_and_size(pos, size)
169230
aabb.expand(0.01)
170231
self.scene.draw_debug_box(aabb.v, color=Color.red().with_alpha(0.5).tuple(), wireframe=False)

0 commit comments

Comments
 (0)