Skip to content

Commit b3917fe

Browse files
committed
Stub code for physical object dragging
1 parent e5d7d92 commit b3917fe

File tree

10 files changed

+159
-34
lines changed

10 files changed

+159
-34
lines changed

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from genesis.utils.misc import tensor_to_array, ti_field_to_torch, ALLOCATE_TENSOR_WARNING
2222

2323
from ..base_entity import Entity
24+
from .rigid_geom import RigidGeom
2425
from .rigid_joint import RigidJoint
2526
from .rigid_link import RigidLink
2627
from .rigid_equality import RigidEquality
@@ -3146,7 +3147,7 @@ def q_end(self):
31463147
return self._q_start + self.n_qs
31473148

31483149
@property
3149-
def geoms(self):
3150+
def geoms(self) -> list[RigidGeom]:
31503151
"""The list of collision geoms (`RigidGeom`) in the entity."""
31513152
if self.is_built:
31523153
return self._geoms
@@ -3168,7 +3169,7 @@ def vgeoms(self):
31683169
return vgeoms
31693170

31703171
@property
3171-
def links(self):
3172+
def links(self) -> list[RigidLink]:
31723173
"""The list of links (`RigidLink`) in the entity."""
31733174
return self._links
31743175

genesis/engine/entities/rigid_entity/rigid_geom.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class RigidGeom(RBC):
3030

