22from typing_extensions import override
33
44import numpy as np
5- from numpy .typing import NDArray
65
76import genesis as gs
87from genesis .engine .entities .rigid_entity .rigid_entity import RigidEntity
98
109from .aabb import AABB
1110from .ray import Plane , Ray , RayHit
1211from .vec3 import Pose , Quat , Vec3 , Color
13- from .viewer_interaction_base import ViewerInteractionBase , EVENT_HANDLE_STATE
12+ from .viewer_interaction_base import ViewerInteractionBase , EVENT_HANDLE_STATE , EVENT_HANDLED
1413
1514if TYPE_CHECKING :
1615 from genesis .engine .entities .rigid_entity .rigid_geom import RigidGeom
@@ -24,14 +23,6 @@ class ViewerInteraction(ViewerInteractionBase):
2423 - mouse dragging
2524 """
2625
27- camera : 'Node'
28- scene : 'Scene'
29- viewport_size : tuple [int , int ]
30- camera_yfov : float
31-
32- tan_half_fov : float
33- prev_mouse_pos : tuple [int , int ]
34-
3526 def __init__ (self ,
3627 camera : 'Node' ,
3728 scene : 'Scene' ,
@@ -41,13 +32,18 @@ def __init__(self,
4132 camera_fov : float = 60.0 ,
4233 ):
4334 super ().__init__ (log_events )
44- self .camera = camera
45- self .scene = scene
46- self .viewport_size = viewport_size
47- self .camera_yfov = camera_yfov
35+ self .camera : 'Node' = camera
36+ self .scene : 'Scene' = scene
37+ self .viewport_size : tuple [ int , int ] = viewport_size
38+ self .camera_yfov : float = camera_yfov
4839
49- self .tan_half_fov = np .tan (0.5 * self .camera_yfov )
50- self .prev_mouse_pos = tuple (np .array (viewport_size ) / 2 )
40+ self .tan_half_fov : float = np .tan (0.5 * self .camera_yfov )
41+ self .prev_mouse_pos : tuple [int , int ] = (0.5 * viewport_size [0 ], 0.5 * viewport_size [1 ])
42+
43+ self .picked_entity : RigidEntity | None = None
44+ self .picked_point_in_local : Vec3 | None = None
45+ self .mouse_drag_plane : Plane | None = None
46+ self .prev_mouse_3d_pos : Vec3 | None = None
5147
5248 @override
5349 def on_mouse_motion (self , x : int , y : int , dx : int , dy : int ) -> EVENT_HANDLE_STATE :
@@ -58,17 +54,55 @@ def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STAT
5854 def on_mouse_drag (self , x : int , y : int , dx : int , dy : int , buttons : int , modifiers : int ) -> EVENT_HANDLE_STATE :
5955 super ().on_mouse_drag (x , y , dx , dy , buttons , modifiers )
6056 self .prev_mouse_pos = (x , y )
57+ if self .picked_entity :
58+ mouse_ray : Ray = self .screen_position_to_ray (x , y )
59+ ray_hit : RayHit = self .mouse_drag_plane .raycast (mouse_ray )
60+ assert ray_hit .is_hit
61+ if ray_hit .is_hit :
62+ new_mouse_3d_pos : Vec3 = ray_hit .position
63+ delta_3d_pos : Vec3 = new_mouse_3d_pos - self .prev_mouse_3d_pos
64+ self .prev_mouse_3d_pos = new_mouse_3d_pos
65+
66+ use_force : bool = False
67+ if use_force :
68+ # apply force
69+ pass
70+ else :
71+ #apply displacement
72+ pos = Vec3 .from_tensor (self .picked_entity .get_pos ())
73+ pos = pos + delta_3d_pos
74+ self .picked_entity .set_pos (pos .as_tensor ())
75+
76+ return EVENT_HANDLED
6177
6278 @override
6379 def on_mouse_press (self , x : int , y : int , button : int , modifiers : int ) -> EVENT_HANDLE_STATE :
6480 super ().on_mouse_press (x , y , button , modifiers )
81+ if button == 1 : # left mouse button
82+ (ray_hit , self .picked_entity ) = self .raycast_against_entities (self .screen_position_to_ray (x , y ))
83+ if self .picked_entity and ray_hit :
84+ temp_fwd = self .get_camera_forward ()
85+ temp_back = - temp_fwd
86+
87+ self .mouse_drag_plane = Plane (temp_back , ray_hit .position )
88+ self .prev_mouse_3d_pos = ray_hit .position
89+
90+ pose : Pose = self .get_pose_of_first_geom (self .picked_entity )
91+ self .picked_point_in_local = pose .inverse_transform_point (ray_hit .position )
92+
93+ @override
94+ def on_mouse_release (self , x : int , y : int , button : int , modifiers : int ) -> EVENT_HANDLE_STATE :
95+ super ().on_mouse_release (x , y , button , modifiers )
96+ if button == 1 : # left mouse button
97+ self .picked_entity = None
98+ self .picked_point_in_local = None
6599
66100 @override
67101 def on_draw (self ) -> None :
68102 super ().on_draw ()
69103 if self .scene ._visualizer is not None and self .scene ._visualizer .viewer_lock is not None :
70104 self .scene .clear_debug_objects ()
71- mouse_ray = self .screen_position_to_ray (* self .prev_mouse_pos )
105+ mouse_ray : Ray = self .screen_position_to_ray (* self .prev_mouse_pos )
72106 closest_hit = None
73107 hit_entity : RigidEntity | None = None
74108
@@ -77,7 +111,7 @@ def on_draw(self) -> None:
77111 closest_hit = ray_hit
78112
79113 for entity in self .get_entities ():
80- ray_hit = self .raycast_against_entity_oobb (entity , mouse_ray )
114+ ray_hit = self .raycast_against_entity_obb (entity , mouse_ray )
81115 if ray_hit .is_hit :
82116 if closest_hit is None or ray_hit .distance < closest_hit .distance :
83117 closest_hit = ray_hit
@@ -87,7 +121,21 @@ def on_draw(self) -> None:
87121 self .scene .draw_debug_sphere (closest_hit .position .v , 0.01 , (0 , 1 , 0 , 1 ))
88122 self ._draw_arrow (closest_hit .position , 0.25 * closest_hit .normal , (0 , 1 , 0 , 1 ))
89123 if hit_entity :
90- self ._draw_entity_unrotated_oobb (hit_entity )
124+ self ._draw_entity_unrotated_obb (hit_entity )
125+
126+ if self .picked_entity :
127+ assert self .mouse_drag_plane is not None
128+ assert self .picked_point_in_local is not None
129+
130+ # draw held point
131+ pose : Pose = self .get_pose_of_first_geom (self .picked_entity )
132+ held_point : Vec3 = pose .transform_point (self .picked_point_in_local )
133+ self .scene .draw_debug_sphere (held_point .v , 0.02 , Color .red ().tuple ())
134+
135+ plane_hit : RayHit = self .mouse_drag_plane .raycast (mouse_ray )
136+ if plane_hit .is_hit :
137+ self .scene .draw_debug_sphere (plane_hit .position .v , 0.02 , Color .red ().tuple ())
138+ self .scene .draw_debug_line (held_point .v , plane_hit .position .v , color = Color .red ().tuple ())
91139
92140 def screen_position_to_ray (self , x : float , y : float ) -> Ray :
93141 # convert screen position to ray
@@ -107,35 +155,49 @@ def screen_position_to_ray(self, x: float, y: float) -> Ray:
107155 # Note: ignoring pixel aspect ratio
108156
109157 mtx = self .camera .matrix
110- position = Vec3 .from_float64 (mtx [:3 , 3 ])
111- forward = Vec3 .from_float64 (- mtx [:3 , 2 ])
112- right = Vec3 .from_float64 (mtx [:3 , 0 ])
113- up = Vec3 .from_float64 (mtx [:3 , 1 ])
158+ position = Vec3 .from_array (mtx [:3 , 3 ])
159+ forward = Vec3 .from_array (- mtx [:3 , 2 ])
160+ right = Vec3 .from_array (mtx [:3 , 0 ])
161+ up = Vec3 .from_array (mtx [:3 , 1 ])
114162
115163 direction = forward + right * x + up * y
116164 return Ray (position , direction )
117165
166+ def get_camera_forward (self ) -> Vec3 :
167+ mtx = self .camera .matrix
168+ return Vec3 .from_array (- mtx [:3 , 2 ])
169+
118170 def get_camera_ray (self ) -> Ray :
119171 mtx = self .camera .matrix
120- position = Vec3 .from_float64 (mtx [:3 , 3 ])
121- forward = Vec3 .from_float64 (- mtx [:3 , 2 ])
172+ position = Vec3 .from_array (mtx [:3 , 3 ])
173+ forward = Vec3 .from_array (- mtx [:3 , 2 ])
122174 return Ray (position , forward )
123175
124176 def _raycast_against_ground_plane (self , ray : Ray ) -> RayHit :
125177 ground_plane = Plane (Vec3 .from_xyz (0 , 0 , 1 ), Vec3 .zero ())
126178 return ground_plane .raycast (ray )
127179
128- def raycast_against_entity_oobb (self , entity : RigidEntity , ray : Ray ) -> RayHit :
180+ def raycast_against_entity_obb (self , entity : RigidEntity , ray : Ray ) -> RayHit :
129181 if isinstance (entity .morph , gs .morphs .Box ):
130182 box : gs .morphs .Box = entity .morph
131183 size = Vec3 .from_xyz (* box .size )
132184 pose = self .get_pose_of_first_geom (entity )
133185 aabb = AABB .from_center_and_size (Vec3 .zero (), size )
134- ray_hit = aabb .raycast_oobb (pose , ray )
186+ ray_hit = aabb .raycast_obb (pose , ray )
135187 return ray_hit
136188 else :
137189 return RayHit .no_hit ()
138190
191+ def raycast_against_entities (self , ray : Ray ) -> tuple [RayHit | None , RigidEntity | None ]:
192+ closest_hit = None
193+ hit_entity : RigidEntity | None = None
194+ for entity in self .get_entities ():
195+ ray_hit = self .raycast_against_entity_obb (entity , ray )
196+ if ray_hit .is_hit and (closest_hit is None or ray_hit .distance < closest_hit .distance ):
197+ closest_hit = ray_hit
198+ hit_entity = entity
199+ return (closest_hit , hit_entity )
200+
139201 def get_entities (self ) -> list [RigidEntity ]:
140202 return self .scene .sim .rigid_solver .entities
141203
@@ -152,7 +214,7 @@ def _draw_arrow(
152214 self .scene .draw_debug_arrow (pos .v , dir .v , color = color ) # Only draws arrowhead -- bug?
153215 self .scene .draw_debug_line (pos .v , pos .v + dir .v , color = color )
154216
155- def _draw_entity_unrotated_oobb (self , entity : RigidEntity ) -> None :
217+ def _draw_entity_unrotated_obb (self , entity : RigidEntity ) -> None :
156218 if isinstance (entity .morph , gs .morphs .Plane ):
157219 plane : gs .morphs .Plane = entity .morph
158220 pass
@@ -161,10 +223,9 @@ def _draw_entity_unrotated_oobb(self, entity: RigidEntity) -> None:
161223 size = Vec3 .from_xyz (* box .size )
162224 geom : RigidGeom = entity .geoms [0 ]
163225 assert geom ._solver .n_envs == 0 , "ViewerInteraction only supports single-env for now"
164- gpos = geom .get_pos () # squeezed if n_envs == 0
165- gquat = geom .get_quat () # squeezed if n_envs == 0
166- pos = Vec3 .from_any_array (gpos .cpu ().numpy ())
167- quat = Quat .from_any_array (gquat .cpu ().numpy ())
226+ # geom.get_pos() and .get_quat() are squeezed if n_envs == 0
227+ pos = Vec3 .from_tensor (geom .get_pos ())
228+ quat = Quat .from_tensor (geom .get_quat ())
168229 aabb = AABB .from_center_and_size (pos , size )
169230 aabb .expand (0.01 )
170231 self .scene .draw_debug_box (aabb .v , color = Color .red ().with_alpha (0.5 ).tuple (), wireframe = False )
0 commit comments