Skip to content

Commit 21bc0c9

Browse files
gasnicaKashu7100
authored andcommitted
[MISC] Picking boxes and Panda links with a MouseSpring (Genesis-Embodied-AI#1443)
1 parent 0a5b75f commit 21bc0c9

File tree

8 files changed

+262
-153
lines changed

8 files changed

+262
-153
lines changed

genesis/engine/entities/rigid_entity/rigid_link.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
if TYPE_CHECKING:
1616
from .rigid_entity import RigidEntity
1717
from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver
18+
from genesis.ext.pyrender.interaction.vec3 import Pose
1819

1920

2021
@ti.data_oriented
@@ -591,14 +592,14 @@ def invweight(self):
591592
return self._invweight
592593

593594
@property
594-
def pos(self):
595+
def pos(self) -> ArrayLike:
595596
"""
596597
The initial position of the link. For real-time position, use `link.get_pos()`.
597598
"""
598599
return self._pos
599600

600601
@property
601-
def quat(self):
602+
def quat(self) -> ArrayLike:
602603
"""
603604
The initial quaternion of the link. For real-time quaternion, use `link.get_quat()`.
604605
"""
@@ -744,6 +745,11 @@ def is_free(self):
744745
"""
745746
return self.entity.is_free
746747

748+
@property
749+
def pose(self) -> "Pose":
750+
"""Return the current pose of the link (note, this is not necessarily the same as the principal axes frame)."""
751+
return Pose.from_link(self)
752+
747753
# ------------------------------------------------------------------------------------
748754
# -------------------------------------- repr ----------------------------------------
749755
# ------------------------------------------------------------------------------------

genesis/engine/solvers/base_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def gravity(self):
152152
return self._gravity.to_numpy() if self._gravity is not None else None
153153

154154
@property
155-
def entities(self):
155+
def entities(self) -> list[Entity]:
156156
return self._entities
157157

158158
@property

genesis/ext/pyrender/interaction/aabb.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from dataclasses import dataclass
2+
13
import numpy as np
24
from numpy.typing import NDArray
35

@@ -20,6 +22,10 @@ def min(self) -> Vec3:
2022
def max(self) -> Vec3:
2123
return Vec3(self.v[1])
2224

25+
@property
26+
def extents(self) -> Vec3:
27+
return self.max - self.min
28+
2329
def expand(self, padding: float) -> None:
2430
self.v[0] -= padding
2531
self.v[1] += padding
@@ -54,28 +60,37 @@ def raycast(self, ray: Ray) -> RayHit:
5460
hit_pos = ray.origin + ray.direction * enter
5561
return RayHit(enter, hit_pos, normal)
5662

57-
def raycast_obb(self, pose: Pose, ray: Ray) -> RayHit:
58-
inv_pose = pose.get_inverse()
59-
origin2 = inv_pose.transform_point(ray.origin)
60-
direction2 = inv_pose.transform_direction(ray.direction)
61-
ray2 = Ray(origin2, direction2)
62-
ray_hit = self.raycast(ray2)
63-
if ray_hit.is_hit:
64-
ray_hit.position = pose.transform_point(ray_hit.position)
65-
ray_hit.normal = pose.transform_direction(ray_hit.normal)
66-
return ray_hit
67-
6863
def __repr__(self) -> str:
69-
return f"Min({self.min.x}, {self.min.y}, {self.min.z}) Max({self.max.x}, {self.max.y}, {self.max.z})"
64+
return f"AABB: Min({self.min.x}, {self.min.y}, {self.min.z}) Max({self.max.x}, {self.max.y}, {self.max.z})"
7065

7166
@classmethod
7267
def from_min_max(cls, min: Vec3, max: Vec3) -> 'AABB':
73-
bounds = np.stack((min.v, max.v))
68+
bounds = np.stack((min.v, max.v), axis=0)
7469
return cls(bounds)
7570

7671
@classmethod
77-
def from_center_and_size(cls, center: Vec3, size: Vec3) -> 'AABB':
78-
min = center - 0.5 * size
79-
max = center + 0.5 * size
80-
bounds = np.stack((min.v, max.v))
72+
def from_center_and_half_extents(cls, center: Vec3, half_extents: Vec3) -> 'AABB':
73+
min = center - half_extents
74+
max = center + half_extents
75+
bounds = np.stack((min.v, max.v), axis=0)
8176
return cls(bounds)
77+
78+
79+
@dataclass
80+
class OBB():
81+
pose: Pose
82+
half_extents: Vec3
83+
84+
def raycast(self, ray: Ray) -> RayHit:
85+
origin2 = self.pose.inverse_transform_point(ray.origin)
86+
direction2 = self.pose.inverse_transform_direction(ray.direction)
87+
ray2 = Ray(origin2, direction2)
88+
aabb = AABB.from_center_and_half_extents(Vec3.zero(), self.half_extents)
89+
ray_hit = aabb.raycast(ray2)
90+
if ray_hit.is_hit:
91+
ray_hit.position = self.pose.transform_point(ray_hit.position)
92+
ray_hit.normal = self.pose.transform_direction(ray_hit.normal)
93+
return ray_hit
94+
95+
def __repr__(self) -> str:
96+
return f"OBB(pose={self.pose}, half_extents={self.half_extents})"

genesis/ext/pyrender/interaction/mouse_spring.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,66 +15,84 @@ def _ensure_torch_imported() -> None:
1515
import torch
1616

1717
class MouseSpring:
18-
def __init__(self):
19-
self.held_geom: RigidGeom | None = None
18+
def __init__(self) -> None:
19+
self.held_link: RigidLink | None = None
2020
self.held_point_in_local: Vec3 | None = None
2121
self.prev_control_point: Vec3 | None = None
2222

23-
def attach(self, picked_entity: RigidEntity, control_point: Vec3):
23+
def attach(self, picked_link: RigidLink, control_point: Vec3) -> None:
2424
# for now, we just pick the first geometry
25-
self.held_geom = picked_entity.geoms[0]
26-
pose: Pose = Pose.from_geom(self.held_geom)
25+
self.held_link = picked_link
26+
pose: Pose = Pose.from_link(self.held_link)
2727
self.held_point_in_local = pose.inverse_transform_point(control_point)
2828
self.prev_control_point = control_point
2929

30-
def detach(self):
31-
self.held_geom = None
30+
def detach(self) -> None:
31+
self.held_link = None
3232

33-
def apply_force(self, control_point: Vec3, delta_time: float):
33+
def apply_force(self, control_point: Vec3, delta_time: float) -> None:
3434
_ensure_torch_imported()
35+
36+
# note when threaded: apply_force is called before attach!
37+
# note2: that was before we added a lock to ViewerInteraction; this migth be fixed now
38+
if not self.held_link:
39+
return
3540

36-
# works ok:
37-
# delta: Vec3 = control_point - self.prev_control_point
38-
# pos = Vec3.from_tensor(self.held_geom.entity.get_pos())
39-
# pos = pos + delta
40-
# self.held_geom.entity.set_pos(pos.as_tensor())
4141
self.prev_control_point = control_point
4242

4343
# do simple force on COM only:
44-
link: RigidLink = self.held_geom.link
45-
link_pos: Vec3 = Vec3.from_tensor(link.get_pos())
44+
link: RigidLink = self.held_link
4645
lin_vel: Vec3 = Vec3.from_tensor(link.get_vel())
4746
ang_vel: Vec3 = Vec3.from_tensor(link.get_ang())
47+
link_pose: Pose = Pose.from_link(link)
48+
held_point_in_world: Vec3 = link_pose.transform_point(self.held_point_in_local)
49+
50+
# note: we should assert earlier that link inertial_pos/quat are not None
51+
# todo: verify inertial_pos/quat are stored in local frame
52+
link_T_principal: Pose = Pose(Vec3.from_arraylike(link.inertial_pos), Quat.from_arraylike(link.inertial_quat))
53+
world_T_principal: Pose = link_pose * link_T_principal
54+
55+
arm_in_principal: Vec3 = link_T_principal.inverse_transform_point(self.held_point_in_local) # for non-spherical inertia
56+
arm_in_world: Vec3 = world_T_principal.rot * arm_in_principal # for spherical inertia
4857

49-
pos_err_v: Vec3 = control_point - link_pos
50-
vel_err_v: Vec3 = Vec3.zero() - lin_vel
58+
pos_err_v: Vec3 = control_point - held_point_in_world
5159
inv_mass: float = float(1.0 / link.get_mass() if link.get_mass() > 0.0 else 0.0)
60+
inv_spherical_inertia: float = float(1.0 / link.inertial_i[0, 0] if link.inertial_i[0, 0] > 0.0 else 0.0)
5261

5362
inv_dt: float = 1.0 / delta_time
54-
# these are temporary values, till we fix an issue with apply_links_external_force.
55-
# after fixing it, use tau = damp = 1.0:
5663
tau: float = MOUSE_SPRING_POSITION_CORRECTION_FACTOR
5764
damp: float = MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR
5865

5966
total_impulse: Vec3 = Vec3.zero()
67+
total_torque_impulse: Vec3 = Vec3.zero()
68+
69+
for i in range(3*4):
70+
body_point_vel: Vec3 = lin_vel + ang_vel.cross(arm_in_world)
71+
vel_err_v: Vec3 = Vec3.zero() - body_point_vel
6072

61-
for i in range(3):
6273
dir: Vec3 = Vec3.zero()
63-
dir.v[i] = 1.0
74+
dir.v[i % 3] = 1.0
6475
pos_err: float = dir.dot(pos_err_v)
6576
vel_err: float = dir.dot(vel_err_v)
6677
error: float = tau * pos_err * inv_dt + damp * vel_err
67-
virtual_mass: float = 1.0 / (inv_mass + 1e-24)
78+
79+
arm_x_dir: Vec3 = arm_in_world.cross(dir)
80+
virtual_mass: float = 1.0 / (inv_mass + arm_x_dir.sqr_magnitude() * inv_spherical_inertia + 1e-24)
6881
impulse: float = error * virtual_mass
6982

70-
lin_vel += impulse * dir * inv_mass
71-
total_impulse.v[i] = impulse
83+
lin_vel += impulse * inv_mass * dir
84+
ang_vel += impulse * inv_spherical_inertia * arm_x_dir
85+
total_impulse.v[i % 3] += impulse
86+
total_torque_impulse += impulse * arm_x_dir
7287

7388
# Apply the new force
7489
total_force = total_impulse * inv_dt
90+
total_torque = total_torque_impulse * inv_dt
7591
force_tensor: torch.Tensor = total_force.as_tensor().unsqueeze(0)
92+
torque_tensor: torch.Tensor = total_torque.as_tensor().unsqueeze(0)
7693
link.solver.apply_links_external_force(force_tensor, (link.idx,), ref='link_com', local=False)
94+
link.solver.apply_links_external_torque(torque_tensor, (link.idx,), ref='link_com', local=False)
7795

7896
@property
7997
def is_attached(self) -> bool:
80-
return self.held_geom is not None
98+
return self.held_link is not None

genesis/ext/pyrender/interaction/ray.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from dataclasses import dataclass
2+
import sys
3+
4+
from genesis.engine.entities.rigid_entity.rigid_geom import RigidGeom
25

36
from .vec3 import Vec3
47

58

69
EPSILON = 1e-6
710
EPSILON2 = EPSILON * EPSILON
811

12+
_MAX_RAY_DISTANCE = sys.float_info.max
13+
914

1015
class Ray:
1116
origin: Vec3
@@ -24,14 +29,16 @@ class RayHit:
2429
distance: float
2530
position: Vec3
2631
normal: Vec3
32+
geom: RigidGeom | None = None
2733

2834
@property
2935
def is_hit(self) -> bool:
30-
return 0 <= self.distance
36+
assert 0.0 <= self.distance
37+
return self.distance < _MAX_RAY_DISTANCE
3138

3239
@classmethod
3340
def no_hit(cls) -> 'RayHit':
34-
return RayHit(-1.0, Vec3.zero(), Vec3.zero())
41+
return RayHit(_MAX_RAY_DISTANCE, Vec3.zero(), Vec3.zero(), None)
3542

3643

3744
class Plane:
@@ -50,4 +57,4 @@ def raycast(self, ray: Ray) -> RayHit:
5057
return RayHit.no_hit()
5158
else:
5259
dist_along_ray = dist / -dot
53-
return RayHit(dist_along_ray, ray.origin + ray.direction * dist_along_ray, self.normal)
60+
return RayHit(dist_along_ray, ray.origin + ray.direction * dist_along_ray, self.normal, None)

genesis/ext/pyrender/interaction/vec3.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from typing import TYPE_CHECKING, Union
33

44
import numpy as np
5-
from numpy.typing import NDArray
5+
from numpy.typing import NDArray, ArrayLike
66

77
from genesis.utils.misc import tensor_to_array
88

99
if TYPE_CHECKING:
1010
from genesis.engine.entities.rigid_entity.rigid_geom import RigidGeom
11+
from genesis.engine.entities.rigid_entity.rigid_link import RigidLink
1112

1213
# If not needing runtime checks, we can just use annotated types:
1314
# Vec3 = Annotated[npt.NDArray[np.float32], (3,)]
@@ -62,6 +63,9 @@ def normalized(self) -> 'Vec3':
6263
def magnitude(self) -> float:
6364
return np.linalg.norm(self.v)
6465

66+
def sqr_magnitude(self) -> float:
67+
return np.dot(self.v, self.v)
68+
6569
def copy(self) -> 'Vec3':
6670
return Vec3(self.v.copy())
6771

@@ -101,15 +105,30 @@ def from_tensor(cls, v: 'torch.Tensor') -> 'Vec3':
101105
array: np.ndarray = tensor_to_array(v)
102106
return cls.from_array(array)
103107

108+
@classmethod
109+
def from_arraylike(cls, v: ArrayLike) -> 'Vec3':
110+
if isinstance(v, np.ndarray):
111+
return cls.from_array(v)
112+
elif isinstance(v, torch.Tensor):
113+
return cls.from_tensor(v)
114+
elif isinstance(v, ArrayLike):
115+
assert len(v) == 3, f"Vec3 must be initialized with a 3-element ArrayLike, got {len(v)}"
116+
return cls.from_xyz(*v)
117+
assert False
118+
104119

105120
@classmethod
106-
def zero(cls):
121+
def zero(cls) -> 'Vec3':
107122
return cls(np.array([0, 0, 0], dtype=np.float32))
108123

109124
@classmethod
110-
def one(cls):
125+
def one(cls) -> 'Vec3':
111126
return cls(np.array([1, 1, 1], dtype=np.float32))
112127

128+
@classmethod
129+
def full(cls, fill_value: float) -> 'Vec3':
130+
return cls(np.full((3,), fill_value, dtype=np.float32))
131+
113132

114133
class Quat:
115134
v: NDArray[np.float32]
@@ -183,6 +202,17 @@ def from_tensor(cls, v: 'torch.Tensor') -> 'Quat':
183202
array: np.ndarray = tensor_to_array(v)
184203
return cls.from_array(array)
185204

205+
@classmethod
206+
def from_arraylike(cls, v: ArrayLike) -> 'Quat':
207+
if isinstance(v, np.ndarray):
208+
return cls.from_array(v)
209+
elif isinstance(v, torch.Tensor):
210+
return cls.from_tensor(v)
211+
elif isinstance(v, ArrayLike):
212+
assert len(v) == 4, f"Quat must be initialized with a 4-element ArrayLike, got {len(v)}"
213+
return cls.from_wxyz(*v)
214+
assert False
215+
186216

187217
@dataclass
188218
class Pose:
@@ -212,6 +242,17 @@ def get_inverse(self) -> 'Pose':
212242
inv_pos = Vec3(-inv_pos.v[1:])
213243
return Pose(inv_pos, inv_rot)
214244

245+
def __mul__(self, other: Union['Pose', Vec3]) -> Union['Pose', Vec3]:
246+
if isinstance(other, Pose):
247+
return Pose(self.pos + self.rot * other.pos, self.rot * other.rot)
248+
elif isinstance(other, Vec3):
249+
return self.pos + self.rot * other
250+
else:
251+
return NotImplemented
252+
253+
def __repr__(self) -> str:
254+
return f"Pose(pos={self.pos}, rot={self.rot})"
255+
215256
@classmethod
216257
def from_geom(cls, geom: 'RigidGeom') -> 'Pose':
217258
assert geom._solver.n_envs == 0, "ViewerInteraction only supports single-env for now"
@@ -220,6 +261,14 @@ def from_geom(cls, geom: 'RigidGeom') -> 'Pose':
220261
quat = Quat.from_tensor(geom.get_quat())
221262
return Pose(pos, quat)
222263

264+
@classmethod
265+
def from_link(cls, link: 'RigidLink') -> 'Pose':
266+
assert link._solver.n_envs == 0, "ViewerInteraction only supports single-env for now"
267+
# geom.get_pos() and .get_quat() are squeezed if n_envs == 0
268+
pos = Vec3.from_tensor(link.get_pos())
269+
quat = Quat.from_tensor(link.get_quat())
270+
return Pose(pos, quat)
271+
223272

224273
@dataclass
225274
class Color:

0 commit comments

Comments
 (0)