3131
def __init__(
3232
self,
33-
link,
33+
link: "RigidLink",
3434
idx,
3535
cell_start,
3636
vert_start,
@@ -494,14 +494,14 @@ def metadata(self):
494494
return self._metadata
495495

496496
@property
497-
def link(self):
497+
def link(self) -> "RigidLink":
498498
"""
499499
Get the link that the geom belongs to.
500500
"""
501501
return self._link
502502

503503
@property
504-
def entity(self):
504+
def entity(self) -> "RigidEntity":
505505
"""
506506
Get the entity that the geom belongs to.
507507
"""

genesis/engine/entities/rigid_entity/rigid_link.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
if TYPE_CHECKING:
1616
from .rigid_entity import RigidEntity
17+
from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver
1718

1819

1920
@ti.data_oriented
@@ -51,7 +52,7 @@ def __init__(
5152
):
5253
self._name: str = name
5354
self._entity: "RigidEntity" = entity
54-
self._solver = entity.solver
55+
self._solver: "RigidSolver" = entity.solver
5556
self._entity_idx_in_solver = entity.idx
5657

5758
self._uid = gs.UID()
@@ -265,7 +266,7 @@ def get_quat(self, envs_idx=None):
265266
return self._solver.get_links_quat([self._idx], envs_idx).squeeze(-2)
266267

267268
@gs.assert_built
268-
def get_vel(self, envs_idx=None):
269+
def get_vel(self, envs_idx=None) -> torch.Tensor:
269270
"""
270271
Get the linear velocity of the link in the world frame.
271272
@@ -277,7 +278,7 @@ def get_vel(self, envs_idx=None):
277278
return self._solver.get_links_vel([self._idx], envs_idx).squeeze(-2)
278279

279280
@gs.assert_built
280-
def get_ang(self, envs_idx=None):
281+
def get_ang(self, envs_idx=None) -> torch.Tensor:
281282
"""
282283
Get the angular velocity of the link in the world frame.
283284

genesis/engine/scene.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1201,7 +1201,7 @@ def active_solvers(self):
12011201
return self._sim.active_solvers
12021202

12031203
@property
1204-
def entities(self):
1204+
def entities(self) -> list[Entity]:
12051205
"""All the entities in the scene."""
12061206
return self._sim.entities
12071207

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
2+
from genesis.engine.entities.rigid_entity.rigid_entity import RigidEntity
3+
from genesis.engine.entities.rigid_entity.rigid_geom import RigidGeom
4+
5+
from .ray import Plane, Ray, RayHit
6+
from .vec3 import Pose, Quat, Vec3, Color
7+
8+
from genesis.engine.entities.rigid_entity.rigid_link import RigidLink
9+
10+
def _ensure_torch_imported() -> None:
11+
global torch
12+
import torch
13+
14+
class MouseSpring:
15+
def __init__(self):
16+
self.held_geom: RigidGeom | None = None
17+
self.held_point_in_local: Vec3 | None = None
18+
self.prev_control_point: Vec3 | None = None
19+
20+
def attach(self, picked_entity: RigidEntity, control_point: Vec3):
21+
# for now, we just pick the first geometry
22+
self.held_geom = picked_entity.geoms[0]
23+
pose: Pose = Pose.from_geom(self.held_geom)
24+
self.held_point_in_local = pose.inverse_transform_point(control_point)
25+
self.prev_control_point = control_point
26+
27+
def detach(self):
28+
self.held_geom = None
29+
30+
def apply_force(self, control_point: Vec3, delta_time: float):
31+
_ensure_torch_imported()
32+
33+
# works ok:
34+
# delta: Vec3 = control_point - self.prev_control_point
35+
# pos = Vec3.from_tensor(self.held_geom.entity.get_pos())
36+
# pos = pos + delta
37+
# self.held_geom.entity.set_pos(pos.as_tensor())
38+
self.prev_control_point = control_point
39+
40+
# do simple force on COM only:
41+
link: RigidLink = self.held_geom.link
42+
link_pos: Vec3 = Vec3.from_tensor(link.get_pos())
43+
lin_vel: Vec3 = Vec3.from_tensor(link.get_vel())
44+
ang_vel: Vec3 = Vec3.from_tensor(link.get_ang())
45+
46+
pos_err_v: Vec3 = control_point - link_pos
47+
vel_err_v: Vec3 = Vec3.zero() - lin_vel
48+
inv_mass: float = float(1.0 / link.get_mass() if 0.0 < link.get_mass() else 0.0)
49+
50+
inv_dt: float = 1.0 / delta_time
51+
tau: float = 1.0 / 2
52+
damp: float = 1.0 * 2
53+
54+
total_impulse: Vec3 = Vec3.zero()
55+
56+
for i in range(3):
57+
dir: Vec3 = Vec3.zero()
58+
dir.v[i] = 1.0
59+
pos_err: float = dir.dot(pos_err_v)
60+
vel_err: float = dir.dot(vel_err_v)
61+
err: float = tau * pos_err * inv_dt + damp * vel_err
62+
vm: float = inv_mass
63+
imp: float = err * vm
64+
65+
lin_vel += imp * dir * inv_mass
66+
total_impulse.v[i] = imp
67+
68+
total_force = total_impulse * inv_dt
69+
force_tensor: torch.Tensor = total_force.as_tensor().unsqueeze(0)
70+
link.solver.apply_links_external_force(force_tensor, (link.idx,), ref='link_com', local=False)
71+
72+
# print("vel", lin_vel, "total_impulse", total_impulse)
73+
# print("exp", lin_vel + total_force * inv_mass * link.solver.dt, "at dt", link.solver.dt, "inv dt", inv_dt)
74+
75+
76+
@property
77+
def is_attached(self) -> bool:
78+
return self.held_geom is not None
79+
80+
81+
def _solve_mouse_spring():
82+
pass

genesis/ext/pyrender/interaction/vec3.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from dataclasses import dataclass
2-
from typing import Union
2+
from typing import TYPE_CHECKING, Union
33

44
import numpy as np
55
from numpy.typing import NDArray
66

77
from genesis.utils.misc import tensor_to_array
88

9+
if TYPE_CHECKING:
10+
from genesis.engine.entities.rigid_entity.rigid_geom import RigidGeom
11+
912
# If not needing runtime checks, we can just use annotated types:
1013
# Vec3 = Annotated[npt.NDArray[np.float32], (3,)]
1114
# Aabb = Annotated[npt.NDArray[np.float32], (2, 3)]
@@ -209,6 +212,14 @@ def get_inverse(self) -> 'Pose':
209212
inv_pos = Vec3(-inv_pos.v[1:])
210213
return Pose(inv_pos, inv_rot)
211214

215+
@classmethod
216+
def from_geom(cls, geom: 'RigidGeom') -> 'Pose':
217+
assert geom._solver.n_envs == 0, "ViewerInteraction only supports single-env for now"
218+
# geom.get_pos() and .get_quat() are squeezed if n_envs == 0
219+
pos = Vec3.from_tensor(geom.get_pos())
220+
quat = Quat.from_tensor(geom.get_quat())
221+
return Pose(pos, quat)
222+
212223

213224
@dataclass
214225
class Color:

genesis/ext/pyrender/interaction/viewer_interaction.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from genesis.engine.entities.rigid_entity.rigid_entity import RigidEntity
88

99
from .aabb import AABB
10+
from .mouse_spring import MouseSpring
1011
from .ray import Plane, Ray, RayHit
1112
from .vec3 import Pose, Quat, Vec3, Color
1213
from .viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE, EVENT_HANDLED
@@ -38,13 +39,15 @@ def __init__(self,
3839
self.camera_yfov: float = camera_yfov
3940

4041
self.tan_half_fov: float = np.tan(0.5 * self.camera_yfov)
41-
self.prev_mouse_pos: tuple[int, int] = (0.5 * viewport_size[0], 0.5 * viewport_size[1])
42+
self.prev_mouse_pos: tuple[int, int] = (viewport_size[0] // 2, viewport_size[1] // 2)
4243

4344
self.picked_entity: RigidEntity | None = None
4445
self.picked_point_in_local: Vec3 | None = None
4546
self.mouse_drag_plane: Plane | None = None
4647
self.prev_mouse_3d_pos: Vec3 | None = None
4748

49+
self.mouse_spring: MouseSpring = MouseSpring()
50+
4851
@override
4952
def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE:
5053
super().on_mouse_motion(x, y, dx, dy)
@@ -55,23 +58,7 @@ def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifier
5558
super().on_mouse_drag(x, y, dx, dy, buttons, modifiers)
5659
self.prev_mouse_pos = (x, y)
5760
if self.picked_entity:
58-
mouse_ray: Ray = self.screen_position_to_ray(x, y)
59-
ray_hit: RayHit = self.mouse_drag_plane.raycast(mouse_ray)
60-
assert ray_hit.is_hit
61-
if ray_hit.is_hit:
62-
new_mouse_3d_pos: Vec3 = ray_hit.position
63-
delta_3d_pos: Vec3 = new_mouse_3d_pos - self.prev_mouse_3d_pos
64-
self.prev_mouse_3d_pos = new_mouse_3d_pos
65-
66-
use_force: bool = False
67-
if use_force:
68-
# apply force
69-
pass
70-
else:
71-
#apply displacement
72-
pos = Vec3.from_tensor(self.picked_entity.get_pos())
73-
pos = pos + delta_3d_pos
74-
self.picked_entity.set_pos(pos.as_tensor())
61+
# actual processing done in update_on_sim_step()
7562

7663
return EVENT_HANDLED
7764

@@ -90,13 +77,44 @@ def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_H
9077
pose: Pose = self.get_pose_of_first_geom(self.picked_entity)
9178
self.picked_point_in_local = pose.inverse_transform_point(ray_hit.position)
9279

80+
self.mouse_spring.attach(self.picked_entity, ray_hit.position)
81+
9382
@override
9483
def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE:
9584
super().on_mouse_release(x, y, button, modifiers)
9685
if button == 1: # left mouse button
9786
self.picked_entity = None
9887
self.picked_point_in_local = None
9988

89+
self.mouse_spring.detach()
90+
91+
@override
92+
def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE:
93+
super().on_resize(width, height)
94+
self.viewport_size = (width, height)
95+
self.tan_half_fov = np.tan(0.5 * self.camera_yfov)
96+
97+
@override
98+
def update_on_sim_step(self) -> None:
99+
if self.picked_entity:
100+
mouse_ray: Ray = self.screen_position_to_ray(*self.prev_mouse_pos)
101+
ray_hit: RayHit = self.mouse_drag_plane.raycast(mouse_ray)
102+
assert ray_hit.is_hit
103+
if ray_hit.is_hit:
104+
new_mouse_3d_pos: Vec3 = ray_hit.position
105+
delta_3d_pos: Vec3 = new_mouse_3d_pos - self.prev_mouse_3d_pos
106+
self.prev_mouse_3d_pos = new_mouse_3d_pos
107+
108+
use_force: bool = True
109+
if use_force:
110+
# apply force
111+
self.mouse_spring.apply_force(new_mouse_3d_pos, self.scene.sim.dt)
112+
else:
113+
#apply displacement
114+
pos = Vec3.from_tensor(self.picked_entity.get_pos())
115+
pos = pos + delta_3d_pos
116+
self.picked_entity.set_pos(pos.as_tensor())
117+
100118
@override
101119
def on_draw(self) -> None:
102120
super().on_draw()
@@ -202,11 +220,7 @@ def get_entities(self) -> list[RigidEntity]:
202220
return self.scene.sim.rigid_solver.entities
203221

204222
def get_pose_of_first_geom(self, entity: RigidEntity) -> 'Pose':
205-
geom: RigidGeom = entity.geoms[0]
206-
assert geom._solver.n_envs == 0, "ViewerInteraction only supports single-env for now"
207-
gpos = geom.get_pos() # squeezed if n_envs == 0
208-
gquat = geom.get_quat() # squeezed if n_envs == 0
209-
return Pose(Vec3(gpos.cpu().numpy()), Quat(gquat.cpu().numpy()))
223+
return Pose.from_geom(entity.geoms[0])
210224

211225
def _draw_arrow(
212226
self, pos: Vec3, dir: Vec3, color: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0),
@@ -215,6 +229,9 @@ def _draw_arrow(
215229
self.scene.draw_debug_line(pos.v, pos.v + dir.v, color=color)
216230

217231
def _draw_entity_unrotated_obb(self, entity: RigidEntity) -> None:
232+
if self.picked_entity:
233+
return
234+
218235
if isinstance(entity.morph, gs.morphs.Plane):
219236
plane: gs.morphs.Plane = entity.morph
220237
pass

genesis/ext/pyrender/interaction/viewer_interaction_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,12 @@ def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE:
4141
if self.log_events:
4242
gs.logger.info(f"Key released: {chr(symbol)}")
4343

44+
def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE:
45+
if self.log_events:
46+
gs.logger.info(f"Window resized to {width}x{height}")
47+
48+
def update_on_sim_step(self) -> None:
49+
pass
50+
4451
def on_draw(self) -> None:
4552
pass

genesis/ext/pyrender/viewer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def on_draw(self):
751751
if self._run_in_thread or not self.auto_start:
752752
self.render_lock.release()
753753

754-
def on_resize(self, width, height):
754+
def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE:
755755
"""Resize the camera and trackball when the window is resized."""
756756
if self._renderer is None:
757757
return
@@ -763,6 +763,7 @@ def on_resize(self, width, height):
763763
self._trackball.resize(self._viewport_size)
764764
self._renderer.viewport_width = self._viewport_size[0]
765765
self._renderer.viewport_height = self._viewport_size[1]
766+
self.viewer_interaction.on_resize(width, height)
766767
self.on_draw()
767768

768769
def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE:
@@ -1288,6 +1289,9 @@ def refresh(self):
12881289
if self._is_active:
12891290
self.flip()
12901291

1292+
def update_on_sim_step(self):
1293+
self.viewer_interaction.update_on_sim_step()
1294+
12911295
def _compute_initial_camera_pose(self):
12921296
centroid = self.scene.centroid
12931297
if self.viewer_flags["view_center"] is not None:

genesis/vis/viewer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def update(self, auto_refresh=None):
139139
if self._followed_entity is not None:
140140
self.update_following()
141141

142+
self._pyrender_viewer.update_on_sim_step()
143+
142144
with self.lock:
143145
self._pyrender_viewer.pending_buffer_updates |= self.context.update()
144146

0 commit comments

Comments
 (0)