diff --git a/examples/drone/interactive_drone.py b/examples/drone/interactive_drone.py index 6a0e078841..05bfba97f7 100644 --- a/examples/drone/interactive_drone.py +++ b/examples/drone/interactive_drone.py @@ -91,23 +91,6 @@ def update_thrust(self): return self.rpms -def update_camera(scene, drone): - """Updates the camera position to follow the drone""" - if not scene.viewer: - return - - drone_pos = drone.get_pos() - - # Camera position relative to drone - offset_x = 0.0 # centered horizontally - offset_y = -4.0 # 4 units behind (in Y axis) - offset_z = 2.0 # 2 units above - - camera_pos = (float(drone_pos[0] + offset_x), float(drone_pos[1] + offset_y), float(drone_pos[2] + offset_z)) - - # Update camera position and look target - scene.viewer.set_camera_pose(pos=camera_pos, lookat=tuple(float(x) for x in drone_pos)) - def run_sim(scene, drone, controller): while controller.running: @@ -119,9 +102,6 @@ def run_sim(scene, drone, controller): # Update physics scene.step() - # Update camera position to follow drone - update_camera(scene, drone) - time.sleep(1 / 60) # Limit simulation rate except Exception as e: print(f"Error in simulation loop: {e}") @@ -165,6 +145,8 @@ def main(): ), ) + scene.viewer.follow_entity(drone) + # Build scene scene.build() diff --git a/examples/locomotion/go2_env.py b/examples/locomotion/go2_env.py index 0ff5825c17..a99ad64e4f 100644 --- a/examples/locomotion/go2_env.py +++ b/examples/locomotion/go2_env.py @@ -64,6 +64,16 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie ), ) + # add follower camera + if show_viewer: + self.follower_camera = self.scene.add_camera(res=(640,480), + pos=(0.0, 2.0, 0.5), + lookat=(0.0, 0.0, 0.5), + fov=40, + GUI=True) + # follow the robot at a fixed height and orientation + self.follower_camera.follow_entity(self.robot, fixed_axis=(None, None, 0.5), smoothing=0.5, fix_orientation=True) + # build self.scene.build(n_envs=num_envs) @@ -124,6 +134,9 @@ def step(self, actions): self.robot.control_dofs_position(target_dof_pos, self.motor_dofs) self.scene.step() + if hasattr(self, "follower_camera"): + self.follower_camera.render() + # update buffers self.episode_length_buf += 1 self.base_pos[:] = self.robot.get_pos() diff --git a/examples/locomotion/go2_train.py b/examples/locomotion/go2_train.py index 7c77990cee..fe68f8ebd9 100644 --- a/examples/locomotion/go2_train.py +++ b/examples/locomotion/go2_train.py @@ -137,6 +137,7 @@ def get_cfgs(): def main(): parser = argparse.ArgumentParser() + parser.add_argument("-v", "--vis", action="store_true", default=False, help="Enable visualization (default: False)") parser.add_argument("-e", "--exp_name", type=str, default="go2-walking") parser.add_argument("-B", "--num_envs", type=int, default=4096) parser.add_argument("--max_iterations", type=int, default=100) @@ -153,7 +154,7 @@ def main(): os.makedirs(log_dir, exist_ok=True) env = Go2Env( - num_envs=args.num_envs, env_cfg=env_cfg, obs_cfg=obs_cfg, reward_cfg=reward_cfg, command_cfg=command_cfg + num_envs=args.num_envs, env_cfg=env_cfg, obs_cfg=obs_cfg, reward_cfg=reward_cfg, command_cfg=command_cfg, show_viewer=args.vis ) runner = OnPolicyRunner(env, train_cfg, log_dir, device="cuda:0") diff --git a/genesis/vis/camera.py b/genesis/vis/camera.py index 687ab052ab..2ad6d865ec 100644 --- a/genesis/vis/camera.py +++ b/genesis/vis/camera.py @@ -93,6 +93,13 @@ def __init__( self._in_recording = False self._recorded_imgs = [] + self._init_pos = np.array(pos) + + self._followed_entity = None + self._follow_fixed_axis = None + self._follow_smoothing = None + self._follow_fix_orientation = None + if self._model not in ["pinhole", "thinlens"]: gs.raise_exception(f"Invalid camera model: {self._model}") @@ -146,6 +153,9 @@ def render(self, rgb=True, depth=False, segmentation=False, colorize_seg=False, rgb_arr, depth_arr, seg_idxc_arr, seg_arr, normal_arr = None, None, None, None, None + if self._followed_entity is not None: + self.update_following() + if self._raytracer is not None: if rgb: self._raytracer.update_scene() @@ -245,6 +255,60 @@ def set_pose(self, transform=None, pos=None, lookat=None, up=None): if self._raytracer is not None: self._raytracer.update_camera(self) + def follow_entity(self, entity, fixed_axis=(None, None, None), smoothing=None, fix_orientation=False): + """ + Set the camera to follow a specified entity. + + Parameters + ---------- + entity : genesis.Entity + The entity to follow. + fixed_axis : (float, float, float), optional + The fixed axis for the camera's movement. For each axis, if None, the camera will move freely. If a float, the viewer will be fixed on at that value. + For example, [None, None, None] will allow the camera to move freely while following, [None, None, 0.5] will fix the viewer's z-axis at 0.5. + smoothing : float, optional + The smoothing factor for the camera's movement. If None, no smoothing will be applied. + fix_orientation : bool, optional + If True, the camera will maintain its orientation relative to the world. If False, the camera will look at the base link of the entity. + """ + self._followed_entity = entity + self._follow_fixed_axis = fixed_axis + self._follow_smoothing = smoothing + self._follow_fix_orientation = fix_orientation + + @gs.assert_built + def update_following(self): + """ + Update the camera position to follow the specified entity. + """ + + entity_pos = self._followed_entity.get_pos()[0].cpu().numpy() + if entity_pos.ndim > 1: #check for multiple envs + entity_pos = entity_pos[0] + camera_pos = np.array(self._pos) + camera_pose = np.array(self._transform) + lookat_pos = np.array(self._lookat) + + if self._follow_smoothing is not None: + # Smooth camera movement with a low-pass filter + camera_pos = self._follow_smoothing * camera_pos + (1 - self._follow_smoothing) * (entity_pos + self._init_pos) + lookat_pos = self._follow_smoothing * lookat_pos + (1 - self._follow_smoothing) * entity_pos + else: + camera_pos = entity_pos + self._init_pos + lookat_pos = entity_pos + + for i, fixed_axis in enumerate(self._follow_fixed_axis): + # Fix the camera's position along the specified axis + if fixed_axis is not None: + camera_pos[i] = fixed_axis + + if self._follow_fix_orientation: + # Keep the camera orientation fixed by overriding the lookat point + camera_pose[:3, 3] = camera_pos + self.set_pose(transform=camera_pose) + else: + self.set_pose(pos=camera_pos, lookat=lookat_pos) + @gs.assert_built def set_params(self, fov=None, aperture=None, focus_dist=None): """ diff --git a/genesis/vis/viewer.py b/genesis/vis/viewer.py index e7ae2395f3..c58d445fbf 100644 --- a/genesis/vis/viewer.py +++ b/genesis/vis/viewer.py @@ -42,6 +42,12 @@ def __init__(self, options, context): self.context = context + self._followed_entity = None + self._follow_fixed_axis = None + self._follow_smoothing = None + self._follow_fix_orientation = None + self._follow_lookat = None + if self._max_FPS is not None: self.rate = Rate(self._max_FPS) @@ -109,6 +115,9 @@ def setup_camera(self): self._camera_node = self.context.add_node(pyrender.PerspectiveCamera(yfov=yfov), pose=pose) def update(self): + if self._followed_entity is not None: + self.update_following() + with self.lock: buffer_updates = self.context.update() for buffer_id, buffer_data in buffer_updates.items(): @@ -153,6 +162,57 @@ def set_camera_pose(self, pose=None, pos=None, lookat=None): self._pyrender_viewer._trackball.set_camera_pose(pose) + def follow_entity(self, entity, fixed_axis=(None, None, None), smoothing=None, fix_orientation=False): + """ + Set the viewer to follow a specified entity. + Parameters + ---------- + entity : genesis.Entity + The entity to follow. + fixed_axis : (float, float, float), optional + The fixed axis for the viewer's movement. For each axis, if None, the viewer will move freely. If a float, the viewer will be fixed on at that value. + For example, [None, None, None] will allow the viewer to move freely while following, [None, None, 0.5] will fix the viewer's z-axis at 0.5. + smoothing : float, optional + The smoothing factor in ]0,1[ for the viewer's movement. If None, no smoothing will be applied. + fix_orientation : bool, optional + If True, the viewer will maintain its orientation relative to the world. If False, the viewer will look at the base link of the entity. + """ + self._followed_entity = entity + self._follow_fixed_axis = fixed_axis + self._follow_smoothing = smoothing + self._follow_fix_orientation = fix_orientation + self._follow_lookat = self._camera_init_lookat + + def update_following(self): + """ + Update the viewer position to follow the specified entity. + """ + entity_pos = self._followed_entity.get_pos().cpu().numpy() + if entity_pos.ndim > 1: #check for multiple envs + entity_pos = entity_pos[0] + camera_pose = np.array(self._pyrender_viewer._trackball.pose) + camera_pos = np.array(self._pyrender_viewer._trackball.pose[:3, 3]) + + if self._follow_smoothing is not None: + # Smooth viewer movement with a low-pass filter + camera_pos = self._follow_smoothing * camera_pos + (1 - self._follow_smoothing) * (entity_pos + np.array(self._camera_init_pos)) + self._follow_lookat = self._follow_smoothing * self._follow_lookat + (1 - self._follow_smoothing) * entity_pos + else: + camera_pos = entity_pos + np.array(self._camera_init_pos) + self._follow_lookat = entity_pos + + for i, fixed_axis in enumerate(self._follow_fixed_axis): + # Fix the camera's position along the specified axis + if fixed_axis is not None: + camera_pos[i] = fixed_axis + + if self._follow_fix_orientation: + # Keep the camera orientation fixed by overriding the lookat point + camera_pose[:3, 3] = camera_pos + self.set_camera_pose(pose=camera_pose) + else: + self.set_camera_pose(pos=camera_pos, lookat=self._follow_lookat) + # ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------