77from genesis .engine .entities .rigid_entity .rigid_entity import RigidEntity
88
99from .aabb import AABB
10+ from .mouse_spring import MouseSpring
1011from .ray import Plane , Ray , RayHit
1112from .vec3 import Pose , Quat , Vec3 , Color
1213from .viewer_interaction_base import ViewerInteractionBase , EVENT_HANDLE_STATE , EVENT_HANDLED
@@ -38,13 +39,15 @@ def __init__(self,
3839 self .camera_yfov : float = camera_yfov
3940
4041 self .tan_half_fov : float = np .tan (0.5 * self .camera_yfov )
41- self .prev_mouse_pos : tuple [int , int ] = (0.5 * viewport_size [0 ], 0.5 * viewport_size [1 ])
42+ self .prev_mouse_pos : tuple [int , int ] = (viewport_size [0 ] // 2 , viewport_size [1 ] // 2 )
4243
4344 self .picked_entity : RigidEntity | None = None
4445 self .picked_point_in_local : Vec3 | None = None
4546 self .mouse_drag_plane : Plane | None = None
4647 self .prev_mouse_3d_pos : Vec3 | None = None
4748
49+ self .mouse_spring : MouseSpring = MouseSpring ()
50+
4851 @override
4952 def on_mouse_motion (self , x : int , y : int , dx : int , dy : int ) -> EVENT_HANDLE_STATE :
5053 super ().on_mouse_motion (x , y , dx , dy )
@@ -55,23 +58,7 @@ def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifier
5558 super ().on_mouse_drag (x , y , dx , dy , buttons , modifiers )
5659 self .prev_mouse_pos = (x , y )
5760 if self .picked_entity :
58- mouse_ray : Ray = self .screen_position_to_ray (x , y )
59- ray_hit : RayHit = self .mouse_drag_plane .raycast (mouse_ray )
60- assert ray_hit .is_hit
61- if ray_hit .is_hit :
62- new_mouse_3d_pos : Vec3 = ray_hit .position
63- delta_3d_pos : Vec3 = new_mouse_3d_pos - self .prev_mouse_3d_pos
64- self .prev_mouse_3d_pos = new_mouse_3d_pos
65-
66- use_force : bool = False
67- if use_force :
68- # apply force
69- pass
70- else :
71- #apply displacement
72- pos = Vec3 .from_tensor (self .picked_entity .get_pos ())
73- pos = pos + delta_3d_pos
74- self .picked_entity .set_pos (pos .as_tensor ())
61+ # actual processing done in update_on_sim_step()
7562
7663 return EVENT_HANDLED
7764
@@ -90,13 +77,44 @@ def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_H
9077 pose : Pose = self .get_pose_of_first_geom (self .picked_entity )
9178 self .picked_point_in_local = pose .inverse_transform_point (ray_hit .position )
9279
80+ self .mouse_spring .attach (self .picked_entity , ray_hit .position )
81+
9382 @override
9483 def on_mouse_release (self , x : int , y : int , button : int , modifiers : int ) -> EVENT_HANDLE_STATE :
9584 super ().on_mouse_release (x , y , button , modifiers )
9685 if button == 1 : # left mouse button
9786 self .picked_entity = None
9887 self .picked_point_in_local = None
9988
89+ self .mouse_spring .detach ()
90+
91+ @override
92+ def on_resize (self , width : int , height : int ) -> EVENT_HANDLE_STATE :
93+ super ().on_resize (width , height )
94+ self .viewport_size = (width , height )
95+ self .tan_half_fov = np .tan (0.5 * self .camera_yfov )
96+
97+ @override
98+ def update_on_sim_step (self ) -> None :
99+ if self .picked_entity :
100+ mouse_ray : Ray = self .screen_position_to_ray (* self .prev_mouse_pos )
101+ ray_hit : RayHit = self .mouse_drag_plane .raycast (mouse_ray )
102+ assert ray_hit .is_hit
103+ if ray_hit .is_hit :
104+ new_mouse_3d_pos : Vec3 = ray_hit .position
105+ delta_3d_pos : Vec3 = new_mouse_3d_pos - self .prev_mouse_3d_pos
106+ self .prev_mouse_3d_pos = new_mouse_3d_pos
107+
108+ use_force : bool = True
109+ if use_force :
110+ # apply force
111+ self .mouse_spring .apply_force (new_mouse_3d_pos , self .scene .sim .dt )
112+ else :
113+ #apply displacement
114+ pos = Vec3 .from_tensor (self .picked_entity .get_pos ())
115+ pos = pos + delta_3d_pos
116+ self .picked_entity .set_pos (pos .as_tensor ())
117+
100118 @override
101119 def on_draw (self ) -> None :
102120 super ().on_draw ()
@@ -202,11 +220,7 @@ def get_entities(self) -> list[RigidEntity]:
202220 return self .scene .sim .rigid_solver .entities
203221
204222 def get_pose_of_first_geom (self , entity : RigidEntity ) -> 'Pose' :
205- geom : RigidGeom = entity .geoms [0 ]
206- assert geom ._solver .n_envs == 0 , "ViewerInteraction only supports single-env for now"
207- gpos = geom .get_pos () # squeezed if n_envs == 0
208- gquat = geom .get_quat () # squeezed if n_envs == 0
209- return Pose (Vec3 (gpos .cpu ().numpy ()), Quat (gquat .cpu ().numpy ()))
223+ return Pose .from_geom (entity .geoms [0 ])
210224
211225 def _draw_arrow (
212226 self , pos : Vec3 , dir : Vec3 , color : tuple [float , float , float , float ] = (1.0 , 1.0 , 1.0 , 1.0 ),
@@ -215,6 +229,9 @@ def _draw_arrow(
215229 self .scene .draw_debug_line (pos .v , pos .v + dir .v , color = color )
216230
217231 def _draw_entity_unrotated_obb (self , entity : RigidEntity ) -> None :
232+ if self .picked_entity :
233+ return
234+
218235 if isinstance (entity .morph , gs .morphs .Plane ):
219236 plane : gs .morphs .Plane = entity .morph
220237 pass
0 commit comments