diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index dcf441f94f..68165d8fa2 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -54,7 +54,7 @@ jobs: run: | python -m pip install --upgrade pip setuptools wheel pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install -e '.[dev,usd]' pynput + pip install -e '.[dev,usd]' - name: Get gstaichi version id: gstaichi_version diff --git a/examples/IPC_Solver/ipc_arm_cloth.py b/examples/IPC_Solver/ipc_arm_cloth.py index 20df431765..5165421988 100644 --- a/examples/IPC_Solver/ipc_arm_cloth.py +++ b/examples/IPC_Solver/ipc_arm_cloth.py @@ -14,45 +14,19 @@ ; - Roll Right (Rotate around X) u - Reset Scene space - Press to close gripper, release to open gripper -esc - Quit """ -import threading import argparse -import numpy as np import csv import os from datetime import datetime -from pynput import keyboard -from scipy.spatial.transform import Rotation as R + +import numpy as np from huggingface_hub import snapshot_download import genesis as gs - - -class KeyboardDevice: - def __init__(self): - self.pressed_keys = set() - self.lock = threading.Lock() - self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) - - def start(self): - self.listener.start() - - def stop(self): - self.listener.stop() - self.listener.join() - - def on_press(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.add(key) - - def on_release(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.discard(key) - - def get_cmd(self): - return self.pressed_keys +import genesis.utils.geom as gu +from genesis.vis.keybindings import Key, KeyAction, Keybind def build_scene(use_ipc=False, show_viewer=False, enable_ipc_gui=False): @@ -109,18 +83,12 @@ def build_scene(use_ipc=False, show_viewer=False, enable_ipc_gui=False): euler=(0, 0, 0), ), ) - scene.sim.coupler.set_ipc_link_filter( - entity=entities["robot"], - link_names=["left_finger", "right_finger"], - ) - - material = ( - gs.materials.FEM.Elastic(E=1.0e4, nu=0.45, rho=1000.0, model="stable_neohookean") - if use_ipc - else gs.materials.Rigid() - ) if use_ipc: + scene.sim.coupler.set_ipc_link_filter( + entity=entities["robot"], + link_names=["left_finger", "right_finger"], + ) cloth = scene.add_entity( morph=gs.morphs.Mesh( file="meshes/grid20x20.obj", @@ -171,14 +139,15 @@ def build_scene(use_ipc=False, show_viewer=False, enable_ipc_gui=False): return scene, entities -def run_sim(scene, entities, clients, mode="interactive", trajectory_file=None): +def run_sim(scene, entities, add_keybinds, mode="interactive", trajectory_file=None): robot = entities["robot"] target_entity = entities["target"] + is_running = True robot_init_pos = np.array([0.5, 0, 0.55]) - robot_init_R = R.from_euler("y", np.pi) + robot_init_quat = gu.xyz_to_quat(np.array([0, np.pi, 0])) # Rotation around Y axis target_pos = robot_init_pos.copy() - target_R = robot_init_R + target_quat = robot_init_quat.copy() n_dofs = robot.n_dofs motors_dof = np.arange(n_dofs - 2) @@ -189,11 +158,16 @@ def run_sim(scene, entities, clients, mode="interactive", trajectory_file=None): trajectory = [] recording = mode == "record" + # Gripper state (use list for mutability in closures) + gripper_closed = [False] + + # Control parameters + dpos = 0.002 + drot = 0.01 + def reset_scene(): - nonlocal target_pos, target_R - target_pos = robot_init_pos.copy() - target_R = robot_init_R - target_quat = target_R.as_quat(scalar_first=True) + target_pos[:] = robot_init_pos + target_quat[:] = robot_init_quat target_entity.set_qpos(np.concatenate([target_pos, target_quat])) q = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat) robot.set_qpos(q[:-2], motors_dof) @@ -201,6 +175,46 @@ def reset_scene(): # entities["cube"].set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) # entities["cube"].set_quat(R.from_euler("z", random.uniform(0, np.pi * 2)).as_quat(scalar_first=True)) + # Register keybindings + if add_keybinds: + + def move(dpos_delta: tuple[float, float, float]): + target_pos[:] += np.array(dpos_delta, dtype=gs.np_float) + + def rotate(axis: str, angle: float): + # Create rotation quaternion for the specified axis + euler = np.zeros(3) + axis_map = {"x": 0, "y": 1, "z": 2} + euler[axis_map[axis]] = angle + drot_quat = gu.xyz_to_quat(euler) + target_quat[:] = gu.transform_quat_by_quat(target_quat, drot_quat) + + def toggle_gripper(close: bool = True): + gripper_closed[0] = close + + def stop(): + nonlocal is_running + is_running = False + + scene.viewer.register_keybinds( + Keybind("move_forward", Key.UP, KeyAction.HOLD, callback=move, args=((-dpos, 0, 0),)), + Keybind("move_backward", Key.DOWN, KeyAction.HOLD, callback=move, args=((dpos, 0, 0),)), + Keybind("move_left", Key.LEFT, KeyAction.HOLD, callback=move, args=((0, -dpos, 0),)), + Keybind("move_right", Key.RIGHT, KeyAction.HOLD, callback=move, args=((0, dpos, 0),)), + Keybind("move_up", Key.N, KeyAction.HOLD, callback=move, args=((0, 0, dpos),)), + Keybind("move_down", Key.M, KeyAction.HOLD, callback=move, args=((0, 0, -dpos),)), + Keybind("yaw_left", Key.J, KeyAction.HOLD, callback=rotate, args=("z", drot)), + Keybind("yaw_right", Key.K, KeyAction.HOLD, callback=rotate, args=("z", -drot)), + Keybind("pitch_up", Key.I, KeyAction.HOLD, callback=rotate, args=("y", drot)), + Keybind("pitch_down", Key.O, KeyAction.HOLD, callback=rotate, args=("y", -drot)), + Keybind("roll_left", Key.L, KeyAction.HOLD, callback=rotate, args=("x", drot)), + Keybind("roll_right", Key.SEMICOLON, KeyAction.HOLD, callback=rotate, args=("x", -drot)), + Keybind("reset_scene", Key.U, KeyAction.HOLD, callback=reset_scene), + Keybind("close_gripper", Key.SPACE, KeyAction.PRESS, callback=toggle_gripper, args=(True,)), + Keybind("open_gripper", Key.SPACE, KeyAction.RELEASE, callback=toggle_gripper, args=(False,)), + Keybind("quit", Key.ESCAPE, KeyAction.PRESS, callback=stop), + ) + # Load trajectory if in playback mode if mode == "playback": if not trajectory_file or not os.path.exists(trajectory_file): @@ -232,7 +246,7 @@ def reset_scene(): print(f"\nMode: {mode.upper()}") if mode == "record": - print("Recording trajectory... Press ESC to stop and save.") + print("Recording trajectory...") elif mode == "playback": print("Playing back trajectory...") @@ -248,99 +262,59 @@ def reset_scene(): print("l/;\t- Roll Left/Right (Rotate around X axis)") print("u\t- Reset Scene") print("space\t- Press to close gripper, release to open gripper") - print("esc\t- Quit") + if mode in ["interactive", "record"]: + print("\nPlus all default viewer controls (press 'i' to see them)") # reset scene before starting teleoperation reset_scene() # start teleoperation or playback - stop = False step_count = 0 - while not stop: - if mode == "playback": - # Playback mode: replay recorded trajectory - if step_count < len(trajectory): - step_data = trajectory[step_count] - target_pos = step_data["target_pos"] - target_R = R.from_quat(step_data["target_quat"]) - is_close_gripper = step_data["gripper_closed"] - step_count += 1 - print(f"\rPlayback step: {step_count}/{len(trajectory)}", end="") - # Check if user wants to stop playback - pressed_keys = clients["keyboard"].pressed_keys.copy() - stop = keyboard.Key.esc in pressed_keys + try: + while is_running: + if mode == "playback": + # Playback mode: replay recorded trajectory + if step_count < len(trajectory): + step_data = trajectory[step_count] + target_pos[:] = step_data["target_pos"] + target_quat[:] = step_data["target_quat"] + gripper_closed[0] = step_data["gripper_closed"] + step_count += 1 + print(f"\rPlayback step: {step_count}/{len(trajectory)}", end="") + else: + print("\nPlayback finished!") + break else: - print("\nPlayback finished!") - break - else: - # Interactive or recording mode - pressed_keys = clients["keyboard"].pressed_keys.copy() - - # reset scene: - reset_flag = False - reset_flag |= keyboard.KeyCode.from_char("u") in pressed_keys - if reset_flag: - reset_scene() - - # stop teleoperation - stop = keyboard.Key.esc in pressed_keys - - # get ee target pose - is_close_gripper = False - dpos = 0.002 - drot = 0.01 - for key in pressed_keys: - if key == keyboard.Key.up: - target_pos[0] -= dpos - elif key == keyboard.Key.down: - target_pos[0] += dpos - elif key == keyboard.Key.right: - target_pos[1] += dpos - elif key == keyboard.Key.left: - target_pos[1] -= dpos - elif key == keyboard.KeyCode.from_char("n"): - target_pos[2] += dpos - elif key == keyboard.KeyCode.from_char("m"): - target_pos[2] -= dpos - elif key == keyboard.KeyCode.from_char("j"): - target_R = R.from_euler("z", drot) * target_R - elif key == keyboard.KeyCode.from_char("k"): - target_R = R.from_euler("z", -drot) * target_R - elif key == keyboard.KeyCode.from_char("i"): - target_R = R.from_euler("y", drot) * target_R - elif key == keyboard.KeyCode.from_char("o"): - target_R = R.from_euler("y", -drot) * target_R - elif key == keyboard.KeyCode.from_char("l"): - target_R = R.from_euler("x", drot) * target_R - elif key == keyboard.KeyCode.from_char(";"): - target_R = R.from_euler("x", -drot) * target_R - elif key == keyboard.Key.space: - is_close_gripper = True - - # Record current state if recording - if recording: - step_data = { - "target_pos": target_pos.copy(), - "target_quat": target_R.as_quat(), # x,y,z,w format - "gripper_closed": is_close_gripper, - "step": step_count, - } - trajectory.append(step_data) + # Interactive or recording mode + # Movement is handled by keybinding callbacks + # Record current state if recording + if recording: + step_data = { + "target_pos": target_pos.copy(), + "target_quat": target_quat.copy(), + "gripper_closed": gripper_closed[0], + "step": step_count, + } + trajectory.append(step_data) + + # control arm + target_entity.set_qpos(np.concatenate([target_pos, target_quat])) + q, err = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat, return_error=True) + robot.control_dofs_position(q[:-2], motors_dof) + # control gripper + if gripper_closed[0]: + robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) + else: + robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) - # control arm - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) - q, err = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat, return_error=True) - robot.control_dofs_position(q[:-2], motors_dof) - # control gripper - if is_close_gripper: - robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) - else: - robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) + scene.step() + step_count += 1 - scene.step() - step_count += 1 + if "PYTEST_VERSION" in os.environ: + break + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") # Save trajectory if recording if recording and len(trajectory) > 0: @@ -436,12 +410,14 @@ def main(): elif not os.path.isabs(trajectory_file): trajectory_file = os.path.join(traj_dir, os.path.basename(trajectory_file)) - clients = dict() - clients["keyboard"] = KeyboardDevice() - clients["keyboard"].start() - scene, entities = build_scene(use_ipc=args.ipc, show_viewer=args.vis, enable_ipc_gui=False) - run_sim(scene, entities, clients, mode=args.mode, trajectory_file=trajectory_file) + run_sim( + scene, + entities, + add_keybinds=args.vis or args.mode in ["interactive", "record"], + mode=args.mode, + trajectory_file=trajectory_file, + ) if __name__ == "__main__": diff --git a/examples/drone/interactive_drone.py b/examples/drone/interactive_drone.py index cdc6713d6e..56c750ae1f 100644 --- a/examples/drone/interactive_drone.py +++ b/examples/drone/interactive_drone.py @@ -1,112 +1,35 @@ import os -import time -import threading -from pynput import keyboard import numpy as np import genesis as gs +from genesis.vis.keybindings import Key, KeyAction, Keybind class DroneController: def __init__(self): - self.thrust = 14475.8 # Base hover RPM - constant hover - self.rotation_delta = 200.0 # Differential RPM for rotation + self.thrust = 14475.8 # Base RPM for constant hover + self.rotation_delta = 100.0 # Differential RPM for rotation self.thrust_delta = 10.0 # Amount to change thrust by when accelerating/decelerating - self.running = True - self.rpms = [self.thrust] * 4 - self.pressed_keys = set() - - def on_press(self, key): - try: - if key == keyboard.Key.esc: - self.running = False - return False - self.pressed_keys.add(key) - print(f"Key pressed: {key}") - except AttributeError: - pass - - def on_release(self, key): - try: - self.pressed_keys.discard(key) - except KeyError: - pass - - def update_thrust(self): - # Store previous RPMs for debugging - prev_rpms = self.rpms.copy() - - # Reset RPMs to hover thrust - self.rpms = [self.thrust] * 4 - - # Acceleration (Spacebar) - All rotors spin faster - if keyboard.Key.space in self.pressed_keys: - self.thrust += self.thrust_delta - self.rpms = [self.thrust] * 4 - print("Accelerating") - - # Deceleration (Left Shift) - All rotors spin slower - if keyboard.Key.shift in self.pressed_keys: - self.thrust -= self.thrust_delta - self.rpms = [self.thrust] * 4 - print("Decelerating") - - # Forward (North) - Front rotors spin faster - if keyboard.Key.up in self.pressed_keys: - self.rpms[0] += self.rotation_delta # Front left - self.rpms[1] += self.rotation_delta # Front right - self.rpms[2] -= self.rotation_delta # Back left - self.rpms[3] -= self.rotation_delta # Back right - print("Moving Forward") - - # Backward (South) - Back rotors spin faster - if keyboard.Key.down in self.pressed_keys: - self.rpms[0] -= self.rotation_delta # Front left - self.rpms[1] -= self.rotation_delta # Front right - self.rpms[2] += self.rotation_delta # Back left - self.rpms[3] += self.rotation_delta # Back right - print("Moving Backward") - - # Left (West) - Left rotors spin faster - if keyboard.Key.left in self.pressed_keys: - self.rpms[0] -= self.rotation_delta # Front left - self.rpms[2] -= self.rotation_delta # Back left - self.rpms[1] += self.rotation_delta # Front right - self.rpms[3] += self.rotation_delta # Back right - print("Moving Left") - - # Right (East) - Right rotors spin faster - if keyboard.Key.right in self.pressed_keys: - self.rpms[0] += self.rotation_delta # Front left - self.rpms[2] += self.rotation_delta # Back left - self.rpms[1] -= self.rotation_delta # Front right - self.rpms[3] -= self.rotation_delta # Back right - print("Moving Right") - - self.rpms = np.clip(self.rpms, 0, 25000) - - # Debug print if any RPMs changed - if not np.array_equal(prev_rpms, self.rpms): - print(f"RPMs changed from {prev_rpms} to {self.rpms}") - - return self.rpms - - -def run_sim(scene, drone, controller): - while controller.running: - # Update drone with current RPMs - rpms = controller.update_thrust() - drone.set_propellels_rpm(rpms) - - # Update physics - scene.step(refresh_visualizer=False) - - # Limit simulation rate - time.sleep(1.0 / scene.viewer.max_FPS) - - if "PYTEST_VERSION" in os.environ: - break + self.cur_dir = np.array([0.0, 0.0, 0.0, 0.0]) # rotor directions + + def update_rpms(self): + """Compute RPMs based on current direction and thrust""" + clipped_dir = np.clip(self.cur_dir, -1.0, 1.0) + rpms = self.thrust + clipped_dir * self.rotation_delta + return np.clip(rpms, 0, 25000) + + def add_direction(self, direction: np.ndarray): + """Add direction vector (on key press)""" + self.cur_dir += direction + + def accelerate(self): + """Increase base thrust""" + self.thrust = min(self.thrust + self.thrust_delta, 25000) + + def decelerate(self): + """Decrease base thrust""" + self.thrust = max(self.thrust - self.thrust_delta, 0) def main(): @@ -129,6 +52,7 @@ def main(): show_world_frame=False, ), show_viewer=True, + show_FPS=False, ) # Add entities @@ -145,14 +69,46 @@ def main(): # Initialize controller controller = DroneController() - # Start keyboard listener. - # Note that instantiating the listener after building the scene causes segfault on MacOS. - listener = keyboard.Listener(on_press=controller.on_press, on_release=controller.on_release) - listener.start() - # Build scene scene.build() + # Register keybindings + def direction_keybinds(name: str, key: Key, direction: tuple[float, float, float, float]): + """Helper to create press/release keybinds for a direction""" + dir_arr = np.array(direction) + return [ + Keybind( + name=f"{name}_press", + key=key, + key_action=KeyAction.PRESS, + callback=controller.add_direction, + args=(dir_arr,), + ), + Keybind( + name=f"{name}_release", + key=key, + key_action=KeyAction.RELEASE, + callback=controller.add_direction, + args=(-dir_arr,), + ), + ] + + is_running = True + + def stop(): + nonlocal is_running + is_running = False + + scene.viewer.register_keybinds( + *direction_keybinds("move_forward", Key.UP, (1.0, 1.0, -1.0, -1.0)), + *direction_keybinds("move_backward", Key.DOWN, (-1.0, -1.0, 1.0, 1.0)), + *direction_keybinds("move_left", Key.LEFT, (-1.0, 1.0, -1.0, 1.0)), + *direction_keybinds("move_right", Key.RIGHT, (1.0, -1.0, 1.0, -1.0)), + Keybind("accelerate", Key.SPACE, KeyAction.HOLD, callback=controller.accelerate), + Keybind("decelerate", Key.LSHIFT, KeyAction.HOLD, callback=controller.decelerate), + Keybind("quit", Key.ESCAPE, KeyAction.PRESS, callback=stop), + ) + # Print control instructions print("\nDrone Controls:") print("↑ - Move Forward (North)") @@ -161,19 +117,23 @@ def main(): print("→ - Move Right (East)") print("space - Increase RPM") print("shift - Decrease RPM") - print("ESC - Quit\n") - print("Initial hover RPM:", controller.thrust) - - # Run simulation in another thread - threading.Thread(target=run_sim, args=(scene, drone, controller)).start() - if "PYTEST_VERSION" not in os.environ: - scene.viewer.run() + # Run simulation try: - listener.stop() - except NotImplementedError: - # Dummy backend does not implement stop - pass + while is_running: + # Update and apply RPMs based on current direction + rpms = controller.update_rpms() + drone.set_propellels_rpm(rpms) + + # Step simulation + scene.step() + + if "PYTEST_VERSION" in os.environ: + break + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") + finally: + gs.logger.info("Simulation finished.") if __name__ == "__main__": diff --git a/examples/keyboard_teleop.py b/examples/keyboard_teleop.py index d741b1f0b0..aef19367dc 100644 --- a/examples/keyboard_teleop.py +++ b/examples/keyboard_teleop.py @@ -11,48 +11,20 @@ u - Reset Scene space - Press to close gripper, release to open gripper esc - Quit + +Plus all default viewer controls (press 'i' to see them) """ import os import random -import threading -import genesis as gs import numpy as np -from pynput import keyboard -from scipy.spatial.transform import Rotation as R - - -class KeyboardDevice: - def __init__(self): - self.pressed_keys = set() - self.lock = threading.Lock() - self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) - - def start(self): - self.listener.start() - - def stop(self): - try: - self.listener.stop() - except NotImplementedError: - # Dummy backend does not implement stop - pass - self.listener.join() - - def on_press(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.add(key) - - def on_release(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.discard(key) - - def get_cmd(self): - return self.pressed_keys +import genesis as gs +import genesis.utils.geom as gu +from genesis.vis.keybindings import Key, KeyAction, Keybind -def build_scene(): +if __name__ == "__main__": ########################## init ########################## gs.init(precision="32", logging_level="info", backend=gs.cpu) np.set_printoptions(precision=7, suppress=True) @@ -80,19 +52,19 @@ def build_scene(): ) ########################## entities ########################## - entities = dict() - entities["plane"] = scene.add_entity( + plane = scene.add_entity( gs.morphs.Plane(), ) - entities["robot"] = scene.add_entity( + robot = scene.add_entity( material=gs.materials.Rigid(gravity_compensation=1), morph=gs.morphs.MJCF( file="xml/franka_emika_panda/panda.xml", euler=(0, 0, 0), ), ) - entities["cube"] = scene.add_entity( + + cube = scene.add_entity( material=gs.materials.Rigid(rho=300), morph=gs.morphs.Box( pos=(0.5, 0.0, 0.07), @@ -101,7 +73,7 @@ def build_scene(): surface=gs.surfaces.Default(color=(0.5, 1, 0.5)), ) - entities["target"] = scene.add_entity( + target = scene.add_entity( gs.morphs.Mesh( file="meshes/axis.obj", scale=0.15, @@ -113,115 +85,90 @@ def build_scene(): ########################## build ########################## scene.build() - return scene, entities - - -def run_sim(scene, entities, clients): - robot = entities["robot"] - target_entity = entities["target"] - + # Initialize robot control state robot_init_pos = np.array([0.5, 0, 0.55]) - robot_init_R = R.from_euler("y", np.pi) - target_pos = robot_init_pos.copy() - target_R = robot_init_R + robot_init_quat = gu.xyz_to_quat(np.array([0, np.pi, 0])) # Rotation around Y axis + # Get DOF indices n_dofs = robot.n_dofs motors_dof = np.arange(n_dofs - 2) fingers_dof = np.arange(n_dofs - 2, n_dofs) ee_link = robot.get_link("hand") - def reset_scene(): - nonlocal target_pos, target_R - target_pos = robot_init_pos.copy() - target_R = robot_init_R - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) - q = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat) + # Initialize target pose + target_pos = robot_init_pos.copy() + target_quat = [robot_init_quat.copy()] # Use list to make it mutable in closures + + # Control parameters + dpos = 0.002 + drot = 0.01 + + # Helper function to reset robot + def reset_robot(): + """Reset robot and cube to initial positions.""" + target_pos[:] = robot_init_pos.copy() + target_quat[0] = robot_init_quat.copy() + target.set_qpos(np.concatenate([target_pos, target_quat[0]])) + q = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat[0]) robot.set_qpos(q[:-2], motors_dof) - entities["cube"].set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) - entities["cube"].set_quat(R.from_euler("z", random.uniform(0, np.pi * 2)).as_quat(scalar_first=True)) - - print("\nKeyboard Controls:") - print("↑\t- Move Forward (North)") - print("↓\t- Move Backward (South)") - print("←\t- Move Left (West)") - print("→\t- Move Right (East)") - print("n\t- Move Up") - print("m\t- Move Down") - print("j\t- Rotate Counterclockwise") - print("k\t- Rotate Clockwise") - print("u\t- Reset Scene") - print("space\t- Press to close gripper, release to open gripper") - print("esc\t- Quit") - - # reset scen before starting teleoperation - reset_scene() - - # start teleoperation - stop = False - while not stop: - pressed_keys = clients["keyboard"].pressed_keys.copy() - - # reset scene: - reset_flag = False - reset_flag |= keyboard.KeyCode.from_char("u") in pressed_keys - if reset_flag: - reset_scene() - - # stop teleoperation - stop = keyboard.Key.esc in pressed_keys - - # get ee target pose - is_close_gripper = False - dpos = 0.002 - drot = 0.01 - for key in pressed_keys: - if key == keyboard.Key.up: - target_pos[0] -= dpos - elif key == keyboard.Key.down: - target_pos[0] += dpos - elif key == keyboard.Key.right: - target_pos[1] += dpos - elif key == keyboard.Key.left: - target_pos[1] -= dpos - elif key == keyboard.KeyCode.from_char("n"): - target_pos[2] += dpos - elif key == keyboard.KeyCode.from_char("m"): - target_pos[2] -= dpos - elif key == keyboard.KeyCode.from_char("j"): - target_R = R.from_euler("z", drot) * target_R - elif key == keyboard.KeyCode.from_char("k"): - target_R = R.from_euler("z", -drot) * target_R - elif key == keyboard.Key.space: - is_close_gripper = True - - # control arm - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) - q, _err = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat, return_error=True) - robot.control_dofs_position(q[:-2], motors_dof) - - # control gripper - if is_close_gripper: - robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) - else: - robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) - - scene.step() - - if "PYTEST_VERSION" in os.environ: - break - - -def main(): - clients = dict() - clients["keyboard"] = KeyboardDevice() - clients["keyboard"].start() - - scene, entities = build_scene() - run_sim(scene, entities, clients) + # Randomize cube position + cube.set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) + random_angle = random.uniform(0, np.pi * 2) + cube.set_quat(gu.xyz_to_quat(np.array([0, 0, random_angle]))) + + # Initialize robot pose + reset_robot() + + # Robot teleoperation callback functions + def move(dpos: tuple[float, float, float]): + target_pos[:] += np.array(dpos, dtype=gs.np_float) + + def rotate(drot: float): + drot_quat = gu.xyz_to_quat(np.array([0, 0, drot])) + target_quat[0] = gu.transform_quat_by_quat(target_quat[0], drot_quat) + + def toggle_gripper(close: bool = True): + pos = -1.0 if close else 1.0 + robot.control_dofs_force(np.array([pos, pos]), fingers_dof) + + is_running = True + + def stop(): + global is_running + is_running = False + + # Register robot teleoperation keybindings + scene.viewer.register_keybinds( + Keybind("move_forward", Key.UP, KeyAction.HOLD, callback=move, args=((-dpos, 0, 0),)), + Keybind("move_back", Key.DOWN, KeyAction.HOLD, callback=move, args=((dpos, 0, 0),)), + Keybind("move_left", Key.LEFT, KeyAction.HOLD, callback=move, args=((0, -dpos, 0),)), + Keybind("move_right", Key.RIGHT, KeyAction.HOLD, callback=move, args=((0, dpos, 0),)), + Keybind("move_up", Key.N, KeyAction.HOLD, callback=move, args=((0, 0, dpos),)), + Keybind("move_down", Key.M, KeyAction.HOLD, callback=move, args=((0, 0, -dpos),)), + Keybind("rotate_ccw", Key.J, KeyAction.HOLD, callback=rotate, args=(drot,)), + Keybind("rotate_cw", Key.K, KeyAction.HOLD, callback=rotate, args=(-drot,)), + Keybind("reset_scene", Key.U, KeyAction.HOLD, callback=reset_robot), + Keybind("close_gripper", Key.SPACE, KeyAction.PRESS, callback=toggle_gripper, args=(True,)), + Keybind("open_gripper", Key.SPACE, KeyAction.RELEASE, callback=toggle_gripper, args=(False,)), + Keybind("quit", Key.ESCAPE, KeyAction.PRESS, callback=stop), + ) + ########################## run simulation ########################## + try: + while is_running: + # Update target entity visualization + target.set_qpos(np.concatenate([target_pos, target_quat[0]])) -if __name__ == "__main__": - main() + # Control arm with inverse kinematics + q, err = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat[0], return_error=True) + robot.control_dofs_position(q[:-2], motors_dof) + + scene.step() + + if "PYTEST_VERSION" in os.environ: + break + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") + finally: + gs.logger.info("Simulation finished.") diff --git a/examples/sensors/lidar_teleop.py b/examples/sensors/lidar_teleop.py index 5336114018..d49a075a5b 100644 --- a/examples/sensors/lidar_teleop.py +++ b/examples/sensors/lidar_teleop.py @@ -1,28 +1,16 @@ import argparse import os -import threading import numpy as np import genesis as gs from genesis.utils.geom import euler_to_quat - -IS_PYNPUT_AVAILABLE = False -try: - from pynput import keyboard - - IS_PYNPUT_AVAILABLE = True -except ImportError: - pass +from genesis.vis.keybindings import Key, KeyAction, Keybind # Position and angle increments for keyboard teleop control KEY_DPOS = 0.1 KEY_DANGLE = 0.1 -# Movement when no keyboard control is available -MOVE_RADIUS = 1.0 -MOVE_RATE = 1.0 / 100.0 - # Number of obstacles to create in a ring around the robot NUM_CYLINDERS = 8 NUM_BOXES = 6 @@ -30,35 +18,6 @@ BOX_RING_RADIUS = 5.0 -class KeyboardDevice: - def __init__(self): - self.pressed_keys = set() - self.lock = threading.Lock() - self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) - - def start(self): - self.listener.start() - - def stop(self): - try: - self.listener.stop() - except NotImplementedError: - # Dummy backend does not implement stop - pass - self.listener.join() - - def on_press(self, key: "keyboard.Key"): - with self.lock: - self.pressed_keys.add(key) - - def on_release(self, key: "keyboard.Key"): - with self.lock: - self.pressed_keys.discard(key) - - def get_cmd(self): - return self.pressed_keys - - def main(): parser = argparse.ArgumentParser(description="Genesis LiDAR/Depth Camera Visualization with Keyboard Teleop") parser.add_argument("-B", "--n_envs", type=int, default=0, help="Number of environments to replicate") @@ -69,12 +28,6 @@ def main(): ) args = parser.parse_args() - if IS_PYNPUT_AVAILABLE: - kb = KeyboardDevice() - kb.start() - else: - print("Keyboard teleop is disabled since pynput is not installed. To install, run `pip install pynput`.") - gs.init(backend=gs.cpu if args.cpu else gs.gpu, precision="32", logging_level="info") scene = gs.Scene( @@ -169,17 +122,7 @@ def main(): scene.build(n_envs=args.n_envs) - if IS_PYNPUT_AVAILABLE: - # Avoid using same keys as interactive viewer keyboard controls - print("Keyboard Controls:") - print("[↑/↓/←/→]: Move XY") - print("[j/k]: Down/Up") - print("[n/m]: Roll CCW/CW") - print("[,/.]: Pitch Up/Down") - print("[o/p]: Yaw CCW/CW") - print("[\\]: Reset") - print("[esc]: Quit") - + # Initialize pose state init_pos = np.array([0.0, 0.0, 0.35], dtype=np.float32) init_euler = np.array([0.0, 0.0, 0.0], dtype=np.float32) @@ -193,48 +136,47 @@ def apply_pose_to_all_envs(pos_np: np.ndarray, quat_np: np.ndarray): robot.set_pos(pos_np) robot.set_quat(quat_np) + # Define control callbacks + def reset_pose(): + target_pos[:] = init_pos + target_euler[:] = init_euler + + def translate(index: int, is_negative: bool): + target_pos[index] += (-1 if is_negative else 1) * KEY_DPOS + + def rotate(index: int, is_negative: bool): + target_euler[index] += (-1 if is_negative else 1) * KEY_DANGLE + + # Register keybindings + scene.viewer.register_keybinds( + Keybind("move_forward", Key.UP, KeyAction.HOLD, callback=translate, args=(0, False)), + Keybind("move_backward", Key.DOWN, KeyAction.HOLD, callback=translate, args=(0, True)), + Keybind("move_right", Key.RIGHT, KeyAction.HOLD, callback=translate, args=(1, True)), + Keybind("move_left", Key.LEFT, KeyAction.HOLD, callback=translate, args=(1, False)), + Keybind("move_down", Key.J, KeyAction.HOLD, callback=translate, args=(2, True)), + Keybind("move_up", Key.K, KeyAction.HOLD, callback=translate, args=(2, False)), + Keybind("roll_ccw", Key.N, KeyAction.HOLD, callback=rotate, args=(0, False)), + Keybind("roll_cw", Key.M, KeyAction.HOLD, callback=rotate, args=(0, True)), + Keybind("pitch_up", Key.COMMA, KeyAction.HOLD, callback=rotate, args=(1, False)), + Keybind("pitch_down", Key.PERIOD, KeyAction.HOLD, callback=rotate, args=(1, True)), + Keybind("yaw_ccw", Key.O, KeyAction.HOLD, callback=rotate, args=(2, False)), + Keybind("yaw_cw", Key.P, KeyAction.HOLD, callback=rotate, args=(2, True)), + Keybind("reset", Key.BACKSLASH, KeyAction.HOLD, callback=reset_pose), + ) + + # Print controls + print("Keyboard Controls:") + print("[↑/↓/←/→]: Move XY") + print("[j/k]: Down/Up") + print("[n/m]: Roll CCW/CW") + print("[,/.]: Pitch Up/Down") + print("[o/p]: Yaw CCW/CW") + print("[\\]: Reset") + apply_pose_to_all_envs(target_pos, euler_to_quat(target_euler)) try: while True: - if IS_PYNPUT_AVAILABLE: - pressed = kb.pressed_keys.copy() - if keyboard.Key.esc in pressed: - break - if keyboard.KeyCode.from_char("\\") in pressed: - target_pos[:] = init_pos - target_euler[:] = init_euler - - if keyboard.Key.up in pressed: - target_pos[0] += KEY_DPOS - if keyboard.Key.down in pressed: - target_pos[0] -= KEY_DPOS - if keyboard.Key.right in pressed: - target_pos[1] -= KEY_DPOS - if keyboard.Key.left in pressed: - target_pos[1] += KEY_DPOS - if keyboard.KeyCode.from_char("j") in pressed: - target_pos[2] -= KEY_DPOS - if keyboard.KeyCode.from_char("k") in pressed: - target_pos[2] += KEY_DPOS - - if keyboard.KeyCode.from_char("n") in pressed: - target_euler[0] += KEY_DANGLE # roll CCW around +X - if keyboard.KeyCode.from_char("m") in pressed: - target_euler[0] -= KEY_DANGLE # roll CW around +X - if keyboard.KeyCode.from_char(",") in pressed: - target_euler[1] += KEY_DANGLE # pitch up around +Y - if keyboard.KeyCode.from_char(".") in pressed: - target_euler[1] -= KEY_DANGLE # pitch down around +Y - if keyboard.KeyCode.from_char("o") in pressed: - target_euler[2] += KEY_DANGLE # yaw CCW around +Z - if keyboard.KeyCode.from_char("p") in pressed: - target_euler[2] -= KEY_DANGLE # yaw CW around +Z - else: - # move in a circle if no keyboard control - target_pos[0] = MOVE_RADIUS * np.cos(scene.t * MOVE_RATE) - target_pos[1] = MOVE_RADIUS * np.sin(scene.t * MOVE_RATE) - apply_pose_to_all_envs(target_pos, euler_to_quat(target_euler)) scene.step() diff --git a/genesis/engine/entities/rigid_entity/rigid_link.py b/genesis/engine/entities/rigid_entity/rigid_link.py index 186f19fd93..f5dcbb2f5a 100644 --- a/genesis/engine/entities/rigid_entity/rigid_link.py +++ b/genesis/engine/entities/rigid_entity/rigid_link.py @@ -17,7 +17,6 @@ from .rigid_entity import RigidEntity from .rigid_joint import RigidJoint from genesis.engine.solvers.rigid.rigid_solver import RigidSolver - from genesis.ext.pyrender.interaction.vec3 import Pose # If mass is too small, we do not care much about spatial inertia discrepancy @@ -784,11 +783,6 @@ def is_built(self) -> bool: def is_free(self): raise DeprecationError("This property has been removed.") - @property - def pose(self) -> "Pose": - """Return the current pose of the link (note, this is not necessarily the same as the principal axes frame).""" - return Pose.from_link(self) - # ------------------------------------------------------------------------------------ # -------------------------------------- repr ---------------------------------------- # ------------------------------------------------------------------------------------ diff --git a/genesis/ext/pyrender/constants.py b/genesis/ext/pyrender/constants.py index 26a42e042c..5e90ff45fd 100644 --- a/genesis/ext/pyrender/constants.py +++ b/genesis/ext/pyrender/constants.py @@ -10,6 +10,7 @@ UINT_SZ = 4 # Byte size of GL uint32 SHADOW_TEX_SZ = 8192 # Width and Height of Shadow Textures TEXT_PADDING = 20 # Width of padding for rendering text (px) +FONT_SIZE = 26 # Default font size for rendering text (px) # Flags for render type diff --git a/genesis/ext/pyrender/interaction/aabb.py b/genesis/ext/pyrender/interaction/aabb.py deleted file mode 100644 index 8f22548f0d..0000000000 --- a/genesis/ext/pyrender/interaction/aabb.py +++ /dev/null @@ -1,96 +0,0 @@ -from dataclasses import dataclass - -import numpy as np - -from .ray import Ray, RayHit, EPSILON -from .vec3 import Pose, Vec3 - - -class AABB: - v: "np.typing.NDArray[np.float32]" - - def __init__(self, v: "np.typing.NDArray[np.float32]"): - assert v.shape == (2, 3), f"Aabb must be initialized with a (2,3)-element array, got {v.shape}" - assert v.dtype == np.float32, f"Aabb must be initialized with a float32 array, got {v.dtype}" - self.v = v - - @property - def min(self) -> Vec3: - return Vec3(self.v[0]) - - @property - def max(self) -> Vec3: - return Vec3(self.v[1]) - - @property - def extents(self) -> Vec3: - return self.max - self.min - - def expand(self, padding: float) -> None: - self.v[0] -= padding - self.v[1] += padding - - def raycast(self, ray: Ray) -> RayHit: - """ - Standard AABB slab implementation. Early-exits and returns no-hit for rays withing the XY, XZ, or YZ planes. - Ignores hits for rays originating inside the AABB. - """ - if (np.abs(ray.direction.v) < EPSILON).any(): - # unhandled ray case: early-exit - return RayHit.no_hit() - - tmin = (self.v[0] - ray.origin.v) / ray.direction.v - tmax = (self.v[1] - ray.origin.v) / ray.direction.v - mmin = np.minimum(tmin, tmax) - mmax = np.maximum(tmin, tmax) - min_idx = np.argmax(mmin) - max_idx = np.argmin(mmax) - tnear = mmin[min_idx] - tfar = mmax[max_idx] - - # Drop hits coming from inside - if tfar < tnear or tnear < 0: # tfar < 0 - return RayHit.no_hit() - - # Calculate enter point and normal - enter = tnear # if 0 <= tnear else tfar - normal = Vec3.zero() - normal.v[min_idx] = -np.sign(ray.direction.v[min_idx]) - - hit_pos = ray.origin + ray.direction * enter - return RayHit(enter, hit_pos, normal) - - def __repr__(self) -> str: - return f"AABB: Min({self.min.x}, {self.min.y}, {self.min.z}) Max({self.max.x}, {self.max.y}, {self.max.z})" - - @classmethod - def from_min_max(cls, min: Vec3, max: Vec3) -> "AABB": - bounds = np.stack((min.v, max.v), axis=0) - return cls(bounds) - - @classmethod - def from_center_and_half_extents(cls, center: Vec3, half_extents: Vec3) -> "AABB": - min = center - half_extents - max = center + half_extents - bounds = np.stack((min.v, max.v), axis=0) - return cls(bounds) - - -@dataclass -class OBB: - pose: Pose - half_extents: Vec3 - - def raycast(self, ray: Ray) -> RayHit: - origin2 = self.pose.inverse_transform_point(ray.origin) - direction2 = self.pose.inverse_transform_direction(ray.direction) - ray2 = Ray(origin2, direction2) - aabb = AABB.from_center_and_half_extents(Vec3.zero(), self.half_extents) - ray_hit = aabb.raycast(ray2) - if ray_hit.is_hit: - ray_hit.position = self.pose.transform_point(ray_hit.position) - ray_hit.normal = self.pose.transform_direction(ray_hit.normal) - return ray_hit - - def __repr__(self) -> str: - return f"OBB(pose={self.pose}, half_extents={self.half_extents})" diff --git a/genesis/ext/pyrender/interaction/mouse_spring.py b/genesis/ext/pyrender/interaction/mouse_spring.py deleted file mode 100644 index f7b53a2d33..0000000000 --- a/genesis/ext/pyrender/interaction/mouse_spring.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -from .vec3 import Pose, Quat, Vec3 - -if TYPE_CHECKING: - from genesis.engine.entities.rigid_entity.rigid_link import RigidLink - - -MOUSE_SPRING_POSITION_CORRECTION_FACTOR = 1.0 -MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR = 1.0 - - -class MouseSpring: - def __init__(self) -> None: - self.held_link: "RigidLink | None" = None - self.held_point_in_local: Vec3 | None = None - self.prev_control_point: Vec3 | None = None - - def attach(self, picked_link: "RigidLink", control_point: Vec3) -> None: - # for now, we just pick the first geometry - self.held_link = picked_link - pose: Pose = Pose.from_link(self.held_link) - self.held_point_in_local = pose.inverse_transform_point(control_point) - self.prev_control_point = control_point - - def detach(self) -> None: - self.held_link = None - - def apply_force(self, control_point: Vec3, delta_time: float) -> None: - # note when threaded: apply_force is called before attach! - # note2: that was before we added a lock to ViewerInteraction; this migth be fixed now - if not self.held_link: - return - - self.prev_control_point = control_point - - # do simple force on COM only: - link: "RigidLink" = self.held_link - lin_vel: Vec3 = Vec3.from_tensor(link.get_vel()) - ang_vel: Vec3 = Vec3.from_tensor(link.get_ang()) - link_pose: Pose = Pose.from_link(link) - held_point_in_world: Vec3 = link_pose.transform_point(self.held_point_in_local) - - # note: we should assert earlier that link inertial_pos/quat are not None - # todo: verify inertial_pos/quat are stored in local frame - link_T_principal: Pose = Pose(Vec3.from_arraylike(link.inertial_pos), Quat.from_arraylike(link.inertial_quat)) - world_T_principal: Pose = link_pose * link_T_principal - - # for non-spherical inertia - arm_in_principal: Vec3 = link_T_principal.inverse_transform_point(self.held_point_in_local) - arm_in_world: Vec3 = world_T_principal.rot * arm_in_principal # for spherical inertia - - pos_err_v: Vec3 = control_point - held_point_in_world - inv_mass: float = float(1.0 / link.get_mass() if link.get_mass() > 0.0 else 0.0) - inv_spherical_inertia: float = float(1.0 / link.inertial_i[0, 0] if link.inertial_i[0, 0] > 0.0 else 0.0) - - inv_dt: float = 1.0 / delta_time - tau: float = MOUSE_SPRING_POSITION_CORRECTION_FACTOR - damp: float = MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR - - total_impulse: Vec3 = Vec3.zero() - total_torque_impulse: Vec3 = Vec3.zero() - - for i in range(3 * 4): - body_point_vel: Vec3 = lin_vel + ang_vel.cross(arm_in_world) - vel_err_v: Vec3 = Vec3.zero() - body_point_vel - - dir: Vec3 = Vec3.zero() - dir.v[i % 3] = 1.0 - pos_err: float = dir.dot(pos_err_v) - vel_err: float = dir.dot(vel_err_v) - error: float = tau * pos_err * inv_dt + damp * vel_err - - arm_x_dir: Vec3 = arm_in_world.cross(dir) - virtual_mass: float = 1.0 / (inv_mass + arm_x_dir.sqr_magnitude() * inv_spherical_inertia + 1e-24) - impulse: float = error * virtual_mass - - lin_vel += impulse * inv_mass * dir - ang_vel += impulse * inv_spherical_inertia * arm_x_dir - total_impulse.v[i % 3] += impulse - total_torque_impulse += impulse * arm_x_dir - - # Apply the new force - total_force = total_impulse * inv_dt - total_torque = total_torque_impulse * inv_dt - force_tensor: torch.Tensor = total_force.as_tensor()[None] - torque_tensor: torch.Tensor = total_torque.as_tensor()[None] - link.solver.apply_links_external_force(force_tensor, (link.idx,), ref="link_com", local=False) - link.solver.apply_links_external_torque(torque_tensor, (link.idx,), ref="link_com", local=False) - - @property - def is_attached(self) -> bool: - return self.held_link is not None diff --git a/genesis/ext/pyrender/interaction/ray.py b/genesis/ext/pyrender/interaction/ray.py deleted file mode 100644 index b571731ac0..0000000000 --- a/genesis/ext/pyrender/interaction/ray.py +++ /dev/null @@ -1,62 +0,0 @@ -import sys -from dataclasses import dataclass -from typing import TYPE_CHECKING - -from .vec3 import Vec3 - -if TYPE_CHECKING: - from genesis.engine.entities.rigid_entity.rigid_geom import RigidGeom - - -EPSILON = 1e-6 -EPSILON2 = EPSILON * EPSILON - -_MAX_RAY_DISTANCE = sys.float_info.max - - -class Ray: - origin: Vec3 - direction: Vec3 - - def __init__(self, origin: Vec3, direction: Vec3): - self.origin = origin - self.direction = direction.normalized() - - def __repr__(self) -> str: - return f"Ray(origin={self.origin}, direction={self.direction})" - - -@dataclass -class RayHit: - distance: float - position: Vec3 - normal: Vec3 - geom: "RigidGeom | None" = None - - @property - def is_hit(self) -> bool: - assert 0.0 <= self.distance - return self.distance < _MAX_RAY_DISTANCE - - @classmethod - def no_hit(cls) -> "RayHit": - return RayHit(_MAX_RAY_DISTANCE, Vec3.zero(), Vec3.zero(), None) - - -class Plane: - normal: Vec3 - distance: float # distance from plane to origin along normal - - def __init__(self, normal: Vec3, point: Vec3): - self.normal = normal - self.distance = -normal.dot(point) - - def raycast(self, ray: Ray) -> RayHit: - dot = ray.direction.dot(self.normal) - dist = ray.origin.dot(self.normal) + self.distance - - if -EPSILON < dot or dist < EPSILON: - return RayHit.no_hit() - else: - dist_along_ray = dist / -dot - return RayHit(dist_along_ray, ray.origin + ray.direction * dist_along_ray, self.normal, None) diff --git a/genesis/ext/pyrender/interaction/vec3.py b/genesis/ext/pyrender/interaction/vec3.py deleted file mode 100644 index 7fe61a490e..0000000000 --- a/genesis/ext/pyrender/interaction/vec3.py +++ /dev/null @@ -1,286 +0,0 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Union - -import numpy as np -import torch - -from genesis.utils.misc import tensor_to_array - -if TYPE_CHECKING: - from genesis.engine.entities.rigid_entity import RigidGeom, RigidLink - - -class Vec3: - """ - Use this wrapper around np.array if you want to ensure adherence to float32 arithmethic - with runtime checks, and avoid hidden and costly conversions between float32 and float64. - - This also makes vector dimensionality explicit for linting and static analysis. - """ - - v: "np.typing.NDArray[np.float32]" - - def __init__(self, v: "np.typing.NDArray[np.float32]"): - assert v.shape == (3,), f"Vec3 must be initialized with a 3-element array, got {v.shape}" - assert v.dtype == np.float32, f"Vec3 must be initialized with a float32 array, got {v.dtype}" - self.v = v - - def __add__(self, other: "Vec3") -> "Vec3": - return Vec3(self.v + other.v) - - def __sub__(self, other: "Vec3") -> "Vec3": - return Vec3(self.v - other.v) - - def __mul__(self, other: float) -> "Vec3": - return Vec3(self.v * np.float32(other)) - - def __rmul__(self, other: float) -> "Vec3": - return Vec3(self.v * np.float32(other)) - - def __neg__(self) -> "Vec3": - return Vec3(-self.v) - - def dot(self, other: "Vec3") -> float: - return np.dot(self.v, other.v).item() - - def cross(self, other: "Vec3") -> "Vec3": - return Vec3(np.cross(self.v, other.v)) - - def normalized(self) -> "Vec3": - return Vec3(self.v / (np.linalg.norm(self.v) + 1e-24)) - - def magnitude(self) -> float: - return np.linalg.norm(self.v) - - def sqr_magnitude(self) -> float: - return np.dot(self.v, self.v) - - def copy(self) -> "Vec3": - return Vec3(self.v.copy()) - - def __repr__(self) -> str: - return f"Vec3({self.v[0]}, {self.v[1]}, {self.v[2]})" - - def as_tensor(self) -> torch.Tensor: - return torch.tensor(self.v) - - @property - def x(self) -> float: - return self.v[0] - - @property - def y(self) -> float: - return self.v[1] - - @property - def z(self) -> float: - return self.v[2] - - @classmethod - def from_xyz(cls, x: float, y: float, z: float) -> "Vec3": - return cls(np.array([x, y, z], dtype=np.float32)) - - @classmethod - def from_array(cls, v: np.ndarray) -> "Vec3": - assert v.shape == (3,), f"Vec3 must be initialized with a 3-element array, got {v.shape}" - assert v.dtype == np.int32 or v.dtype == np.int64 or v.dtype == np.float32 or v.dtype == np.float64, ( - f"from_array must be initialized with a array of ints/floats 32/64-bit, got {v.dtype}" - ) - return cls.from_xyz(*v) - - @classmethod - def from_tensor(cls, v: torch.Tensor) -> "Vec3": - array: np.ndarray = tensor_to_array(v) - return cls.from_array(array) - - @classmethod - def from_arraylike(cls, v: "np.typing.ArrayLike") -> "Vec3": - if isinstance(v, np.ndarray): - return cls.from_array(v) - elif isinstance(v, torch.Tensor): - return cls.from_tensor(v) - elif isinstance(v, np.typing.ArrayLike): - assert len(v) == 3, f"Vec3 must be initialized with a 3-element ArrayLike, got {len(v)}" - return cls.from_xyz(*v) - assert False - - @classmethod - def zero(cls) -> "Vec3": - return cls(np.array([0, 0, 0], dtype=np.float32)) - - @classmethod - def one(cls) -> "Vec3": - return cls(np.array([1, 1, 1], dtype=np.float32)) - - @classmethod - def full(cls, fill_value: float) -> "Vec3": - return cls(np.full((3,), fill_value, dtype=np.float32)) - - -class Quat: - v: "np.typing.NDArray[np.float32]" - - def __init__(self, v: "np.typing.NDArray[np.float32]"): - assert v.shape == (4,), f"Quat must be initialized with a 4-element array, got {v.shape}" - assert v.dtype == np.float32, f"Quat must be initialized with a float32 array, got {v.dtype}" - self.v = v - - def get_inverse(self) -> "Quat": - quat_inv = self.v.copy() - quat_inv[1:] *= -1 - return Quat(quat_inv) - - def __mul__(self, other: Union["Quat", Vec3]) -> Union["Quat", Vec3]: - if isinstance(other, Quat): - # Quaternion * Quaternion - w1, x1, y1, z1 = self.w, self.x, self.y, self.z - w2, x2, y2, z2 = other.w, other.x, other.y, other.z - return Quat.from_wxyz( - w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, - w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, - w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, - w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, - ) - elif isinstance(other, Vec3): # (other, np.ndarray) and other.shape == (3,): - # Quaternion * Vector3 -> rotate vector - v_quat = Quat.from_wxyz(0, *other.v) - result = self * v_quat * self.get_inverse() - return Vec3(result.v[1:]) - else: - return NotImplemented - - def copy(self) -> "Quat": - return Quat(self.v.copy()) - - def __repr__(self) -> str: - return f"Quat({self.v[0]}, {self.v[1]}, {self.v[2]}, {self.v[3]})" - - def as_tensor(self) -> torch.Tensor: - return torch.tensor(self.v) - - @property - def w(self) -> float: - return self.v[0] - - @property - def x(self) -> float: - return self.v[1] - - @property - def y(self) -> float: - return self.v[2] - - @property - def z(self) -> float: - return self.v[3] - - @classmethod - def from_wxyz(cls, w: float, x: float, y: float, z: float) -> "Quat": - return cls(np.array([w, x, y, z], dtype=np.float32)) - - @classmethod - def from_array(cls, v: np.ndarray) -> "Quat": - assert v.shape == (4,), f"Quat must be initialized with a 4-element array, got {v.shape}" - return cls.from_wxyz(*v) - - @classmethod - def from_tensor(cls, v: torch.Tensor) -> "Quat": - array: np.ndarray = tensor_to_array(v) - return cls.from_array(array) - - @classmethod - def from_arraylike(cls, v: "np.typing.ArrayLike") -> "Quat": - if isinstance(v, np.ndarray): - return cls.from_array(v) - elif isinstance(v, torch.Tensor): - return cls.from_tensor(v) - elif isinstance(v, np.typing.ArrayLike): - assert len(v) == 4, f"Quat must be initialized with a 4-element ArrayLike, got {len(v)}" - return cls.from_wxyz(*v) - assert False - - -@dataclass -class Pose: - pos: Vec3 - rot: Quat - - # todo: consider using a single np.array with views - - def transform_point(self, point: Vec3) -> Vec3: - return self.pos + self.rot * point - - def inverse_transform_point(self, point: Vec3) -> Vec3: - return self.rot.get_inverse() * (point - self.pos) - - def transform_direction(self, direction: Vec3) -> Vec3: - return self.rot * direction - - def inverse_transform_direction(self, direction: Vec3) -> Vec3: - return self.rot.get_inverse() * direction - - def get_inverse(self) -> "Pose": - inv_rot = self.rot.get_inverse() - # inv_pos = -1.0 * (inv_rot * self.pos) - # faster -- avoid repeated quat inversion: - pos_quat = Quat.from_wxyz(0, *self.pos.v) - inv_pos = inv_rot * pos_quat * self.rot - inv_pos = Vec3(-inv_pos.v[1:]) - return Pose(inv_pos, inv_rot) - - def __mul__(self, other: Union["Pose", Vec3]) -> Union["Pose", Vec3]: - if isinstance(other, Pose): - return Pose(self.pos + self.rot * other.pos, self.rot * other.rot) - elif isinstance(other, Vec3): - return self.pos + self.rot * other - else: - return NotImplemented - - def __repr__(self) -> str: - return f"Pose(pos={self.pos}, rot={self.rot})" - - @classmethod - def from_geom(cls, geom: "RigidGeom") -> "Pose": - assert geom._solver.n_envs == 0, "ViewerInteraction only supports single-env for now" - # geom.get_pos() and .get_quat() are squeezed if n_envs == 0 - pos = Vec3.from_tensor(geom.get_pos()) - quat = Quat.from_tensor(geom.get_quat()) - return Pose(pos, quat) - - @classmethod - def from_link(cls, link: "RigidLink") -> "Pose": - assert link._solver.n_envs == 0, "ViewerInteraction only supports single-env for now" - # geom.get_pos() and .get_quat() are squeezed if n_envs == 0 - pos = Vec3.from_tensor(link.get_pos()) - quat = Quat.from_tensor(link.get_quat()) - return Pose(pos, quat) - - -@dataclass -class Color: - r: float - g: float - b: float - a: float - - def tuple(self) -> tuple[float, float, float, float]: - return (self.r, self.g, self.b, self.a) - - def with_alpha(self, alpha: float) -> "Color": - return Color(self.r, self.g, self.b, alpha) - - @classmethod - def red(cls) -> "Color": - return cls(1.0, 0.0, 0.0, 1.0) - - @classmethod - def green(cls) -> "Color": - return cls(0.0, 1.0, 0.0, 1.0) - - @classmethod - def blue(cls) -> "Color": - return cls(0.0, 0.0, 1.0, 1.0) - - @classmethod - def yellow(cls) -> "Color": - return cls(1.0, 1.0, 0.0, 1.0) diff --git a/genesis/ext/pyrender/interaction/viewer_interaction.py b/genesis/ext/pyrender/interaction/viewer_interaction.py deleted file mode 100644 index c801baea08..0000000000 --- a/genesis/ext/pyrender/interaction/viewer_interaction.py +++ /dev/null @@ -1,260 +0,0 @@ -from typing import TYPE_CHECKING, cast -from typing_extensions import override # Made it into standard lib from Python 3.12 -from threading import Lock as threading_Lock - -import numpy as np - -import genesis as gs - -from .aabb import AABB, OBB -from .mouse_spring import MouseSpring -from .ray import Plane, Ray, RayHit -from .vec3 import Pose, Vec3, Color -from .viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE, EVENT_HANDLED - -if TYPE_CHECKING: - from genesis.engine.entities.rigid_entity import RigidGeom, RigidLink, RigidEntity - from genesis.engine.scene import Scene - from genesis.ext.pyrender.node import Node - - -class ViewerInteraction(ViewerInteractionBase): - """Functionalities to be implemented: - - mouse picking - - mouse dragging - """ - - def __init__( - self, - camera: "Node", - scene: "Scene", - viewport_size: tuple[int, int], - camera_yfov: float, - log_events: bool = False, - camera_fov: float = 60.0, - ) -> None: - super().__init__(log_events) - self.camera: "Node" = camera - self.scene: "Scene" = scene - self.viewport_size: tuple[int, int] = viewport_size - self.camera_yfov: float = camera_yfov - - self.tan_half_fov: float = np.tan(0.5 * self.camera_yfov) - self.prev_mouse_pos: tuple[int, int] = (viewport_size[0] // 2, viewport_size[1] // 2) - - self.picked_link: RigidLink | None = None - self.picked_point_in_local: Vec3 | None = None - self.mouse_drag_plane: Plane | None = None - self.prev_mouse_3d_pos: Vec3 | None = None - - self.mouse_spring: MouseSpring = MouseSpring() - self.lock = threading_Lock() - - @override - def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: - super().on_mouse_motion(x, y, dx, dy) - self.prev_mouse_pos = (x, y) - - @override - def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: - super().on_mouse_drag(x, y, dx, dy, buttons, modifiers) - self.prev_mouse_pos = (x, y) - if self.picked_link: - # actual processing done in update_on_sim_step() - - return EVENT_HANDLED - - @override - def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: - super().on_mouse_press(x, y, button, modifiers) - if button == 1: # left mouse button - ray_hit = self.raycast_against_entities(self.screen_position_to_ray(x, y)) - with self.lock: - if ray_hit.geom: - self.picked_link = ray_hit.geom.link - assert self.picked_link is not None - - temp_fwd = self.get_camera_forward() - temp_back = -temp_fwd - - self.mouse_drag_plane = Plane(temp_back, ray_hit.position) - self.prev_mouse_3d_pos = ray_hit.position - - pose: Pose = Pose.from_link(self.picked_link) - self.picked_point_in_local = pose.inverse_transform_point(ray_hit.position) - - self.mouse_spring.attach(self.picked_link, ray_hit.position) - - @override - def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: - super().on_mouse_release(x, y, button, modifiers) - if button == 1: # left mouse button - with self.lock: - self.picked_link = None - self.picked_point_in_local = None - self.mouse_drag_plane = None - self.prev_mouse_3d_pos = None - - self.mouse_spring.detach() - - @override - def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: - super().on_resize(width, height) - self.viewport_size = (width, height) - self.tan_half_fov = np.tan(0.5 * self.camera_yfov) - - @override - def update_on_sim_step(self) -> None: - with self.lock: - if self.picked_link: - mouse_ray: Ray = self.screen_position_to_ray(*self.prev_mouse_pos) - ray_hit: RayHit = self.mouse_drag_plane.raycast(mouse_ray) - assert ray_hit.is_hit - if ray_hit.is_hit: - new_mouse_3d_pos: Vec3 = ray_hit.position - delta_3d_pos: Vec3 = new_mouse_3d_pos - self.prev_mouse_3d_pos - self.prev_mouse_3d_pos = new_mouse_3d_pos - - use_force: bool = True - if use_force: - # apply force - self.mouse_spring.apply_force(new_mouse_3d_pos, self.scene.sim.dt) - else: - # apply displacement - pos = Vec3.from_tensor(self.picked_link.entity.get_pos()) - pos += delta_3d_pos - self.picked_link.entity.set_pos(pos.as_tensor()) - - @override - def on_draw(self) -> None: - super().on_draw() - if self.scene._visualizer is not None and self.scene._visualizer.is_built: - self.scene.clear_debug_objects() - mouse_ray: Ray = self.screen_position_to_ray(*self.prev_mouse_pos) - - closest_hit = self.raycast_against_entities(mouse_ray) - if not closest_hit.is_hit: - closest_hit = self._raycast_against_ground_plane(mouse_ray) - - with self.lock: - if self.picked_link: - assert self.mouse_drag_plane is not None - assert self.picked_point_in_local is not None - - # draw held point - pose: Pose = Pose.from_link(self.picked_link) - held_point: Vec3 = pose.transform_point(self.picked_point_in_local) - self.scene.draw_debug_sphere(held_point.v, 0.02, Color.red().tuple()) - - plane_hit: RayHit = self.mouse_drag_plane.raycast(mouse_ray) - if plane_hit.is_hit: - self.scene.draw_debug_sphere(plane_hit.position.v, 0.02, Color.red().tuple()) - self.scene.draw_debug_line(held_point.v, plane_hit.position.v, color=Color.red().tuple()) - else: - if closest_hit.is_hit: - self.scene.draw_debug_sphere(closest_hit.position.v, 0.01, (0, 1, 0, 1)) - self._draw_arrow(closest_hit.position, 0.25 * closest_hit.normal, (0, 1, 0, 1)) - if closest_hit.geom: - self._draw_entity_unrotated_obb(closest_hit.geom) - - def screen_position_to_ray(self, x: float, y: float) -> Ray: - # convert screen position to ray - if True: - x = x - 0.5 * self.viewport_size[0] - y = y - 0.5 * self.viewport_size[1] - x = 2.0 * x / self.viewport_size[1] * self.tan_half_fov - y = 2.0 * y / self.viewport_size[1] * self.tan_half_fov - else: - # alternative way - projection_matrix = self.camera.camera.get_projection_matrix(*self.viewport_size) - x = x - 0.5 * self.viewport_size[0] - y = y - 0.5 * self.viewport_size[1] - x = 2.0 * x / self.viewport_size[0] / projection_matrix[0, 0] - y = 2.0 * y / self.viewport_size[1] / projection_matrix[1, 1] - - # Note: ignoring pixel aspect ratio - - mtx = self.camera.matrix - position = Vec3.from_array(mtx[:3, 3]) - forward = Vec3.from_array(-mtx[:3, 2]) - right = Vec3.from_array(mtx[:3, 0]) - up = Vec3.from_array(mtx[:3, 1]) - - direction = forward + right * x + up * y - return Ray(position, direction) - - def get_camera_forward(self) -> Vec3: - mtx = self.camera.matrix - return Vec3.from_array(-mtx[:3, 2]) - - def get_camera_ray(self) -> Ray: - mtx = self.camera.matrix - position = Vec3.from_array(mtx[:3, 3]) - forward = Vec3.from_array(-mtx[:3, 2]) - return Ray(position, forward) - - def _raycast_against_ground_plane(self, ray: Ray) -> RayHit: - ground_plane = Plane(Vec3.from_xyz(0, 0, 1), Vec3.zero()) - return ground_plane.raycast(ray) - - def raycast_against_entity_obb(self, entity: "RigidEntity", ray: Ray) -> RayHit: - if isinstance(entity.morph, gs.morphs.Box): - obb: OBB = self._get_box_obb(entity) - ray_hit = obb.raycast(ray) - if ray_hit.is_hit: - ray_hit.geom = entity.geoms[0] - return ray_hit - elif isinstance(entity.morph, gs.morphs.Plane): - # ignore plane - return RayHit.no_hit() - else: - closest_hit = RayHit.no_hit() - for link in entity.links: - if not link.is_fixed: - for geom in link.geoms: - obb: OBB = self._get_geom_placeholder_obb(geom) - ray_hit = obb.raycast(ray) - if ray_hit.distance < closest_hit.distance: - ray_hit.geom = geom - closest_hit = ray_hit - return closest_hit - - def raycast_against_entities(self, ray: Ray) -> RayHit: - closest_hit = RayHit.no_hit() - for entity in self.scene.sim.rigid_solver.entities: - rigid_entity: "RigidEntity" = cast("RigidEntity", entity) - ray_hit = self.raycast_against_entity_obb(rigid_entity, ray) - if ray_hit.distance < closest_hit.distance: - closest_hit = ray_hit - return closest_hit - - def _get_box_obb(self, box_entity: "RigidEntity") -> OBB: - box: gs.morphs.Box = box_entity.morph - pose = Pose.from_link(box_entity.links[0]) - half_extents = 0.5 * Vec3.from_xyz(*box.size) - return OBB(pose, half_extents) - - def _get_geom_placeholder_obb(self, geom: "RigidGeom") -> OBB: - pose = Pose.from_geom(geom) - half_extents = Vec3.full(0.5 * 0.125) - return OBB(pose, half_extents) - - def _draw_arrow( - self, - pos: Vec3, - dir: Vec3, - color: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0), - ) -> None: - self.scene.draw_debug_arrow(pos.v, dir.v, color=color) - - def _draw_entity_unrotated_obb(self, geom: "RigidGeom") -> None: - obb: OBB | None = None - if isinstance(geom.entity.morph, gs.morphs.Box): - obb = self._get_box_obb(geom.entity) - else: - obb = self._get_geom_placeholder_obb(geom) - - if obb: - aabb: AABB = AABB.from_center_and_half_extents(obb.pose.pos, obb.half_extents) - aabb.expand(padding=0.01) - self.scene.draw_debug_box(aabb.v, color=Color.red().with_alpha(0.5).tuple(), wireframe=False) diff --git a/genesis/ext/pyrender/interaction/viewer_interaction_base.py b/genesis/ext/pyrender/interaction/viewer_interaction_base.py deleted file mode 100644 index 9be9f545f2..0000000000 --- a/genesis/ext/pyrender/interaction/viewer_interaction_base.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Union, Literal - -import genesis as gs - - -EVENT_HANDLE_STATE = Union[Literal[True], None] -EVENT_HANDLED: Literal[True] = True - -# Note: Viewer window is based on pyglet.window.Window, mouse events are defined in pyglet.window.BaseWindow - - -class ViewerInteractionBase: - """Base class for handling pyglet.window.Window events.""" - - log_events: bool - - def __init__(self, log_events: bool = False): - self.log_events = log_events - - def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse moved to {x}, {y}") - - def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse dragged to {x}, {y}") - - def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse buttons {button} pressed at {x}, {y}") - - def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse buttons {button} released at {x}, {y}") - - def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Key pressed: {chr(symbol)}") - - def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Key released: {chr(symbol)}") - - def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Window resized to {width}x{height}") - - def update_on_sim_step(self) -> None: - pass - - def on_draw(self) -> None: - pass diff --git a/genesis/ext/pyrender/viewer.py b/genesis/ext/pyrender/viewer.py index bba4bf014a..9d3ae7ec8b 100644 --- a/genesis/ext/pyrender/viewer.py +++ b/genesis/ext/pyrender/viewer.py @@ -1,20 +1,21 @@ """A pyglet-based interactive 3D scene viewer.""" import copy -from contextlib import nullcontext import os import shutil import sys -import time import threading +import time +from contextlib import nullcontext from threading import Event, RLock, Semaphore, Thread -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import numpy as np import OpenGL from OpenGL.GL import * import genesis as gs +from genesis.vis.keybindings import Key, KeyAction, Keybind, Keybindings, KeyMod # Importing tkinter and creating a first context before importing pyglet is necessary to avoid later segfault on MacOS. # Note that destroying the window will cause segfault at exit. @@ -31,11 +32,14 @@ import pyglet +from genesis.vis.viewer_plugins import EVENT_HANDLE_STATE, EVENT_HANDLED, ViewerPlugin + from .camera import IntrinsicsCamera, OrthographicCamera, PerspectiveCamera from .constants import ( DEFAULT_SCENE_SCALE, DEFAULT_Z_FAR, DEFAULT_Z_NEAR, + FONT_SIZE, MIN_OPEN_GL_MAJOR, MIN_OPEN_GL_MINOR, TARGET_OPEN_GL_MAJOR, @@ -44,8 +48,6 @@ RenderFlags, TextAlign, ) -from .interaction.viewer_interaction import ViewerInteraction -from .interaction.viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE, EVENT_HANDLED from .light import DirectionalLight from .node import Node from .renderer import Renderer @@ -61,6 +63,9 @@ MODULE_DIR = os.path.dirname(__file__) +HELP_TEXT_KEY = Key.I +HELP_TEXT_KEYBIND_NAME = "toggle_instructions" + class Viewer(pyglet.window.Window): """An interactive viewer for 3D scenes. @@ -80,17 +85,7 @@ class Viewer(pyglet.window.Window): viewer_flags : dict A set of flags for controlling the viewer's behavior. Described in the note below. - registered_keys : dict - A map from ASCII key characters to tuples containing: - - - A function to be called whenever the key is pressed, - whose first argument will be the viewer itself. - - (Optionally) A list of additional positional arguments - to be passed to the function. - - (Optionally) A dict of keyword arguments to be passed - to the function. - - kwargs : dict + **kwargs : dict Any keyword arguments left over will be interpreted as belonging to either the :attr:`.Viewer.render_flags` or :attr:`.Viewer.viewer_flags` dictionaries. Those flag sets will be updated appropriately. @@ -199,14 +194,13 @@ def __init__( viewport_size=None, render_flags=None, viewer_flags=None, - registered_keys=None, run_in_thread=False, auto_start=True, shadow=False, plane_reflection=False, env_separate_rigid=False, - enable_interaction=False, - disable_keyboard_shortcuts=False, + plugins=None, + disable_help_text=False, **kwargs, ): ####################################################################### @@ -231,7 +225,6 @@ def __init__( self._offscreen_semaphore = Semaphore(0) self._offscreen_result = None - self._video_saver = None self._video_recorder = None self._default_render_flags = { @@ -282,42 +275,13 @@ def __init__( elif key in self.viewer_flags: self._viewer_flags[key] = kwargs[key] - self._registered_keys = {} - if registered_keys is not None: - self._registered_keys = {ord(k.lower()): registered_keys[k] for k in registered_keys} - - self._disable_keyboard_shortcuts = disable_keyboard_shortcuts + self._keybindings: Keybindings = Keybindings() + self._held_keys: dict[tuple[int, int], bool] = {} ####################################################################### # Save internal settings ####################################################################### - # Set up caption stuff - self._message_text = None - self._ticks_till_fade = 2.0 / 3.0 * self.viewer_flags["refresh_rate"] - self._message_opac = 1.0 + self._ticks_till_fade - - self._display_instr = False - - self._instr_texts = [ - ["> [i]: show keyboard instructions"], - [ - "< [i]: hide keyboard instructions", - " [r]: record video", - " [s]: save image", - " [z]: reset camera", - " [a]: camera rotation", - " [h]: shadow", - " [f]: face normal", - " [v]: vertex normal", - " [w]: world frame", - " [l]: link frame", - " [d]: wireframe", - " [c]: camera & frustrum", - " [F11]: full-screen mode", - ], - ] - # Set up raymond lights and direct lights self._raymond_lights = self._create_raymond_lights() self._direct_light = self._create_direct_light() @@ -379,15 +343,25 @@ def __init__( self.scene.main_camera_node = self._camera_node self._reset_view() - # Setup mouse interaction + # Setup help text functionality + self._disable_help_text = disable_help_text + if not self._disable_help_text: + self._collapse_instructions = True + instr_key_str = str(Key(HELP_TEXT_KEY)) + self._instr_texts: tuple[list[str], list[str]] = ( + [f"> [{instr_key_str}]: show keyboard instructions"], + [f"< [{instr_key_str}]: hide keyboard instructions"], + ) + self._key_instr_texts: list[str] = [] + self._message_text = None + self._ticks_till_fade = 2.0 / 3.0 * self.viewer_flags["refresh_rate"] + self._message_opac = 1.0 + self._ticks_till_fade + self.register_keybinds(Keybind(HELP_TEXT_KEYBIND_NAME, HELP_TEXT_KEY, callback=self._toggle_instructions)) - # Note: context.scene is genesis.engine.scene.Scene - # Note: context._scene is genesis.ext.pyrender.scene.Scene - self.viewer_interaction = ( - ViewerInteraction(self._camera_node, context.scene, viewport_size, camera.yfov) - if enable_interaction - else ViewerInteractionBase() - ) + # Setup viewer plugins + self.plugins: list[ViewerPlugin] = [] + for plugin in plugins: + self.register_plugin(plugin) ####################################################################### # Initialize OpenGL context and renderer @@ -521,25 +495,73 @@ def viewer_flags(self): def viewer_flags(self, value): self._viewer_flags = value - @property - def registered_keys(self): - """dict : Map from ASCII key character to a handler function. + def register_plugin(self, plugin: ViewerPlugin) -> None: + """ + Register a viewer plugin. - This is a map from ASCII key characters to tuples containing: + Parameters + ---------- + plugin : :class:`.ViewerPlugin` + The viewer plugin to add. + """ + self.plugins.append(plugin) + plugin.build(self, self._camera_node, self.gs_context.scene) + # Register pyglet.window event handlers from the plugin + self.push_handlers(plugin) - - A function to be called whenever the key is pressed, - whose first argument will be the viewer itself. - - (Optionally) A list of additional positional arguments - to be passed to the function. - - (Optionally) A dict of keyword arguments to be passed - to the function. + def register_keybinds(self, *keybinds: Keybind) -> None: + """ + Add a key handler to call a function when the given key is pressed. + Parameters + ---------- + keybinds : Keybind + One or more Keybind objects to register. """ - return self._registered_keys + for keybind in keybinds: + self._keybindings.register(keybind) + self._update_instr_texts() - @registered_keys.setter - def registered_keys(self, value): - self._registered_keys = value + def remap_keybind( + self, + keybind_name: str, + new_key_code: Key, + new_key_mods: tuple[KeyMod] | None, + new_key_action: KeyAction = KeyAction.PRESS, + ) -> None: + """ + Remap an existing keybind to a new key combination. + + Parameters + ---------- + keybind_name : str + The name of the keybind to remap. + new_key_code : int + The new key code from pyglet. + new_key_mods : tuple[KeyMod] | None + The new modifier keys pressed. + new_key_action : KeyAction, optional + The new type of key action. If not provided, the key action of the old keybind is used. + """ + self._keybindings.rebind( + keybind_name, + new_key_code, + new_key_mods, + new_key_action, + ) + self._update_instr_texts() + + def remove_keybind(self, keybind_name: str) -> None: + """ + Remove an existing keybind. + + Parameters + ---------- + keybind_name : str + The name of the keybind to remove. + """ + self._keybindings.remove(keybind_name) + self._update_instr_texts() def close(self): """Close the viewer. @@ -590,6 +612,9 @@ def on_close(self): # Do not consider the viewer as active anymore self._is_active = False + for plugin in self.plugins: + plugin.on_close() + # Remove our camera and restore the prior one try: if self._camera_node is not None: @@ -733,49 +758,25 @@ def on_draw(self): self.clear() self._render() - self.viewer_interaction.on_draw() - - if not self._disable_keyboard_shortcuts: - if self._display_instr: - self._renderer.render_texts( - self._instr_texts[1], - TEXT_PADDING, - self.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - else: - self._renderer.render_texts( - self._instr_texts[0], - TEXT_PADDING, - self.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - - if self._message_text is not None: - self._renderer.render_text( - self._message_text, - self.viewport_size[0] - TEXT_PADDING, - TEXT_PADDING, - font_pt=20, - color=np.array([0.1, 0.7, 0.2, np.clip(self._message_opac, 0.0, 1.0)]), - align=TextAlign.BOTTOM_RIGHT, - ) - - if self.viewer_flags["caption"] is not None: - for caption in self.viewer_flags["caption"]: - xpos, ypos = self._location_to_x_y(caption["location"]) - self._renderer.render_text( - caption["text"], - xpos, - ypos, - font_name=caption["font_name"], - font_pt=caption["font_pt"], - color=caption["color"], - scale=caption["scale"], - align=caption["location"], - ) + if self.viewer_flags["caption"] is not None: + for caption in self.viewer_flags["caption"]: + xpos, ypos = self._location_to_x_y(caption["location"]) + self._renderer.render_text( + caption["text"], + xpos, + ypos, + font_name=caption["font_name"], + font_pt=caption["font_pt"], + color=caption["color"], + scale=caption["scale"], + align=caption["location"], + ) + + # Render help text + self._render_help_text() + + for plugin in self.plugins: + plugin.on_draw() def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: """Resize the camera and trackball when the window is resized.""" @@ -789,15 +790,17 @@ def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: self._trackball.resize(self._viewport_size) self._renderer.viewport_width = self._viewport_size[0] self._renderer.viewport_height = self._viewport_size[1] - self.viewer_interaction.on_resize(width, height) self.on_draw() def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: """The mouse was moved with no buttons held down.""" - return self.viewer_interaction.on_mouse_motion(x, y, dx, dy) + pass def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record an initial mouse press.""" + # Stop animating while using the mouse + self.viewer_flags["mouse_pressed"] = True + self._trackball.set_state(Trackball.STATE_ROTATE) if button == pyglet.window.mouse.LEFT: ctrl = modifiers & pyglet.window.key.MOD_CTRL @@ -814,23 +817,19 @@ def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_H self._trackball.down(np.array([x, y])) - # Stop animating while using the mouse - self.viewer_flags["mouse_pressed"] = True - return self.viewer_interaction.on_mouse_press(x, y, button, modifiers) + return EVENT_HANDLED def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: """The mouse was moved with one or more buttons held down.""" - result = self.viewer_interaction.on_mouse_drag(x, y, dx, dy, buttons, modifiers) - if result is not EVENT_HANDLED: - result = self._trackball.drag(np.array([x, y])) + result = self._trackball.drag(np.array([x, y])) return result def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a mouse release.""" self.viewer_flags["mouse_pressed"] = False - return self.viewer_interaction.on_mouse_release(x, y, button, modifiers) + return EVENT_HANDLED - def on_mouse_scroll(self, x, y, dx, dy): + def on_mouse_scroll(self, x, y, dx, dy) -> EVENT_HANDLE_STATE: """Record a mouse scroll.""" if self.viewer_flags["use_perspective_cam"]: self._trackball.scroll(dy) @@ -849,176 +848,29 @@ def on_mouse_scroll(self, x, y, dx, dy): c.xmag = xmag c.ymag = ymag - def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - """Record a key press.""" - # First, check for registered key callbacks - if symbol in self.registered_keys: - tup = self.registered_keys[symbol] - callback = None - args = [] - kwargs = {} - if not isinstance(tup, (list, tuple, np.ndarray)): - callback = tup - else: - callback = tup[0] - if len(tup) == 2: - args = tup[1] - if len(tup) == 3: - kwargs = tup[2] - callback(self, *args, **kwargs) - return self.viewer_interaction.on_key_press(symbol, modifiers) - - # If keyboard shortcuts are disabled, skip default key functions - if self._disable_keyboard_shortcuts: - return self.viewer_interaction.on_key_press(symbol, modifiers) - - # Otherwise, use default key functions - - # A causes the frame to rotate - self._message_text = None - if symbol == pyglet.window.key.A: - self.viewer_flags["rotate"] = not self.viewer_flags["rotate"] - if self.viewer_flags["rotate"]: - self._message_text = "Rotation On" - else: - self._message_text = "Rotation Off" - - # F11 toggles face normals - elif symbol == pyglet.window.key.F11: - self.viewer_flags["fullscreen"] = not self.viewer_flags["fullscreen"] - self.set_fullscreen(self.viewer_flags["fullscreen"]) - self.activate() - if self.viewer_flags["fullscreen"]: - self._message_text = "Fullscreen On" - else: - self._message_text = "Fullscreen Off" + return EVENT_HANDLED - # H toggles shadows - elif symbol == pyglet.window.key.H: - self.render_flags["shadows"] = not self.render_flags["shadows"] - if self.render_flags["shadows"]: - self._message_text = "Shadows On" - else: - self._message_text = "Shadows Off" + def _call_keybind_callback(self, symbol: int, modifiers: int, action: KeyAction) -> None: + """Call registered keybind callbacks for the given key event.""" + keybind: Keybind = self._keybindings.get(symbol, modifiers, action) + if keybind is not None and keybind.callback is not None: + keybind.callback(*keybind.args, **keybind.kwargs) - # W toggles world frame - elif symbol == pyglet.window.key.W: - if not self.gs_context.world_frame_shown: - self.gs_context.on_world_frame() - self._message_text = "World Frame On" - else: - self.gs_context.off_world_frame() - self._message_text = "World Frame Off" - - # L toggles link frame - elif symbol == pyglet.window.key.L: - if not self.gs_context.link_frame_shown: - self.gs_context.on_link_frame() - self._message_text = "Link Frame On" - else: - self.gs_context.off_link_frame() - self._message_text = "Link Frame Off" - - # C toggles camera frustum - elif symbol == pyglet.window.key.C: - if not self.gs_context.camera_frustum_shown: - self.gs_context.on_camera_frustum() - self._message_text = "Camera Frustrum On" - else: - self.gs_context.off_camera_frustum() - self._message_text = "Camera Frustrum Off" - - # F toggles face normals - elif symbol == pyglet.window.key.F: - self.render_flags["face_normals"] = not self.render_flags["face_normals"] - if self.render_flags["face_normals"]: - self._message_text = "Face Normals On" - else: - self._message_text = "Face Normals Off" - - # V toggles vertex normals - elif symbol == pyglet.window.key.V: - self.render_flags["vertex_normals"] = not self.render_flags["vertex_normals"] - if self.render_flags["vertex_normals"]: - self._message_text = "Vert Normals On" - else: - self._message_text = "Vert Normals Off" - - # R starts recording frames - elif symbol == pyglet.window.key.R: - if self.viewer_flags["record"]: - self.save_video() - self.set_caption(self.viewer_flags["window_title"]) - else: - # Importing moviepy is very slow and not used very often. Let's delay import. - from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter - - self._video_recorder = FFMPEG_VideoWriter( - filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), - fps=self.viewer_flags["refresh_rate"], - size=self.viewport_size, - ) - self.set_caption("{} (RECORDING)".format(self.viewer_flags["window_title"])) - self.viewer_flags["record"] = not self.viewer_flags["record"] - - # S saves the current frame as an image - elif symbol == pyglet.window.key.S: - self._save_image() - - # T toggles through geom types - # elif symbol == pyglet.window.key.T: - # if self.gs_context.rigid_shown == 'visual': - # self.gs_context.on_rigid('collision') - # self._message_text = "Geom Type: 'collision'" - # elif self.gs_context.rigid_shown == 'collision': - # self.gs_context.on_rigid('sdf') - # self._message_text = "Geom Type: 'sdf'" - # else: - # self.gs_context.on_rigid('visual') - # self._message_text = "Geom Type: 'visual'" - - # D toggles through wireframe modes - elif symbol == pyglet.window.key.D: - if self.render_flags["flip_wireframe"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = True - self.render_flags["all_solid"] = False - self._message_text = "All Wireframe" - elif self.render_flags["all_wireframe"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = True - self._message_text = "All Solid" - elif self.render_flags["all_solid"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = False - self._message_text = "Default Wireframe" - else: - self.render_flags["flip_wireframe"] = True - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = False - self._message_text = "Flip Wireframe" - - # Z resets the camera viewpoint - elif symbol == pyglet.window.key.Z: - self._reset_view() - - # i toggles instruction display - elif symbol == pyglet.window.key.I: - self._display_instr = not self._display_instr - - elif symbol == pyglet.window.key.P: - self._renderer.reload_program() - - if self._message_text is not None: - self._message_opac = 1.0 + self._ticks_till_fade + def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + """Record a key press.""" + self._held_keys[(symbol, modifiers)] = True - return self.viewer_interaction.on_key_press(symbol, modifiers) + self._call_keybind_callback(symbol, modifiers, KeyAction.PRESS) def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a key release.""" - return self.viewer_interaction.on_key_release(symbol, modifiers) + self._held_keys.pop((symbol, modifiers), None) + + self._call_keybind_callback(symbol, modifiers, KeyAction.RELEASE) + + def on_deactivate(self) -> EVENT_HANDLE_STATE: + """Clear held keys when window loses focus.""" + self._held_keys.clear() @staticmethod def _time_event(dt, self): @@ -1032,26 +884,6 @@ def _time_event(dt, self): if self.viewer_flags["rotate"] and not self.viewer_flags["mouse_pressed"]: self._rotate() - # Manage message opacity - if self._message_text is not None: - if self._message_opac > 1.0: - self._message_opac -= 1.0 - else: - self._message_opac *= 0.90 - if self._message_opac < 0.05: - self._message_opac = 1.0 + self._ticks_till_fade - self._message_text = None - - # video saving warning - if self._video_saver is not None: - if self._video_saver.is_alive(): - self._message_text = "Saving video... Please don't exit." - self._message_opac = 1.0 - else: - self._message_text = f"Video saved to {self._video_file_name}" - self._message_opac = self.viewer_flags["refresh_rate"] * 2 - self._video_saver = None - self.on_draw() def _reset_view(self): @@ -1089,8 +921,7 @@ def _get_save_filename(self, file_exts): try: # Importing tkinter is very slow and not used very often. Let's delay import. - from tkinter import Tk - from tkinter import filedialog + from tkinter import Tk, filedialog if root is None: root = Tk() @@ -1234,8 +1065,8 @@ def start(self, auto_refresh=True): import pyglet # For some reason, this is necessary if 'pyglet.window.xlib' fails to import... try: - import pyglet.window.xlib import pyglet.display.xlib + import pyglet.window.xlib xlib_exceptions = (pyglet.window.xlib.XlibException, pyglet.display.xlib.NoSuchDisplayException) except ImportError: @@ -1406,7 +1237,11 @@ def refresh(self): self.flip() def update_on_sim_step(self): - self.viewer_interaction.update_on_sim_step() + # Call HOLD callbacks for all currently held keys + for symbol, modifiers in list(self._held_keys.keys()): + self._call_keybind_callback(symbol, modifiers, KeyAction.HOLD) + for plugin in self.plugins: + plugin.update_on_sim_step() def _compute_initial_camera_pose(self): centroid = self.scene.centroid @@ -1475,5 +1310,71 @@ def _location_to_x_y(self, location): elif location == TextAlign.TOP_CENTER: return (self.viewport_size[0] / 2.0, self.viewport_size[1] - TEXT_PADDING) + def _update_instr_texts(self): + """Update the instruction text based on current keybindings.""" + if self._disable_help_text: + return + + self._key_instr_texts = self._instr_texts[0] + [ + # f"{'[' + get_keycode_string(kb.key_code):>{7}}]: " + kb.name.replace("_", " ") + f"{'[' + str(kb.key):>{7}}]: " + kb.name.replace("_", " ") + for kb in self._keybindings.keybinds + if kb.name != HELP_TEXT_KEYBIND_NAME and kb.key_action != KeyAction.RELEASE + ] + + def _toggle_instructions(self): + """Toggle the display of keyboard instructions.""" + if self._disable_help_text: + raise RuntimeError("Instructions display is disabled.") + self._collapse_instructions = not self._collapse_instructions + + def set_message_text(self, text: str): + """Set a temporary message to display on the viewer.""" + self._message_text = text + self._message_opac = 1.0 + self._ticks_till_fade + + def _render_help_text(self): + """Render help text and messages on the viewer.""" + if self._disable_help_text: + return + + # Render temporary message + if self._message_text is not None: + self._renderer.render_text( + self._message_text, + self._viewport_size[0] - TEXT_PADDING, + TEXT_PADDING, + font_pt=FONT_SIZE, + color=np.array([0.1, 0.7, 0.2, np.clip(self._message_opac, 0.0, 1.0)]), + align=TextAlign.BOTTOM_RIGHT, + ) + + if self._message_opac > 1.0: + self._message_opac -= 1.0 + else: + self._message_opac *= 0.90 + + if self._message_opac < 0.05: + self._message_opac = 1.0 + self._ticks_till_fade + self._message_text = None + + # Render keyboard instructions + if self._collapse_instructions: + self._renderer.render_texts( + self._instr_texts[0], + TEXT_PADDING, + self._viewport_size[1] - TEXT_PADDING, + font_pt=FONT_SIZE, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + else: + self._renderer.render_texts( + self._key_instr_texts, + TEXT_PADDING, + self._viewport_size[1] - TEXT_PADDING, + font_pt=FONT_SIZE, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + __all__ = ["Viewer"] diff --git a/genesis/options/vis.py b/genesis/options/vis.py index 939df400c0..4f126cb44d 100644 --- a/genesis/options/vis.py +++ b/genesis/options/vis.py @@ -33,20 +33,22 @@ class ViewerOptions(Options): The up vector of the camera's extrinsic pose. camera_fov : float The field of view (in degrees) of the camera. - disable_keyboard_shortcuts : bool - Whether to disable all keyboard shortcuts in the viewer. Defaults to False. + disable_help_text : bool + Whether to disable the rendering of instructions text in the viewer. + disable_default_keybinds : bool + Whether to disable the default keyboard controls in the viewer. """ - res: Optional[tuple] = None - run_in_thread: Optional[bool] = None + res: tuple | None = None + run_in_thread: bool | None = None refresh_rate: int = 60 - max_FPS: Optional[int] = 60 + max_FPS: int | None = 60 camera_pos: tuple = (3.5, 0.5, 2.5) camera_lookat: tuple = (0.0, 0.0, 0.5) camera_up: tuple = (0.0, 0.0, 1.0) camera_fov: float = 40 - enable_interaction: bool = False - disable_keyboard_shortcuts: bool = False + disable_help_text: bool = False + disable_default_keybinds: bool = False class VisOptions(Options): diff --git a/genesis/vis/keybindings.py b/genesis/vis/keybindings.py new file mode 100644 index 0000000000..84697f0493 --- /dev/null +++ b/genesis/vis/keybindings.py @@ -0,0 +1,390 @@ +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Callable + + +class LabeledIntEnum(IntEnum): + def __new__(cls, value, label): + obj = int.__new__(cls, value) + obj._value_ = value + obj._label = label + return obj + + def __str__(self) -> str: + return self._label + + +class Key(LabeledIntEnum): + """ + Key codes for keyboard keys. + + These are compatible with the pyglet key codes. + https://github.com/pyglet/pyglet/blob/master/pyglet/window/key.py + """ + + # fmt: off + + # ASCII commands + BACKSPACE = 0xff08, "backspace" + TAB = 0xff09, "tab" + LINEFEED = 0xff0a, "linefeed" + CLEAR = 0xff0b, "clear" + RETURN = 0xff0d, "return" + ENTER = 0xff0d, "enter" # synonym + PAUSE = 0xff13, "pause" + SCROLLLOCK = 0xff14, "scrolllock" + SYSREQ = 0xff15, "sysreq" + ESCAPE = 0xff1b, "escape" + + # Cursor control and motion + HOME = 0xff50, "home" + LEFT = 0xff51, "left" + UP = 0xff52, "up" + RIGHT = 0xff53, "right" + DOWN = 0xff54, "down" + PAGEUP = 0xff55, "pageup" + PAGEDOWN = 0xff56, "pagedown" + END = 0xff57, "end" + BEGIN = 0xff58, "begin" + + # Misc functions + DELETE = 0xffff, "delete" + SELECT = 0xff60, "select" + PRINT = 0xff61, "print" + EXECUTE = 0xff62, "execute" + INSERT = 0xff63, "insert" + UNDO = 0xff65, "undo" + REDO = 0xff66, "redo" + MENU = 0xff67, "menu" + FIND = 0xff68, "find" + CANCEL = 0xff69, "cancel" + HELP = 0xff6a, "help" + BREAK = 0xff6b, "break" + MODESWITCH = 0xff7e, "modeswitch" + SCRIPTSWITCH = 0xff7e, "scriptswitch" + FUNCTION = 0xffd2, "function" + + # Number pad + NUMLOCK = 0xff7f, "numlock" + NUM_SPACE = 0xff80, "num_space" + NUM_TAB = 0xff89, "num_tab" + NUM_ENTER = 0xff8d, "num_enter" + NUM_F1 = 0xff91, "num_f1" + NUM_F2 = 0xff92, "num_f2" + NUM_F3 = 0xff93, "num_f3" + NUM_F4 = 0xff94, "num_f4" + NUM_HOME = 0xff95, "num_home" + NUM_LEFT = 0xff96, "num_left" + NUM_UP = 0xff97, "num_up" + NUM_RIGHT = 0xff98, "num_right" + NUM_DOWN = 0xff99, "num_down" + NUM_PRIOR = 0xff9a, "num_prior" + NUM_PAGE_UP = 0xff9a, "num_page_up" + NUM_NEXT = 0xff9b, "num_next" + NUM_PAGE_DOWN = 0xff9b, "num_page_down" + NUM_END = 0xff9c, "num_end" + NUM_BEGIN = 0xff9d, "num_begin" + NUM_INSERT = 0xff9e, "num_insert" + NUM_DELETE = 0xff9f, "num_delete" + NUM_EQUAL = 0xffbd, "num_equal" + NUM_MULTIPLY = 0xffaa, "num_multiply" + NUM_ADD = 0xffab, "num_add" + NUM_SEPARATOR = 0xffac, "num_separator" + NUM_SUBTRACT = 0xffad, "num_subtract" + NUM_DECIMAL = 0xffae, "num_decimal" + NUM_DIVIDE = 0xffaf, "num_divide" + + NUM_0 = 0xffb0, "num_0" + NUM_1 = 0xffb1, "num_1" + NUM_2 = 0xffb2, "num_2" + NUM_3 = 0xffb3, "num_3" + NUM_4 = 0xffb4, "num_4" + NUM_5 = 0xffb5, "num_5" + NUM_6 = 0xffb6, "num_6" + NUM_7 = 0xffb7, "num_7" + NUM_8 = 0xffb8, "num_8" + NUM_9 = 0xffb9, "num_9" + + # Function keys + F1 = 0xffbe, "f1" + F2 = 0xffbf, "f2" + F3 = 0xffc0, "f3" + F4 = 0xffc1, "f4" + F5 = 0xffc2, "f5" + F6 = 0xffc3, "f6" + F7 = 0xffc4, "f7" + F8 = 0xffc5, "f8" + F9 = 0xffc6, "f9" + F10 = 0xffc7, "f10" + F11 = 0xffc8, "f11" + F12 = 0xffc9, "f12" + F13 = 0xffca, "f13" + F14 = 0xffcb, "f14" + F15 = 0xffcc, "f15" + F16 = 0xffcd, "f16" + F17 = 0xffce, "f17" + F18 = 0xffcf, "f18" + F19 = 0xffd0, "f19" + F20 = 0xffd1, "f20" + F21 = 0xffd2, "f21" + F22 = 0xffd3, "f22" + F23 = 0xffd4, "f23" + F24 = 0xffd5, "f24" + + # Modifiers + LSHIFT = 0xffe1, "left_shift" + RSHIFT = 0xffe2, "right_shift" + LCTRL = 0xffe3, "left_ctrl" + RCTRL = 0xffe4, "right_ctrl" + CAPSLOCK = 0xffe5, "capslock" + LMETA = 0xffe7, "left_meta" + RMETA = 0xffe8, "right_meta" + LALT = 0xffe9, "left_alt" + RALT = 0xffea, "right_alt" + LWINDOWS = 0xffeb, "left_windows" + RWINDOWS = 0xffec, "right_windows" + LCOMMAND = 0xffed, "left_command" + RCOMMAND = 0xffee, "right_command" + LOPTION = 0xffef, "left_option" + ROPTION = 0xfff0, "right_option" + + # Latin-1 + SPACE = 0x020, "space" + EXCLAMATION = 0x021, "!" + DOUBLEQUOTE = 0x022, "\"" + HASH = 0x023, "#" + POUND = 0x023, "#" # synonym + DOLLAR = 0x024, "$" + PERCENT = 0x025, "%" + AMPERSAND = 0x026, "&" + APOSTROPHE = 0x027, "'" + PARENLEFT = 0x028, "(" + PARENRIGHT = 0x029, ")" + ASTERISK = 0x02a, "*" + PLUS = 0x02b, "+" + COMMA = 0x02c, "," + MINUS = 0x02d, "-" + PERIOD = 0x02e, "." + SLASH = 0x02f, "/" + _0 = 0x030, "0" + _1 = 0x031, "1" + _2 = 0x032, "2" + _3 = 0x033, "3" + _4 = 0x034, "4" + _5 = 0x035, "5" + _6 = 0x036, "6" + _7 = 0x037, "7" + _8 = 0x038, "8" + _9 = 0x039, "9" + COLON = 0x03a, ":" + SEMICOLON = 0x03b, ";" + LESS = 0x03c, "<" + EQUAL = 0x03d, "=" + GREATER = 0x03e, ">" + QUESTION = 0x03f, "?" + AT = 0x040, "@" + BRACKETLEFT = 0x05b, "[" + BACKSLASH = 0x05c, "\\" + BRACKETRIGHT = 0x05d, "]" + ASCIICIRCUM = 0x05e, "^" + UNDERSCORE = 0x05f, "_" + GRAVE = 0x060, "`" + QUOTELEFT = 0x060, "`" + A = 0x061, "a" + B = 0x062, "b" + C = 0x063, "c" + D = 0x064, "d" + E = 0x065, "e" + F = 0x066, "f" + G = 0x067, "g" + H = 0x068, "h" + I = 0x069, "i" + J = 0x06a, "j" + K = 0x06b, "k" + L = 0x06c, "l" + M = 0x06d, "m" + N = 0x06e, "n" + O = 0x06f, "o" + P = 0x070, "p" + Q = 0x071, "q" + R = 0x072, "r" + S = 0x073, "s" + T = 0x074, "t" + U = 0x075, "u" + V = 0x076, "v" + W = 0x077, "w" + X = 0x078, "x" + Y = 0x079, "y" + Z = 0x07a, "z" + BRACELEFT = 0x07b, "{" + BAR = 0x07c, "|" + BRACERIGHT = 0x07d, "}" + ASCIITILDE = 0x07e, "~" + # fmt: on + + +class KeyMod(LabeledIntEnum): + # fmt: off + SHIFT = 1 << 0, "shift" + CTRL = 1 << 1, "ctrl" + ALT = 1 << 2, "alt" + CAPSLOCK = 1 << 3, "capslock" + NUMLOCK = 1 << 4, "numlock" + WINDOWS = 1 << 5, "windows" + COMMAND = 1 << 6, "command" + OPTION = 1 << 7, "option" + SCROLLLOCK = 1 << 8, "scrolllock" + FUNCTION = 1 << 9, "function" + # fmt: on + + +class KeyAction(LabeledIntEnum): + PRESS = 0, "press" + HOLD = 1, "hold" + RELEASE = 2, "release" + + +def get_key_hash(key_code: int, modifiers: int | None, action: KeyAction) -> int: + """Generate a unique hash for a key combination. + + Parameters + ---------- + key_code: int + The key code as an int. + modifiers : int | None + The modifier keys pressed, as an int with bit flags, or None to ignore modifiers. + action : KeyAction + The type of key action (press, hold, release). + + Returns + ------- + int + A unique hash for this key combination. + """ + return hash((key_code, modifiers, action)) + + +@dataclass +class Keybind: + """ + A keybinding with an associated callback. + + Parameters + ---------- + name : str + The name of the keybind. + key : Key + The key code for the keybind. + key_action : KeyAction + The type of key action (press, hold, release). + key_mods : tuple[KeyMod] | None + The modifier keys required for the keybind. If None, modifiers are ignored. + callback : Callable[[], None] | None + The function to call when the keybind is activated. + args : tuple + Positional arguments to pass to the callback. + kwargs : dict + Keyword arguments to pass to the callback. + """ + + name: str + key: Key + key_action: KeyAction = KeyAction.PRESS + key_mods: tuple[KeyMod] | None = None + callback: Callable[[], None] | None = None + args: tuple = () + kwargs: dict = field(default_factory=dict) + + _modifiers: int | None = field(default=None, init=False, repr=False) + + def __post_init__(self): + if self.key_mods is not None: + self._modifiers = 0 + for mod in self.key_mods: + self._modifiers |= mod + if self.kwargs is None: + self.kwargs = {} + + def key_hash(self) -> int: + """Generate a unique hash for the keybind based on key code and modifiers.""" + return get_key_hash(self.key, self._modifiers, self.key_action) + + +class Keybindings: + def __init__(self, keybinds: tuple[Keybind] = ()): + self._keybinds_map: dict[int, Keybind] = {} + self._name_to_hash: dict[str, int] = {} + for kb in keybinds: + key_hash = kb.key_hash() + self._keybinds_map[key_hash] = kb + self._name_to_hash[kb.name] = key_hash + + def register(self, keybind: Keybind) -> None: + key_hash = keybind.key_hash() + if key_hash in self._keybinds_map: + existing_kb = self._keybinds_map[key_hash] + raise ValueError(f"Key [{keybind.key}] is already assigned to '{existing_kb.name}'.") + if keybind.name and keybind.name in self._name_to_hash: + raise ValueError(f"Name '{keybind.name}' is already assigned to another keybind.") + + self._keybinds_map[key_hash] = keybind + self._name_to_hash[keybind.name] = key_hash + + def remove(self, name: str) -> None: + if name not in self._name_to_hash: + raise ValueError(f"No keybind found with name '{name}'.") + key_hash = self._name_to_hash[name] + del self._keybinds_map[key_hash] + del self._name_to_hash[name] + + def rebind( + self, + name: str, + new_key: Key | None, + new_key_mods: tuple[KeyMod] | None, + new_key_action: KeyAction | None = None, + ) -> None: + if name not in self._name_to_hash: + raise ValueError(f"No keybind found with name '{name}'.") + old_hash = self._name_to_hash[name] + kb = self._keybinds_map[old_hash] + new_kb = Keybind( + name=kb.name, + key=new_key or kb.key, + key_action=new_key_action or kb.key_action, + key_mods=new_key_mods, + callback=kb.callback, + args=kb.args, + kwargs=kb.kwargs, + ) + del self._keybinds_map[old_hash] + new_hash = new_kb.key_hash() + print("new_kb", new_kb) + self._keybinds_map[new_hash] = new_kb + self._name_to_hash[name] = new_hash + + def get(self, key: int, modifiers: int, key_action: KeyAction) -> Keybind | None: + key_hash = get_key_hash(key, modifiers, key_action) + if key_hash in self._keybinds_map: + return self._keybinds_map[key_hash] + + # Try ignoring modifiers (for keybinds where modifiers=None) + key_hash_no_mods = get_key_hash(key, None, key_action) + if key_hash_no_mods in self._keybinds_map: + return self._keybinds_map[key_hash_no_mods] + + return None + + def get_by_name(self, name: str) -> Keybind | None: + if name in self._name_to_hash: + key_hash = self._name_to_hash[name] + return self._keybinds_map[key_hash] + return None + + def __len__(self) -> int: + return len(self._keybinds_map) + + @property + def keybinds(self) -> tuple[Keybind]: + """Return a tuple of all registered Keybinds.""" + return tuple(self._keybinds_map.values()) diff --git a/genesis/vis/viewer.py b/genesis/vis/viewer.py index 3a0c24d4c8..1b7381d8b1 100644 --- a/genesis/vis/viewer.py +++ b/genesis/vis/viewer.py @@ -1,6 +1,6 @@ +import importlib import os import threading -import importlib from typing import TYPE_CHECKING import numpy as np @@ -9,14 +9,16 @@ import genesis as gs import genesis.utils.geom as gu - from genesis.ext import pyrender from genesis.repr_base import RBC -from genesis.utils.tools import Rate from genesis.utils.misc import redirect_libc_stderr, tensor_to_array +from genesis.utils.tools import Rate +from genesis.vis.keybindings import Key, KeyAction, Keybind, KeyMod +from genesis.vis.viewer_plugins import DefaultControlsPlugin if TYPE_CHECKING: from genesis.options.vis import ViewerOptions + from genesis.vis.viewer_plugins import ViewerPlugin class ViewerLock: @@ -32,6 +34,7 @@ def __exit__(self, exc_type, exc_value, traceback): class Viewer(RBC): def __init__(self, options: "ViewerOptions", context): + self._is_built = False self._res = options.res self._run_in_thread = options.run_in_thread self._refresh_rate = options.refresh_rate @@ -40,16 +43,16 @@ def __init__(self, options: "ViewerOptions", context): self._camera_init_lookat = np.asarray(options.camera_lookat, dtype=gs.np_float) self._camera_up = np.asarray(options.camera_up, dtype=gs.np_float) self._camera_fov = options.camera_fov - self._enable_interaction = options.enable_interaction - self._disable_keyboard_shortcuts = options.disable_keyboard_shortcuts + + self._disable_help_text = options.disable_help_text + self._viewer_plugins: list["ViewerPlugin"] = [] + if not options.disable_default_keybinds: + self._viewer_plugins.append(DefaultControlsPlugin()) # Validate viewer options if any(e.shape != (3,) for e in (self._camera_init_pos, self._camera_init_lookat, self._camera_up)): gs.raise_exception("ViewerOptions.camera_(pos|lookat|up) must be sequences of length 3.") - if options.enable_interaction and gs.backend != gs.cpu: - gs.logger.warning("Interaction code is slow on GPU. Switch to CPU backend or disable interaction.") - self._pyrender_viewer = None self.context = context @@ -100,8 +103,8 @@ def build(self, scene): shadow=self.context.shadow, plane_reflection=self.context.plane_reflection, env_separate_rigid=self.context.env_separate_rigid, - enable_interaction=self._enable_interaction, - disable_keyboard_shortcuts=self._disable_keyboard_shortcuts, + disable_help_text=self._disable_help_text, + plugins=self._viewer_plugins, viewer_flags={ "window_title": f"Genesis {gs.__version__}", "refresh_rate": self._refresh_rate, @@ -135,6 +138,8 @@ def build(self, scene): renderer = glinfo.get_renderer() gs.logger.debug(f"Using interactive viewer OpenGL device: {renderer}") + self._is_built = True + def run(self): if self._pyrender_viewer is None: gs.raise_exception("Viewer must be built successfully before calling this method.") @@ -265,10 +270,80 @@ def update_following(self): else: self.set_camera_pose(pos=camera_pos, lookat=self._follow_lookat) + @gs.assert_built + def register_keybinds(self, *keybinds: Keybind) -> None: + """ + Register a callback function to be called when a key is pressed. + + Parameters + ---------- + keybinds : Keybind + One or more Keybind objects to register. See Keybind documentation for usage. + """ + self._pyrender_viewer.register_keybinds(*keybinds) + + @gs.assert_built + def remap_keybind( + self, + keybind_name: str, + new_key: Key, + new_key_mods: tuple[KeyMod] | None, + new_key_action: KeyAction = KeyAction.PRESS, + ) -> None: + """ + Remap an existing keybind by name to a new key combination. + + Parameters + ---------- + keybind_name : str + The name of the keybind to remap. + new_key : int + The new key code from pyglet. + new_key_mods : tuple[KeyMod] | None + The new modifier keys pressed. + new_key_action : KeyAction, optional + The new type of key action. If not provided, the key action of the old keybind is used. + """ + self._pyrender_viewer.remap_keybind( + keybind_name, + new_key, + new_key_mods, + new_key_action, + ) + + @gs.assert_built + def remove_keybind(self, keybind_name: str) -> None: + """ + Remove an existing keybind by name. + + Parameters + ---------- + keybind_name : str + The name of the keybind to remove. + """ + self._pyrender_viewer.remove_keybind(keybind_name) + + def add_plugin(self, plugin: "ViewerPlugin") -> None: + """ + Add a viewer plugin to the viewer. + + Parameters + ---------- + plugin : ViewerPlugin + The viewer plugin to add. + """ + self._viewer_plugins.append(plugin) + if self.is_built: + self._viewer.register_plugin(plugin) + # ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------ + @property + def is_built(self): + return self._is_built + @property def res(self): return self._res diff --git a/genesis/vis/viewer_plugins/__init__.py b/genesis/vis/viewer_plugins/__init__.py new file mode 100644 index 0000000000..ca16b2038c --- /dev/null +++ b/genesis/vis/viewer_plugins/__init__.py @@ -0,0 +1,6 @@ +from .plugins import * +from .viewer_plugin import ( + EVENT_HANDLE_STATE, + EVENT_HANDLED, + ViewerPlugin, +) diff --git a/genesis/vis/viewer_plugins/plugins/__init__.py b/genesis/vis/viewer_plugins/plugins/__init__.py new file mode 100644 index 0000000000..83634f65bc --- /dev/null +++ b/genesis/vis/viewer_plugins/plugins/__init__.py @@ -0,0 +1,5 @@ +from .default_controls import DefaultControlsPlugin + +__all__ = [ + "DefaultControlsPlugin", +] diff --git a/genesis/vis/viewer_plugins/plugins/default_controls.py b/genesis/vis/viewer_plugins/plugins/default_controls.py new file mode 100644 index 0000000000..16c03d34e8 --- /dev/null +++ b/genesis/vis/viewer_plugins/plugins/default_controls.py @@ -0,0 +1,149 @@ +import os +from typing import TYPE_CHECKING + +import genesis as gs +from genesis.vis.keybindings import Key, Keybind + +from ..viewer_plugin import ViewerPlugin + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + + +class DefaultControlsPlugin(ViewerPlugin): + """ + Default keyboard controls for the Genesis viewer. + + This plugin handles the standard viewer keyboard shortcuts for recording, changing render modes, etc. + """ + + def __init__(self): + super().__init__() + + def build(self, viewer, camera: "Node", scene: "Scene"): + super().build(viewer, camera, scene) + + self.viewer.register_keybinds( + Keybind("record_video", Key.R, callback=self._toggle_record_video), + Keybind("save_image", Key.S, callback=self._save_image), + Keybind("reset_camera", Key.Z, callback=self._reset_camera), + Keybind("camera_rotation", Key.A, callback=self._toggle_cam_rotation), + Keybind("shadow", Key.H, callback=self._toggle_shadow), + Keybind("face_normals", Key.F, callback=self._toggle_face_normals), + Keybind("vertex_normals", Key.V, callback=self._toggle_vertex_normals), + Keybind("world_frame", Key.W, callback=self._toggle_world_frame), + Keybind("link_frame", Key.L, callback=self._toggle_link_frame), + Keybind("wireframe", Key.D, callback=self._toggle_wireframe), + Keybind("camera_frustum", Key.C, callback=self._toggle_camera_frustum), + Keybind("reload_shader", Key.P, callback=self._reload_shader), + Keybind("fullscreen_mode", Key.F11, callback=self._toggle_fullscreen), + ) + + def _toggle_cam_rotation(self): + self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] + if self.viewer.viewer_flags["rotate"]: + self.viewer.set_message_text("Rotation On") + else: + self.viewer.set_message_text("Rotation Off") + + def _toggle_fullscreen(self): + self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] + self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) + self.viewer.activate() + if self.viewer.viewer_flags["fullscreen"]: + self.viewer.set_message_text("Fullscreen On") + else: + self.viewer.set_message_text("Fullscreen Off") + + def _toggle_shadow(self): + self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] + if self.viewer.render_flags["shadows"]: + self.viewer.set_message_text("Shadows On") + else: + self.viewer.set_message_text("Shadows Off") + + def _toggle_world_frame(self): + if not self.viewer.gs_context.world_frame_shown: + self.viewer.gs_context.on_world_frame() + self.viewer.set_message_text("World Frame On") + else: + self.viewer.gs_context.off_world_frame() + self.viewer.set_message_text("World Frame Off") + + def _toggle_link_frame(self): + if not self.viewer.gs_context.link_frame_shown: + self.viewer.gs_context.on_link_frame() + self.viewer.set_message_text("Link Frame On") + else: + self.viewer.gs_context.off_link_frame() + self.viewer.set_message_text("Link Frame Off") + + def _toggle_camera_frustum(self): + if not self.viewer.gs_context.camera_frustum_shown: + self.viewer.gs_context.on_camera_frustum() + self.viewer.set_message_text("Camera Frustum On") + else: + self.viewer.gs_context.off_camera_frustum() + self.viewer.set_message_text("Camera Frustum Off") + + def _toggle_face_normals(self): + self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] + if self.viewer.render_flags["face_normals"]: + self.viewer.set_message_text("Face Normals On") + else: + self.viewer.set_message_text("Face Normals Off") + + def _toggle_vertex_normals(self): + self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] + if self.viewer.render_flags["vertex_normals"]: + self.viewer.set_message_text("Vert Normals On") + else: + self.viewer.set_message_text("Vert Normals Off") + + def _toggle_record_video(self): + if self.viewer.viewer_flags["record"]: + self.viewer.save_video() + self.viewer.set_caption(self.viewer.viewer_flags["window_title"]) + else: + # Importing moviepy is very slow and not used very often. Let's delay import. + from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter + + self.viewer._video_recorder = FFMPEG_VideoWriter( + filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), + fps=self.viewer.viewer_flags["refresh_rate"], + size=self.viewer.viewport_size, + ) + self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) + self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] + + def _save_image(self): + self.viewer._save_image() + + def _toggle_wireframe(self): + if self.viewer.render_flags["flip_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = True + self.viewer.render_flags["all_solid"] = False + self.viewer.set_message_text("All Wireframe") + elif self.viewer.render_flags["all_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = True + self.viewer.set_message_text("All Solid") + elif self.viewer.render_flags["all_solid"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer.set_message_text("Default Wireframe") + else: + self.viewer.render_flags["flip_wireframe"] = True + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer.set_message_text("Flip Wireframe") + + def _reset_camera(self): + self.viewer._reset_view() + + def _reload_shader(self): + self.viewer._renderer.reload_program() diff --git a/genesis/vis/viewer_plugins/viewer_plugin.py b/genesis/vis/viewer_plugins/viewer_plugin.py new file mode 100644 index 0000000000..cb336f0190 --- /dev/null +++ b/genesis/vis/viewer_plugins/viewer_plugin.py @@ -0,0 +1,67 @@ +from typing import TYPE_CHECKING, Literal + +import numpy as np + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + from genesis.ext.pyrender.viewer import Viewer + + +EVENT_HANDLE_STATE = Literal[True] | None +EVENT_HANDLED: Literal[True] = True + + +class ViewerPlugin: + """ + Base class for handling pyglet.window.Window events. + """ + + def __init__(self): + self.viewer = None + self.camera: "Node | None" = None + self.scene: "Scene | None" = None + self._camera_yfov: float = 0.0 + self._tan_half_fov: float = 0.0 + + def build(self, viewer: "Viewer", camera: "Node", scene: "Scene"): + """Build and initialize the plugin with pyrender viewer context.""" + + self.viewer = viewer + self.camera = camera + self.scene = scene + self._camera_yfov: float = camera.camera.yfov + self._tan_half_fov: float = np.tan(0.5 * self._camera_yfov) + + def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_scroll(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: + pass + + def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: + pass + + def update_on_sim_step(self) -> None: + pass + + def on_draw(self) -> None: + pass + + def on_close(self) -> None: + pass diff --git a/tests/test_render.py b/tests/test_render.py index cd6bf084ca..26593ce5ea 100644 --- a/tests/test_render.py +++ b/tests/test_render.py @@ -6,10 +6,10 @@ import time import numpy as np +import OpenGL.error import pyglet import pytest import torch -import OpenGL.error import genesis as gs import genesis.utils.geom as gu @@ -1232,28 +1232,6 @@ def on_key_press(self, symbol: int, modifiers: int): assert f.read() == png_snapshot -@pytest.mark.required -@pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) -@pytest.mark.skipif(not IS_INTERACTIVE_VIEWER_AVAILABLE, reason="Interactive viewer not supported on this platform.") -@pytest.mark.xfail(sys.platform == "win32", raises=OpenGL.error.Error, reason="Invalid OpenGL context.") -def test_interactive_viewer_disable_keyboard_shortcuts(): - """Test that keyboard shortcuts can be disabled in the interactive viewer.""" - - # Test with keyboard shortcuts DISABLED - scene = gs.Scene( - viewer_options=gs.options.ViewerOptions( - disable_keyboard_shortcuts=True, - ), - show_viewer=True, - ) - scene.build() - pyrender_viewer = scene.visualizer.viewer._pyrender_viewer - assert pyrender_viewer.is_active - - # Verify the flag is set correctly - assert pyrender_viewer._disable_keyboard_shortcuts is True - - @pytest.mark.required @pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) def test_camera_gimbal_lock_singularity(renderer, show_viewer): diff --git a/tests/test_utils.py b/tests/test_utils.py index ba474574cc..4b24ec55c7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -331,118 +331,6 @@ def test_geom_tensor_identity(batch_shape): np.testing.assert_allclose(tensor_to_array(tc_args[0]), tensor_to_array(tc_args[-1]), atol=1e2 * gs.EPS) -@pytest.mark.required -def test_pyrender_vec3(): - from genesis.ext.pyrender.interaction.vec3 import Vec3, Quat - - tol = 1e-6 - # construction helpers enforce shape and dtype - v = Vec3.from_xyz(1.0, 2.0, 3.0) - assert v.v.shape == (3,) - assert_allclose(v.v, np.array([1.0, 2.0, 3.0]), tol=gs.EPS) - assert_allclose((v.x, v.y, v.z), (1.0, 2.0, 3.0), tol=gs.EPS) - - # from_array converts various dtypes to float32 - v_i64 = Vec3.from_array(np.array([1, 2, 3], dtype=np.int64)) - assert_allclose(v_i64.v, np.array([1, 2, 3]), tol=gs.EPS) - - v_f64 = Vec3.from_array(np.array([0.5, -1.5, 2.0], dtype=np.float64)) - assert_allclose(v_f64.v, np.array([0.5, -1.5, 2.0]), tol=gs.EPS) - - # from_tensor - v_t = Vec3.from_tensor(torch.tensor([4.0, 5.0, 6.0], dtype=torch.float32)) - assert_allclose(v_t.v, np.array([4.0, 5.0, 6.0]), tol=gs.EPS) - - # constants - assert_allclose(Vec3.zero().v, 0.0, tol=gs.EPS) - assert_allclose(Vec3.one().v, 1.0, tol=gs.EPS) - assert_allclose(Vec3.full(5.5).v, 5.5, tol=gs.EPS) - - # arithmetic ops and dtype preservation - a = Vec3.from_xyz(1, 2, 3) - b = Vec3.from_xyz(4, 5, 6) - c = a + b - d = b - a - assert_allclose(c.v, np.array([5, 7, 9]), tol=gs.EPS) - assert_allclose(d.v, np.array([3, 3, 3]), tol=gs.EPS) - - m1 = a * 2.0 - m2 = 2.0 * a - assert_allclose(m1.v, np.array([2, 4, 6]), tol=gs.EPS) - assert_allclose(m2.v, np.array([2, 4, 6]), tol=gs.EPS) - - # dot and cross - dot_ab = a.dot(b) - assert_allclose(dot_ab, 1 * 4 + 2 * 5 + 3 * 6, tol=gs.EPS) - - cross_ab = a.cross(b) - assert_allclose(cross_ab.v, np.array([-3.0, 6.0, -3.0]), tol=gs.EPS) - - # norms - assert_allclose(a.sqr_magnitude(), 1.0 + 4.0 + 9.0, tol=gs.EPS) - assert_allclose(a.magnitude(), np.sqrt(a.sqr_magnitude()), tol=gs.EPS) - na = a.normalized() - assert_allclose(na.magnitude(), 1.0, tol=tol) - assert_allclose(Vec3.zero().normalized().v, 0.0, tol=gs.EPS) - - # copy is deep for underlying array - cp = a.copy() - assert cp is not a - cp.v[...] = 0.0 - assert_allclose(a.v, np.array([1.0, 2.0, 3.0]), tol=gs.EPS) - assert_allclose(cp.v, 0.0, tol=gs.EPS) - - # repr and tensor conversion - t = a.as_tensor() - assert isinstance(t, torch.Tensor) - assert_allclose(t, a.v, tol=gs.EPS) - - # --- Quat tests --- - q = Quat.from_wxyz(1.0, 0.0, 0.0, 0.0) # identity - assert q.v.shape == (4,) - assert_allclose(np.array([q.w, q.x, q.y, q.z]), np.array([1.0, 0.0, 0.0, 0.0]), tol=gs.EPS) - - # from_array converts dtype and enforces shape - q_arr = Quat.from_array(np.array([0.5, 0.5, -0.5, 0.5], dtype=np.float64)) - assert_allclose(q_arr.v, np.array([0.5, 0.5, -0.5, 0.5]), tol=gs.EPS) - - # from_tensor - q_t = Quat.from_tensor(torch.tensor([0.0, 1.0, 0.0, 0.0], dtype=torch.float32)) - assert_allclose(q_t.v, np.array([0.0, 1.0, 0.0, 0.0]), tol=gs.EPS) - - # inverse - q_inv = q_arr.get_inverse() - assert_allclose(q_inv.w, q_arr.w, tol=gs.EPS) - assert_allclose(q_inv.v[1:], -q_arr.v[1:], tol=gs.EPS) - - # quat * quat (identity) - qq = q * q_arr - assert_allclose(qq.v, q_arr.v, tol=gs.EPS) - - # rotation of a vector by 90deg about z: (1,0,0) -> (0,1,0) - theta = np.pi / 2.0 - qz = Quat.from_wxyz(np.cos(theta / 2.0), 0.0, 0.0, np.sin(theta / 2.0)) - v_x = Vec3.from_xyz(1.0, 0.0, 0.0) - v_rot = qz * v_x - assert_allclose(v_rot.v, np.array([0.0, 1.0, 0.0]), tol=tol) - - # quat * quat inverse -> identity - q_unit = qz * qz.get_inverse() - assert_allclose(q_unit.v, Quat.from_wxyz(1.0, 0.0, 0.0, 0.0).v, tol=tol) - - # copy independence - q_cp = qz.copy() - assert q_cp is not qz - q_cp.v[...] = 0.0 - assert_allclose(qz.v, np.array([np.cos(theta / 2.0), 0.0, 0.0, np.sin(theta / 2.0)]), tol=tol) - assert_allclose(q_cp.v, np.array([0.0, 0.0, 0.0, 0.0]), tol=gs.EPS) - - # tensor conversion - tq = qz.as_tensor() - assert isinstance(tq, torch.Tensor) - assert_allclose(tq, qz.v, tol=gs.EPS) - - def test_fps_tracker(): n_envs = 23 tracker = FPSTracker(alpha=0.0, minimum_interval_seconds=0.1, n_envs=n_envs) diff --git a/tests/test_viewer.py b/tests/test_viewer.py new file mode 100644 index 0000000000..eccc22d63a --- /dev/null +++ b/tests/test_viewer.py @@ -0,0 +1,152 @@ +import sys +import time + +import OpenGL.error +import pyglet +import pytest + +import genesis as gs +from genesis.vis.keybindings import Key, KeyAction, Keybind, KeyMod + +from .conftest import IS_INTERACTIVE_VIEWER_AVAILABLE +from .utils import rgb_array_to_png_bytes + +CAM_RES = (480, 320) + + +def wait_for_viewer_events(viewer, condition_fn, timeout=2.0, sleep_interval=0.1): + """Utility function to wait for viewer events to be processed in a threaded viewer.""" + if not viewer.run_in_thread: + viewer.dispatch_pending_events() + viewer.dispatch_events() + + for _ in range(int(timeout / sleep_interval)): + if condition_fn(): + return + time.sleep(sleep_interval) + else: + raise AssertionError("Keyboard event not processed before timeout") + + +@pytest.mark.required +@pytest.mark.skipif(not IS_INTERACTIVE_VIEWER_AVAILABLE, reason="Interactive viewer not supported on this platform.") +@pytest.mark.xfail(sys.platform == "win32", raises=OpenGL.error.Error, reason="Invalid OpenGL context.") +def test_interactive_viewer_disable_viewer_defaults(): + """Test that keyboard shortcuts can be disabled in the interactive viewer.""" + + # Test with keyboard shortcuts DISABLED + scene = gs.Scene( + viewer_options=gs.options.ViewerOptions( + disable_help_text=True, + disable_default_keybinds=True, + ), + profiling_options=gs.options.ProfilingOptions( + show_FPS=False, + ), + show_viewer=True, + ) + scene.build() + pyrender_viewer = scene.visualizer.viewer._pyrender_viewer + assert pyrender_viewer.is_active + + # Verify the flag is set correctly + assert pyrender_viewer._disable_help_text is True + # Verify that no keybindings are registered + assert len(pyrender_viewer._keybindings) == 0 + + +@pytest.mark.required +@pytest.mark.skipif(not IS_INTERACTIVE_VIEWER_AVAILABLE, reason="Interactive viewer not supported on this platform.") +def test_default_viewer_plugin(): + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=1e-2, + substeps=1, + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(2.0, 0.0, 1.0), + camera_lookat=(0.0, 0.0, 0.0), + camera_fov=30, + res=CAM_RES, + run_in_thread=(sys.platform == "linux"), + disable_help_text=False, + disable_default_keybinds=False, + ), + profiling_options=gs.options.ProfilingOptions( + show_FPS=False, + ), + show_viewer=True, + ) + + scene.add_entity(morph=gs.morphs.Plane()) + scene.add_entity( + morph=gs.morphs.Box( + pos=(0.0, 0.0, 0.2), + size=(0.2, 0.2, 0.2), + euler=(30, 40, 0), + ) + ) + scene.build() + + pyrender_viewer = scene.visualizer.viewer._pyrender_viewer + assert pyrender_viewer.is_active + + assert len(pyrender_viewer._keybindings) > 0, "Expected default keybindings to be registered." + + # Add a custom keybind + flags = [False, False, False] + + def toggle_flag(idx): + flags[idx] = not flags[idx] + + scene.viewer.register_keybinds( + Keybind( + name="toggle_flag_0", + key=Key._0, + key_action=KeyAction.PRESS, + callback=lambda: toggle_flag(0), + ), + Keybind( + name="toggle_flag_1", + key=Key._1, + key_action=KeyAction.PRESS, + key_mods=(KeyMod.SHIFT, KeyMod.CTRL), + callback=toggle_flag, + args=(1,), + ), + ) + + # Press key to toggle flag on + pyrender_viewer.dispatch_event("on_key_press", Key._0, 0) + # Press key with modifiers to toggle flag off + pyrender_viewer.dispatch_event("on_key_press", Key._1, KeyMod.SHIFT | KeyMod.CTRL) + # Press key toggle world frame + pyrender_viewer.dispatch_event("on_key_press", Key.W, 0) + + wait_for_viewer_events(pyrender_viewer, lambda: flags[0] and flags[1]) + + assert flags[0], "Expected custom keybind callback to toggle flag on." + assert flags[1], "Expected custom keybind with key modifiers to toggle flag on." + assert pyrender_viewer.gs_context.world_frame_shown, "Expected world frame to be shown after pressing 'W' key." + + # Remove the keybind and press key to verify it no longer works + scene.viewer.remove_keybind("toggle_flag_0") + pyrender_viewer.dispatch_event("on_key_press", Key._0, 0) + # Remap the keybind and check it works + scene.viewer.remap_keybind("toggle_flag_1", new_key=Key._2, new_key_mods=None) + pyrender_viewer.dispatch_event("on_key_press", Key._2, 0) + + wait_for_viewer_events(pyrender_viewer, lambda: not flags[1]) + + assert flags[0], "Keybind was not removed properly." + assert not flags[1], "Expected rebinded keybind to toggle flag off." + + # Error when remapping non-existent keybind + with pytest.raises(ValueError): + scene.viewer.remap_keybind("non_existent_keybind", new_key=Key._3, new_key_mods=None) + + # Error when adding a keybind with same key + with pytest.raises(ValueError): + scene.viewer.register_keybinds( + Keybind(name="conflicting_keybind", key=Key._2, key_action=KeyAction.PRESS, callback=lambda: None), + ) diff --git a/tests/utils.py b/tests/utils.py index 72a61055ea..e1b96b72cd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,22 +17,21 @@ from typing import Literal, Sequence import cpuinfo -import numpy as np import mujoco +import numpy as np import torch -from httpx import HTTPError as HTTPXError from httpcore import TimeoutException as HTTPTimeoutException +from httpx import HTTPError as HTTPXError from huggingface_hub import snapshot_download from PIL import Image, UnidentifiedImageError from requests.exceptions import HTTPError import genesis as gs import genesis.utils.geom as gu +from genesis.options.morphs import GLTF_FORMATS, MESH_FORMATS, MJCF_FORMAT, URDF_FORMAT, USD_FORMATS from genesis.utils import mjcf as mju from genesis.utils.mesh import get_assets_dir from genesis.utils.misc import tensor_to_array -from genesis.options.morphs import URDF_FORMAT, MJCF_FORMAT, MESH_FORMATS, GLTF_FORMATS, USD_FORMATS - REPOSITY_URL = "Genesis-Embodied-AI/Genesis" DEFAULT_BRANCH_NAME = "main"