diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 1055975d7bc..a1766185c06 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -59,6 +59,8 @@ title: Use Async Inference - local: rtc title: Real-Time Chunking (RTC) + - local: drtc + title: Distributed Real-Time Chunking (DRTC) title: "Inference" - sections: - local: envhub diff --git a/docs/source/drtc.mdx b/docs/source/drtc.mdx new file mode 100644 index 00000000000..bf10eab91ea --- /dev/null +++ b/docs/source/drtc.mdx @@ -0,0 +1,198 @@ +# Distributed Real-Time Chunking (DRTC) + +[Distributed Real-Time Chunking](https://jackvial.com/posts/distributed-real-time-chunking.html) (DRTC) extends [RTC](./rtc.mdx) to a distributed client-server setup. You can think of it as combining [RTC](./rtc.mdx)'s in-painting with [async inference](./async.mdx)'s networked client-server pattern. + +## Supported models + +[SmolVLA](./smolvla.mdx), [pi0](./pi0.mdx) and any other flow matching models should work. + +## Quick Start + +DRTC assumes you already have a working LeRobot environment. If not, follow the main [Installation Guide](./installation), then install the extras used by the default DRTC scripts: + +```bash +uv pip install -e ".[smolvla,async,feetech,scipy-dep]" +``` + +The examples below are currently set up around the default SO101 hardware profile in this repo. Configure your policy type and pretrained weight path in `examples/experiments/configs/baseline.yaml` before running. + +### Client + Server on Same Node + +If your policy server and robot client are running on the same machine, the simplest entrypoint is: + +```bash +./scripts/run_drtc_experiment.sh \ + --config examples/experiments/configs/baseline.yaml +``` + +This starts the DRTC policy server locally and then runs the experiment client against it. Add `--viz` if you also want the trajectory visualization server on `http://localhost:8088`. + +### Remote Server With Local Client + +Use this setup when you want the policy server on a remote GPU machine while keeping the robot client on your local robot computer. + +> Note: The workflow below uses Prime Intellect for cloud GPUs and Tailscale for secure networking because that is one setup currently documented in this repo. They are not required. Any comparable cloud GPU provider and VPN or private network setup should work as well. + +#### Prerequisites + +- A Prime Intellect account: https://www.primeintellect.ai/ +- A local `~/.prime/config.json` containing your API key and SSH key path for `./scripts/provision_prime_lerobot.sh` +- A Tailscale network shared by the local client and remote server: https://tailscale.com/ + +#### 1. Provision a remote policy server + +Run from the repository root: + +```bash +./scripts/provision_prime_lerobot.sh +``` + +This script searches for available GPUs with the required CUDA image, lets you choose one, provisions the instance, clones the repo, installs dependencies, sets up Tailscale, and prints: + +- SSH connection details (`user@host` and port) +- The Tailscale domain for the remote machine + +To resume setup on an existing pod (for example, after a network interruption): + +```bash +./scripts/provision_prime_lerobot.sh --pod-id +``` + +#### 2. Start the policy server on the remote machine + +SSH to the provisioned machine using the connection details printed at the end of provisioning, then start the policy server: + +```bash +ssh -i -p @ +cd /workspace/drtc +./scripts/start_drtc_server.sh +``` + +Leave this process running while the local client connects. + +#### 3. Start the local client against the remote server + +From your local client or robot machine, make sure you are connected to the same Tailscale network as the remote server, then run: + +```bash +./scripts/run_drtc_experiment_with_remote_server.sh \ + --remote-server-host \ + --config examples/experiments/configs/baseline.yaml +``` + +## API Usage + +After getting familiar with the quick start you will likely want to interact with the DRTC client and server APIs directly. Reference implementations can be found in: + +- Client: `examples/tutorial/async-inf/robot_client_drtc.py` +- Server: `src/lerobot/async_inference/policy_server_drtc.py` + +### Minimal Client API Usage Example + +```python +import threading + +from lerobot.async_inference.robot_client_drtc import RobotClientDrtc +from lerobot.async_inference.configs_drtc import RobotClientDrtcConfig +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.robots.so101_follower import SO101FollowerConfig + +camera_cfg = { + "camera1": OpenCVCameraConfig( + index_or_path="/dev/video0", + width=800, + height=600, + fps=30, + fourcc="MJPG", + use_threaded_async_read=True, + allow_stale_frames=True, + ), + "camera2": OpenCVCameraConfig( + index_or_path="/dev/video2", + width=800, + height=600, + fps=30, + fourcc="MJPG", + use_threaded_async_read=True, + allow_stale_frames=True, + ), +} + +robot_cfg = SO101FollowerConfig( + port="/dev/ttyACM0", + id="so101_follower", + cameras=camera_cfg, +) + +client_cfg = RobotClientDrtcConfig( + robot=robot_cfg, + server_address="127.0.0.1:8080", + policy_device="cuda", + policy_type="smolvla", + pretrained_name_or_path="jackvial/so101_smolvla_pickplaceorangecube_e100", + actions_per_chunk=50, + fps=60, + s_min=15, + epsilon=2, + rtc_sigma_d=0.2, + rtc_full_trajectory_alignment=False, + num_flow_matching_steps=None, + action_filter_mode="butterworth", + action_filter_past_buffer_size=10, + action_filter_butterworth_cutoff=3.0, + action_filter_butterworth_order=2, + action_filter_gain=1.4, + metrics_diagnostic_enabled=True, + control_use_deadline_clock=True, + obs_fallback_on_failure=True, + obs_fallback_max_age_s=2.0, + trajectory_viz_enabled=True, + trajectory_viz_ws_url="ws://localhost:8089", +) + +client = RobotClientDrtc(client_cfg) + +if client.start(): + observation_thread = threading.Thread( + target=client.observation_sender, + name="observation_sender", + daemon=True, + ) + action_thread = threading.Thread( + target=client.action_receiver, + name="action_receiver", + daemon=True, + ) + + observation_thread.start() + action_thread.start() + + try: + client.control_loop( + "Pick up the orange cube and place it on the black X marker with the white background" + ) + finally: + client.stop() + observation_thread.join(timeout=2.0) + action_thread.join(timeout=2.0) +``` + +### Minimal Server API Usage Example + +```python +from lerobot.async_inference.configs_drtc import PolicyServerDrtcConfig +from lerobot.async_inference.policy_server_drtc import serve_drtc + +serve_drtc( + PolicyServerDrtcConfig( + host="0.0.0.0", + port=8080, + fps=30, + warmup_passes=2, + trajectory_viz_enabled=True, + trajectory_viz_http_port=8088, + trajectory_viz_ws_port=8089, + ) +) +``` + diff --git a/examples/experiments/__init__.py b/examples/experiments/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/experiments/configs/baseline.yaml b/examples/experiments/configs/baseline.yaml new file mode 100644 index 00000000000..aed046c7bbc --- /dev/null +++ b/examples/experiments/configs/baseline.yaml @@ -0,0 +1,40 @@ +# Mixture of faults -- drops, duplicates, reordering, and disconnect combined +name: baseline_cloud_raspberry + +# Hardware +robot_type: so101 +gpu: RTX 4090 +client_host: raspberrypiv5 +server_host: cloudserver4090 + +# Policy +policy_type: smolvla +pretrained_name_or_path: jackvial/so101_smolvla_pickplaceorangecube_e100 + +# DRTC parameters +estimator: jk +cooldown: true +latency_k: 1.5 +epsilon: 1 +actions_per_chunk: 50 +s_min: 20 +latency_alpha: 0.125 +latency_beta: 0.25 + +# Timing +duration_s: 25.0 +fps: 60 + +# Flow matching / RTC +num_flow_matching_steps: 8 +rtc_enabled: true +rtc_sigma_d: 0.2 +rtc_prefix_attention_schedule: linear + +# Butterworth filter +action_filter_mode: butterworth +action_filter_butterworth_cutoff: 3.0 +action_filter_gain: 1.4 + +# Diagnostics +full_diagnostics: true diff --git a/examples/experiments/configs/lww_mixture_of_faults.yaml b/examples/experiments/configs/lww_mixture_of_faults.yaml new file mode 100644 index 00000000000..ad081d9fde0 --- /dev/null +++ b/examples/experiments/configs/lww_mixture_of_faults.yaml @@ -0,0 +1,56 @@ +# Mixture of faults -- drops, duplicates, reordering, and disconnect combined +name: mixture_of_faults_cloud_raspberry + +# Hardware +robot_type: so101 +gpu: RTX 4090 +client_host: raspberrypiv5 +server_host: cloudserver4090 + +# Policy +policy_type: smolvla +pretrained_name_or_path: jackvial/so101_smolvla_pickplaceorangecube_e100 + +# DRTC parameters +estimator: jk +cooldown: true +latency_k: 1.5 +epsilon: 1 +s_min: 20 +latency_alpha: 0.125 +latency_beta: 0.25 + +# Timing +duration_s: 25.0 +fps: 60 +actions_per_chunk: 50 + +# Butterworth filter +action_filter_mode: butterworth +action_filter_butterworth_cutoff: 3.0 +action_filter_gain: 1.4 + +# Flow matching / RTC +num_flow_matching_steps: 8 +rtc_enabled: true +rtc_sigma_d: 0.2 +rtc_prefix_attention_schedule: linear + +# Diagnostics +full_diagnostics: false + +# Fault injection +drop_obs: + - {start_s: 4.0, duration_s: 2.0} +drop_action: + - {start_s: 7.0, duration_s: 2.0} +dup_obs: + - {start_s: 10.0, duration_s: 2.0} +dup_action: + - {start_s: 13.0, duration_s: 2.0} +reorder_obs: + - {start_s: 16.0, duration_s: 3.0} +reorder_action: + - {start_s: 20.0, duration_s: 3.0} +# disconnect: +# - {start_s: 22.0, duration_s: 2.0} diff --git a/examples/experiments/configs/spikes_jk.yaml b/examples/experiments/configs/spikes_jk.yaml new file mode 100644 index 00000000000..48812553dfe --- /dev/null +++ b/examples/experiments/configs/spikes_jk.yaml @@ -0,0 +1,45 @@ +# Compare estimators under spike conditions +defaults: + # Hardware + robot_type: so101 + gpu: RTX 4090 + client_host: raspberrypiv5 + server_host: cloudserver4090 + + # Policy + policy_type: smolvla + pretrained_name_or_path: jackvial/so101_smolvla_pickplaceorangecube_e100 + + # DRTC parameters + cooldown: true + latency_k: 1.5 + epsilon: 1 + s_min: 20 + latency_alpha: 0.125 + latency_beta: 0.25 + + # Flow matching / RTC + num_flow_matching_steps: 8 + rtc_enabled: true + rtc_sigma_d: 0.2 + rtc_prefix_attention_schedule: linear + + # Timing + duration_s: 25.0 + fps: 60 + actions_per_chunk: 50 + + # Butterworth filter + action_filter_mode: butterworth + action_filter_butterworth_cutoff: 3.0 + action_filter_gain: 1.4 + + # Diagnostics + full_diagnostics: false + +experiments: + - name: spikes_jk_cloud_raspberry + estimator: jk + spikes: + - {start_s: 5.0, delay_ms: 500} + - {start_s: 15.0, delay_ms: 1000} diff --git a/examples/experiments/configs/spikes_max_last_10.yaml b/examples/experiments/configs/spikes_max_last_10.yaml new file mode 100644 index 00000000000..890f8ff344f --- /dev/null +++ b/examples/experiments/configs/spikes_max_last_10.yaml @@ -0,0 +1,45 @@ +# Compare estimators under spike conditions +defaults: + # Hardware + robot_type: so101 + gpu: RTX 4090 + client_host: raspberrypiv5 + server_host: cloudserver4090 + + # Policy + policy_type: smolvla + pretrained_name_or_path: jackvial/so101_smolvla_pickplaceorangecube_e100 + + # DRTC parameters + cooldown: true + latency_k: 2.0 + epsilon: 2 + s_min: 20 + latency_alpha: 0.125 + latency_beta: 0.25 + + # Flow matching / RTC + num_flow_matching_steps: 8 + rtc_enabled: true + rtc_sigma_d: 0.2 + rtc_prefix_attention_schedule: linear + + # Timing + duration_s: 25.0 + fps: 60 + actions_per_chunk: 50 + + # Butterworth filter + action_filter_mode: butterworth + action_filter_butterworth_cutoff: 3.0 + action_filter_gain: 1.4 + + # Diagnostics + full_diagnostics: false + +experiments: + - name: spikes_max_last_10_cloud_raspberry + estimator: max_last_10 + spikes: + - {start_s: 5.0, delay_ms: 500} + - {start_s: 15.0, delay_ms: 1000} diff --git a/examples/experiments/plot_results.py b/examples/experiments/plot_results.py new file mode 100644 index 00000000000..77dfce08f33 --- /dev/null +++ b/examples/experiments/plot_results.py @@ -0,0 +1,1420 @@ +#!/usr/bin/env python3 +""" +Plot results from DRTC experiments. + +Usage: + # Plot an experiment directory (finds CSV + trajectory JSON automatically) + uv run python examples/experiments/plot_results.py \ + --input results/experiments/drop_obs_00 + + # Plot a single CSV file + uv run python examples/experiments/plot_results.py \ + --input results/experiments/drop_obs_00/drop_obs_00.csv + + # Compare estimators (filter by pattern) + uv run python examples/experiments/plot_results.py \ + --input results/ \ + --filter "estimator_" \ + --mode estimator_comparison \ + --output results/estimator_comparison.png + + # Show latency spikes with measured RTT + uv run python examples/experiments/plot_results.py \ + --input results/spike_experiment.csv \ + --mode detailed \ + --output results/spike_detailed.png +""" + +import argparse +import json +import shutil +import subprocess +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.patches import Patch + + +def setup_paper_style(): + """Configure matplotlib rcParams for clean, academic paper-ready plots.""" + plt.rcParams.update( + { + "font.family": "sans-serif", + "font.size": 11, + "axes.titlesize": 13, + "axes.labelsize": 11, + "xtick.labelsize": 10, + "ytick.labelsize": 10, + "legend.fontsize": 9, + "axes.spines.top": False, + "axes.spines.right": False, + "axes.grid": True, + "grid.alpha": 0.3, + "grid.linewidth": 0.5, + "lines.linewidth": 1.5, + "lines.markersize": 5, + "figure.facecolor": "white", + "axes.facecolor": "white", + "savefig.dpi": 300, + "savefig.bbox": "tight", + "figure.dpi": 100, + } + ) + + +def _latex_escape(text: str) -> str: + """Escape special LaTeX characters in *text*.""" + # Order matters: ampersand first so we don't double-escape later subs. + for char, replacement in [ + ("&", r"\&"), + ("%", r"\%"), + ("$", r"\$"), + ("#", r"\#"), + ("_", r"\_"), + ("{", r"\{"), + ("}", r"\}"), + ("~", r"\textasciitilde{}"), + ("^", r"\textasciicircum{}"), + ]: + text = text.replace(char, replacement) + return text + + +# Human-readable labels for each config key, in display order. +_CONFIG_DISPLAY = [ + # Robot / hardware + ("robot_type", "Robot type"), + ("gpu", "GPU"), + ("client_host", "Client host"), + ("server_host", "Server host"), + ("num_cameras", "Number of cameras"), + ("cameras", "Cameras"), + # Policy + ("policy_type", "Policy type"), + ("pretrained_name_or_path", "Model path"), + ("chunk_size", "Chunk size"), + ("fps", "FPS"), + ("s_min", r"$s_{\min}$"), + ("epsilon", r"$\epsilon$"), + ("latency_estimator_type", "Latency estimator"), + ("latency_alpha", r"$\alpha$"), + ("latency_beta", r"$\beta$"), + ("latency_k", r"$K$"), + # Flow matching / RTC + ("num_flow_matching_steps", "Flow matching steps"), + ("rtc_enabled", "RTC enabled"), + ("rtc_max_guidance_weight", r"RTC max guidance weight ($\beta$)"), + ("rtc_prefix_attention_schedule", "RTC attention schedule"), + ("rtc_sigma_d", r"RTC $\sigma_d$"), + ("rtc_full_trajectory_alignment", "RTC full trajectory alignment"), + # Action filter + ("filter_type", "Filter type"), + ("filter_cutoff", "Filter cutoff (Hz)"), + ("gain", "Gain"), +] + +# Human-readable names for latency estimator types. +_ESTIMATOR_DISPLAY_NAMES = { + "jk": "Jacobson--Karels", + "max_last_10": "Max of last 10", + "fixed": "Fixed", +} + +# Human-readable labels for simulation config fault-injection keys. +_SIM_CONFIG_DISPLAY = [ + ("drop_obs", "Drop obs"), + ("drop_action", "Drop action"), + ("dup_obs", "Duplicate obs"), + ("dup_action", "Duplicate action"), + ("reorder_obs", "Reorder obs"), + ("reorder_action", "Reorder action"), + ("disconnect", "Disconnect"), + ("spikes", "Spikes"), +] + + +def _format_fault_windows(windows: list[dict]) -> str: + """Format a list of fault-injection window dicts into a compact string. + + Each window is rendered as ``start_s`` -- ``+duration_s`` (or for spikes, + ``start_s`` @ ``delay_ms`` ms). Multiple windows are comma-separated. + Returns ``---`` when the list is empty. + """ + if not windows: + return "---" + parts = [] + for w in windows: + if "delay_ms" in w: + # Spike event + parts.append(f"{w.get('start_s', 0):.1f}s @ {w['delay_ms']}ms") + else: + start = w.get("start_s", 0) + dur = w.get("duration_s", 0) + parts.append(f"{start:.1f}s +{dur:.1f}s") + return ", ".join(parts) + + +def generate_config_table( + experiment_config: dict, + simulation_config: dict | None = None, +) -> str: + """Return a LaTeX ``table`` environment summarising the experiment config. + + The table has two columns (Parameter / Value) and uses the ``booktabs`` + package for clean horizontal rules. Long model paths are wrapped in + ``\\texttt``. + + When *simulation_config* is provided, a second section is appended + showing fault-injection windows (drops, duplicates, reorders, + disconnects, spikes). + + Args: + experiment_config: Dict of config key/value pairs (as written by + ``ExperimentMetricsWriter``). + simulation_config: Optional dict of fault-injection window lists + (as written to the trajectory JSON under ``simulation_config``). + + Returns: + A string of LaTeX source for the config table (no surrounding + ``\\begin{document}``). + """ + rows: list[str] = [] + for key, label in _CONFIG_DISPLAY: + value = experiment_config.get(key) + if value is None: + value_str = "N/A" + elif key == "pretrained_name_or_path": + # Render model path in monospace; allow line-break at slashes. + escaped = _latex_escape(str(value)) + value_str = r"\texttt{" + escaped + "}" + elif key == "latency_estimator_type": + display = _ESTIMATOR_DISPLAY_NAMES.get(str(value), str(value)) + value_str = _latex_escape(display) + elif isinstance(value, float): + # Use a sensible number of decimals. + value_str = f"{value:g}" + else: + value_str = _latex_escape(str(value)) + rows.append(f" {label} & {value_str} \\\\") + + # Fault-injection section (only if any faults are configured) + if simulation_config: + has_any_faults = any(simulation_config.get(key) for key, _ in _SIM_CONFIG_DISPLAY) + if has_any_faults: + rows.append(" \\midrule") + rows.append(" \\multicolumn{2}{l}{\\textit{Fault Injection}} \\\\") + for key, label in _SIM_CONFIG_DISPLAY: + windows = simulation_config.get(key, []) + if windows: + value_str = _latex_escape(_format_fault_windows(windows)) + rows.append(f" {_latex_escape(label)} & {value_str} \\\\") + + rows_str = "\n".join(rows) + return ( + "\\begin{table}[htbp]\n" + " \\centering\n" + " \\caption{Experiment Configuration}\n" + " \\begin{tabular}{l p{0.65\\textwidth}}\n" + " \\toprule\n" + " Parameter & Value \\\\\n" + " \\midrule\n" + f"{rows_str}\n" + " \\bottomrule\n" + " \\end{tabular}\n" + "\\end{table}\n" + ) + + +# Kandinsky-inspired color palette (from trajectory_viz.html) +CHUNK_COLORS = [ + "#c1272d", # vermilion + "#1a3a6e", # ultramarine + "#f4c430", # cadmium yellow + "#e85d04", # orange + "#5c3d6e", # purple + "#2d6a4f", # deep green + "#1a1a1a", # black + "#8b7355", # ochre + "#0077b6", # cerulean + "#9d4edd", # violet +] + + +def load_trajectory_data(csv_path: Path) -> dict | None: + """Load trajectory JSON data corresponding to a CSV file. + + The trajectory file is expected to be at the same path as the CSV + but with a '.trajectory.json' suffix. + """ + trajectory_path = csv_path.with_suffix(".trajectory.json") + if not trajectory_path.exists(): + return None + + with open(trajectory_path) as f: + return json.load(f) + + +# Default joint names for the SO101 follower robot +_SO101_JOINT_NAMES = [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", +] + + +def plot_trajectory_on_axis( + ax, + trajectory_data: dict, + joint_names: list[str] | None = None, +): + """Plot all joint trajectories on a given axis. + + Plots each joint dimension as a separate coloured line. Gaps in the + lines correspond to stalls (timesteps where no action was executed). + + Args: + ax: Matplotlib axis to plot on. + trajectory_data: Dict with ``executed`` list of action records. + joint_names: Human-readable names for each joint dimension. + Defaults to the SO101 follower joint names. + """ + if trajectory_data is None: + ax.text( + 0.5, + 0.5, + "No trajectory data", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=10, + color="gray", + ) + return + + executed = trajectory_data.get("executed", []) + + if not executed: + ax.text( + 0.5, + 0.5, + "No executed actions", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=10, + color="gray", + ) + return + + if joint_names is None: + joint_names = _SO101_JOINT_NAMES + + t0 = min(e["t"] for e in executed) + n_joints = len(executed[0]["action"]) + + for j in range(n_joints): + times = [(e["t"] - t0) for e in executed if j < len(e["action"])] + values = [e["action"][j] for e in executed if j < len(e["action"])] + label = joint_names[j] if j < len(joint_names) else f"joint {j}" + color = CHUNK_COLORS[j % len(CHUNK_COLORS)] + ax.scatter(times, values, s=3, color=color, alpha=0.8, label=label, linewidths=0) + + ax.set_ylabel("Position") + + +# Colors for event types (sim events + obs/action from CSV) +_SIM_EVENT_COLORS = { + "obs_triggered": "#3498db", # blue + "action_received": "#e67e22", # orange + "obs_dropped": "#c1272d", # vermilion + "obs_reorder_held": "#f4c430", # cadmium yellow + "obs_reorder_swapped": "#e85d04", # orange + "obs_duplicated": "#5c3d6e", # purple + "action_dropped": "#1a3a6e", # ultramarine + "action_reorder_held": "#0077b6", # cerulean + "action_reorder_swapped": "#2d6a4f", # deep green + "action_duplicated": "#9d4edd", # violet + "disconnect": "#333333", # dark gray + "spike": "#e74c3c", # red +} + +# Y-position for each event type so they don't overlap (contiguous) +_SIM_EVENT_YPOS = { + "obs_triggered": 11, + "action_received": 10, + "disconnect": 9, + "spike": 8, + "obs_dropped": 7, + "obs_reorder_held": 6, + "obs_reorder_swapped": 5, + "obs_duplicated": 4, + "action_dropped": 3, + "action_reorder_held": 2, + "action_reorder_swapped": 1, + "action_duplicated": 0, +} + +# Shading colors for configured simulation windows +_SIM_CONFIG_COLORS = { + "drop_obs": ("#c1272d", 0.12), + "drop_action": ("#1a3a6e", 0.12), + "dup_obs": ("#5c3d6e", 0.10), + "dup_action": ("#9d4edd", 0.10), + "reorder_obs": ("#f4c430", 0.10), + "reorder_action": ("#2d6a4f", 0.10), + "disconnect": ("#333333", 0.18), + "spikes": ("#e74c3c", 0.12), +} + + +def plot_gantt_on_axis( + ax, + trajectory_data: dict, + sim_config_offset: float = 0.0, +) -> None: + """Plot a Gantt chart of configured fault-injection windows. + + Each fault type gets its own horizontal lane with coloured bars showing + when the fault is active. Spikes are rendered as bars whose width + corresponds to the injected delay duration. + + Args: + ax: Matplotlib axis. + trajectory_data: Trajectory JSON dict (must contain ``simulation_config``). + sim_config_offset: Seconds between experiment start (CSV t0) and the + first executed action (trajectory t0). Subtracted from ``start_s`` + values so the bars align with other subplots. + """ + sim_config = trajectory_data.get("simulation_config", {}) + if not sim_config: + ax.text( + 0.5, + 0.5, + "No simulation config", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=10, + color="gray", + ) + return + + # Collect lanes: only include fault types that have at least one window. + # Order follows _SIM_CONFIG_COLORS so related faults are grouped. + lane_keys: list[str] = [] + for key in _SIM_CONFIG_COLORS: + windows = sim_config.get(key, []) + if windows: + lane_keys.append(key) + + if not lane_keys: + ax.text( + 0.5, + 0.5, + "No fault windows configured", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=10, + color="gray", + ) + return + + n_lanes = len(lane_keys) + bar_height = 0.6 + + for i, lane in enumerate(lane_keys): + y = i # y-position for this lane + windows = sim_config.get(lane, []) + color, _alpha = _SIM_CONFIG_COLORS.get(lane, ("#888888", 0.5)) + + is_spike = lane == "spikes" + + # Build (start, width) tuples and labels for each window. + # Spikes use delay_ms converted to seconds; other faults use duration_s. + bars: list[tuple[float, float]] = [] + labels: list[str] = [] + for w in windows: + start = w.get("start_s", 0) - sim_config_offset + if is_spike: + dur = w.get("delay_ms", 0) / 1000.0 + labels.append(f"{w.get('delay_ms', 0):.0f}ms") + else: + dur = w.get("duration_s", 0) + labels.append(f"{dur:.1f}s") + bars.append((start, dur)) + + ax.broken_barh( + bars, + (y - bar_height / 2, bar_height), + facecolors=color, + alpha=0.7, + edgecolors="white", + linewidth=0.5, + ) + for (start, dur), label in zip(bars, labels, strict=False): + if is_spike: + ax.text( + start + dur / 2, + y, + label, + ha="center", + va="center", + fontsize=8, + color="white", + fontweight="bold", + rotation=-90, + rotation_mode="anchor", + ) + else: + ax.text( + start + dur / 2, + y, + label, + ha="center", + va="center", + fontsize=8, + color="white", + fontweight="bold", + ) + + # Configure axis + ax.set_yticks(range(n_lanes)) + ax.set_yticklabels( + [lane.replace("_", " ") for lane in lane_keys], + fontsize=9, + ) + # ax.set_ylim(-0.5, n_lanes - 0.5) + # ax.set_ylabel("Fault Schedule") + + +# Lane configuration for the latency breakdown Gantt chart. +_LATENCY_GANTT_LANES = [ + ("total", "Total", "#e74c3c"), # red + ("client_to_server", "Client \u2192 Server", "#0077b6"), # cerulean + ("model_inference", "Model Inference", "#2d6a4f"), # deep green + ("server_to_client", "Server \u2192 Client", "#e85d04"), # orange +] + + +def plot_latency_gantt_on_axis( + ax, + df: pd.DataFrame, + time_offset: float = 0.0, +) -> None: + """Plot a two-lane horizontal stacked bar chart of inference latency breakdown. + + Each inference round-trip (rows where ``action_received == 1`` and all + timestamp columns are present) is rendered across two lanes: a Total + bar (top) and a Breakdown bar (bottom) with three end-to-end colored + segments (Client -> Server, Model Inference, Server -> Client). + + Supports both new timestamp columns (``obs_sent_ts``, etc.) and legacy + duration columns (``client_to_server_ms``, etc.) for backward + compatibility with older CSVs. + + Args: + ax: Matplotlib axis. + df: Experiment DataFrame. + time_offset: Seconds to subtract from ``t_relative`` so bar + positions align with other subplots. + """ + # Detect whether we have new timestamp columns or legacy duration columns + ts_cols = ["obs_sent_ts", "server_obs_received_ts", "server_action_sent_ts", "action_received_ts"] + legacy_cols = ["measured_latency_ms", "client_to_server_ms", "model_inference_ms", "server_to_client_ms"] + use_timestamps = all(c in df.columns for c in ts_cols) + use_legacy = all(c in df.columns for c in legacy_cols) + + if not use_timestamps and not use_legacy: + ax.text( + 0.5, + 0.5, + "No latency breakdown data", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=10, + color="gray", + ) + return + + # Filter to rows where action was received and all required cols exist + check_cols = ts_cols if use_timestamps else legacy_cols + mask = df["action_received"] == 1 + for col in check_cols: + mask = mask & df[col].notna() + rtt_rows = df[mask] + + if len(rtt_rows) == 0: + ax.text( + 0.5, + 0.5, + "No latency breakdown data", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=10, + color="gray", + ) + return + + bar_height = 0.6 + + # Compute a minimum visual bar width so that very short phases + # (e.g. 1ms client->server on a 25s axis) are still visible. + t_all = df["t_relative"] - time_offset + t_span = t_all.max() - t_all.min() if len(t_all) > 1 else 1.0 + min_bar_width = t_span * 0.004 # ~0.4% of axis width + + # Reference t0 for converting absolute timestamps to relative x-axis values + csv_t0 = df["t"].iloc[0] + + # Colors for the two-lane layout + total_color = "#e74c3c" # red + c2s_color = "#0077b6" # cerulean + model_color = "#2d6a4f" # deep green + s2c_color = "#e85d04" # orange + + for _, row in rtt_rows.iterrows(): + if use_timestamps: + # Derive bar positions directly from wall-clock timestamps + t_send = row["obs_sent_ts"] - csv_t0 - time_offset + t_server_recv = row["server_obs_received_ts"] - csv_t0 - time_offset + t_server_send = row["server_action_sent_ts"] - csv_t0 - time_offset + t_recv = row["action_received_ts"] - csv_t0 - time_offset + + rtt_s = t_recv - t_send + c2s_s = t_server_recv - t_send + model_s = t_server_send - t_server_recv + s2c_s = t_recv - t_server_send + else: + # Legacy: derive from duration columns + t_recv = row["t_relative"] - time_offset + rtt_s = row["measured_latency_ms"] / 1000.0 + c2s_s = row["client_to_server_ms"] / 1000.0 + model_s = row["model_inference_ms"] / 1000.0 + s2c_s = row["server_to_client_ms"] / 1000.0 + t_send = t_recv - rtt_s + t_server_recv = t_send + c2s_s + t_server_send = t_server_recv + model_s + + # Lane 1 (top, y=1): Total round-trip bar + ax.broken_barh( + [(t_send, max(rtt_s, min_bar_width))], + (1 - bar_height / 2, bar_height), + facecolors=total_color, + alpha=0.7, + edgecolors="white", + linewidth=0.3, + ) + + # Lane 0 (bottom, y=0): Breakdown segments tiled end-to-end. + # Compute visual widths first, then place each segment right + # after the previous one so min_bar_width clamping never causes + # overlap. + vis_c2s = max(c2s_s, min_bar_width) + vis_model = max(model_s, min_bar_width) + vis_s2c = max(s2c_s, min_bar_width) + + seg_start = t_send + for dur, color in [ + (vis_c2s, c2s_color), + (vis_model, model_color), + (vis_s2c, s2c_color), + ]: + ax.broken_barh( + [(seg_start, dur)], + (0 - bar_height / 2, bar_height), + facecolors=color, + alpha=0.7, + edgecolors="white", + linewidth=0.3, + ) + seg_start += dur + + # Vertical guide lines at phase boundaries (span both lanes) + for bx in [t_send, t_send + vis_c2s, t_send + vis_c2s + vis_model, seg_start]: + ax.axvline(bx, color="#888888", linewidth=0.4, alpha=0.5) + + # Configure axis -- 2 lanes + ax.set_yticks([0, 1]) + ax.set_yticklabels(["Breakdown", "Total"], fontsize=9) + ax.set_ylim(-0.5, 1.5) + ax.set_title("Latency Breakdown") + + # Legend with proxy artists for all 4 colors + legend_handles = [ + Patch(facecolor=total_color, alpha=0.7, label="Total"), + Patch(facecolor=c2s_color, alpha=0.7, label="Client \u2192 Server"), + Patch(facecolor=model_color, alpha=0.7, label="Model Inference"), + Patch(facecolor=s2c_color, alpha=0.7, label="Server \u2192 Client"), + ] + ax.legend(handles=legend_handles, loc="upper right", fontsize=8) + + +def plot_sim_events_on_axis( + ax, + trajectory_data: dict, + sim_config_offset: float = 0.0, + df: pd.DataFrame | None = None, +) -> None: + """Plot events timeline on a given axis. + + Shows actual recorded sim events as scatter markers. Spike events + from the simulation config are shown as scatter markers in a + dedicated lane. When *df* is provided, ``obs_triggered`` and + ``action_received`` events from the CSV are also plotted as + additional lanes. + + Args: + ax: Matplotlib axis. + trajectory_data: Trajectory JSON dict. + sim_config_offset: Seconds between experiment start (CSV t0) and + the first executed action (trajectory t0). The ``start_s`` + values in the simulation config are relative to experiment + start, so we subtract this offset to align them with the + trajectory t0 used by all other subplots. + df: Optional experiment DataFrame (for obs_triggered / action_received). + """ + sim_events = trajectory_data.get("sim_events", []) + sim_config = trajectory_data.get("simulation_config", {}) + executed = trajectory_data.get("executed", []) + + if not sim_events and not sim_config and df is None: + ax.text( + 0.5, + 0.5, + "No events", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=10, + color="gray", + ) + return + + # Derive t0 from executed actions (same baseline as trajectory plot) + t0 = min(e["t"] for e in executed) if executed else 0.0 + + # --- Collect all events (sim events + obs/action from CSV) --- + events_by_type: dict[str, list[float]] = {} + + # Sim events from trajectory JSON + for ev in sim_events: + etype = ev["event_type"] + t_rel = ev["t"] - t0 + events_by_type.setdefault(etype, []).append(t_rel) + + # Spike config entries (point events from simulation_config) + spikes = sim_config.get("spikes", []) + for spike in spikes: + t_spike = spike.get("start_s", 0) - sim_config_offset + events_by_type.setdefault("spike", []).append(t_spike) + + # Obs sent / action received from CSV DataFrame + if df is not None: + # CSV times are relative to CSV t0; shift by sim_config_offset + # so they align with the trajectory t0 baseline. + t_csv = df["t_relative"] - sim_config_offset + obs_mask = df["obs_triggered"] == 1 + if obs_mask.any(): + events_by_type["obs_triggered"] = t_csv[obs_mask].tolist() + act_mask = df["action_received"] == 1 + if act_mask.any(): + events_by_type["action_received"] = t_csv[act_mask].tolist() + + # Determine active lanes (preserve order from _SIM_EVENT_YPOS) + # Assign contiguous y-positions so there are no gaps between active lanes. + active_types = [k for k in _SIM_EVENT_YPOS if k in events_by_type] + active_ypos = {k: i for i, k in enumerate(reversed(active_types))} + + for etype in active_types: + times = events_by_type[etype] + y = active_ypos[etype] + color = _SIM_EVENT_COLORS.get(etype, "#888888") + ax.scatter( + times, + [y] * len(times), + marker="|", + s=40, + color=color, + alpha=0.8, + ) + + # Configure axis + if active_types: + ax.set_yticks([active_ypos[k] for k in active_types]) + ax.set_yticklabels([k.replace("_", " ") for k in active_types], fontsize=8) + ax.set_ylim(-0.5, len(active_types) - 0.5) + + +def load_experiment_data(csv_path: Path) -> pd.DataFrame: + """Load and preprocess experiment CSV data. + + Handles corrupted CSVs that contain concatenated runs (caused by a + race between ``signal_stop`` and ``stop`` both calling ``flush()``). + Detection: when ``step`` resets back to -1 after having been positive, + the earlier rows are discarded and only the last complete run is kept. + Rows with clearly anomalous timestamps (> 3 IQR from the median) are + also dropped. + """ + df = pd.read_csv(csv_path) + + # --- Detect concatenated runs --- + # When the CSV contains data from two flushes of the same experiment, + # `step` jumps from a positive value back to -1. Keep only the last + # contiguous run (the complete one). + if "step" in df.columns: + step = df["step"].values + # Find indices where step resets from >= 0 back to -1 + reset_indices = [] + for i in range(1, len(step)): + if step[i] == -1 and step[i - 1] >= 0: + reset_indices.append(i) + if reset_indices: + last_reset = reset_indices[-1] + n_dropped = last_reset + df = df.iloc[last_reset:].reset_index(drop=True) + print( + f" WARNING: CSV contains concatenated runs; " + f"dropped first {n_dropped} rows, keeping last run ({len(df)} rows)" + ) + + # --- Drop rows with anomalous timestamps --- + # A flush race can also truncate a timestamp (e.g. '773263.12' instead + # of '1770773263.12'). Detect these as outliers relative to the median. + t = df["t"] + t_median = t.median() + t_iqr = t.quantile(0.75) - t.quantile(0.25) + if t_iqr > 0: + lower = t_median - 10 * t_iqr + upper = t_median + 10 * t_iqr + bad_mask = (t < lower) | (t > upper) + n_bad = bad_mask.sum() + if n_bad > 0: + df = df[~bad_mask].reset_index(drop=True) + print(f" WARNING: Dropped {n_bad} rows with anomalous timestamps") + + # Normalize timestamps to start at 0 + df["t_relative"] = df["t"] - df["t"].iloc[0] + + # Calculate rolling stall rate (30-sample window ~= 1 second at 30fps) + df["stall_rolling"] = df["stall"].rolling(window=30, min_periods=1).mean() + + # Convert measured_latency_ms and timestamp columns to numeric (may have empty strings) + for col in [ + "measured_latency_ms", + "latency_estimate_ms", + "obs_sent_ts", + "server_obs_received_ts", + "server_action_sent_ts", + "action_received_ts", + ]: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce") + + # Backward compat: also handle legacy duration columns from older CSVs + for col in ["client_to_server_ms", "model_inference_ms", "server_to_client_ms"]: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce") + + # Backward compat: older CSVs used "obs_sent" instead of "obs_triggered" + if "obs_triggered" not in df.columns and "obs_sent" in df.columns: + df.rename(columns={"obs_sent": "obs_triggered"}, inplace=True) + + # Convert L2 columns to numeric + for col in ["chunk_mean_l2", "chunk_max_l2"]: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce") + + return df + + +def plot_single_experiment( + df: pd.DataFrame, + title: str, + ax_cooldown, + ax_latency, + ax_schedule=None, + ax_events=None, + time_offset: float = 0.0, +): + """Plot a single experiment's data across 2-4 subplots. + + When *ax_events* is ``None`` (trajectory mode), the obs/action events + are merged into the combined events subplot elsewhere. + + Args: + ax_schedule: Optional axis for the action schedule size plot. + time_offset: Seconds to subtract from ``t_relative`` so that the + cooldown / latency x-values align with the trajectory t0. + """ + t = df["t_relative"] - time_offset + + # Calculate summary stats + total_ticks = len(df) + stall_count = df["stall"].sum() + stall_fraction = stall_count / total_ticks if total_ticks > 0 else 0 + obs_triggered_count = df["obs_triggered"].sum() + action_received_count = df["action_received"].sum() + + # 1. Cooldown counter + quantized latency estimate (both in steps) + ax_cooldown.plot(t, df["cooldown"], linewidth=1.5, alpha=0.7, color="#9b59b6", label="Cooldown") + if "latency_estimate_steps" in df.columns: + ax_cooldown.plot( + t, + df["latency_estimate_steps"], + drawstyle="steps-post", + linewidth=1.2, + color="#2ecc71", + alpha=0.7, + label="Latency estimate (steps)", + ) + ax_cooldown.set_title("Cooldown & Latency Estimate (steps)") + ax_cooldown.set_ylabel("Steps") + ax_cooldown.legend(loc="upper right", fontsize=8) + + # 2. Latency estimate (smooth) + measured RTT overlay (all in ms) + if "latency_estimate_ms" in df.columns: + estimate_ms = pd.to_numeric(df["latency_estimate_ms"], errors="coerce") + else: + # Fallback for older CSVs: convert quantized steps to ms + t_span = df["t_relative"].iloc[-1] + fps = (len(df) - 1) / t_span if t_span > 0 else 60.0 + estimate_ms = df["latency_estimate_steps"] / fps * 1000.0 + ax_latency.plot(t, estimate_ms, linewidth=1.5, color="#3498db", label="Estimate") + # Overlay measured RTT in ms (red scatter) + if "measured_latency_ms" in df.columns: + measured = df[df["measured_latency_ms"].notna()] + if len(measured) > 0: + ax_latency.scatter( + measured["t_relative"] - time_offset, + measured["measured_latency_ms"], + s=25, + alpha=0.8, + color="#e74c3c", + label="Measured RTT", + zorder=5, + ) + ax_latency.legend(loc="upper right", fontsize=8) + ax_latency.set_title("Inference Latency") + ax_latency.set_ylabel("ms") + + # 3. Action schedule size + if ax_schedule is not None and "schedule_size" in df.columns: + ax_schedule.plot(t, df["schedule_size"], linewidth=1, color="#3498db") + ax_schedule.axhline(y=0, color="#e74c3c", linestyle="--", alpha=0.5, linewidth=0.5) + ax_schedule.set_title("Action Schedule Size") + ax_schedule.set_ylabel("Actions") + + # 4. Events timeline (only when a separate events axis is provided) + if ax_events is not None: + obs_times = t[df["obs_triggered"] == 1] + action_times = t[df["action_received"] == 1] + + ax_events.scatter( + obs_times, + [1] * len(obs_times), + marker="|", + s=30, + alpha=0.7, + label=f"obs triggered ({obs_triggered_count})", + ) + ax_events.scatter( + action_times, + [0] * len(action_times), + marker="|", + s=30, + alpha=0.7, + label=f"action recv ({action_received_count})", + ) + ax_events.set_ylabel("Events") + ax_events.set_ylim(-0.5, 1.5) + ax_events.set_yticks([0, 1]) + ax_events.set_yticklabels(["Action", "Obs"]) + + return { + "total_ticks": total_ticks, + "stall_count": stall_count, + "stall_fraction": stall_fraction, + "obs_triggered_count": obs_triggered_count, + "action_received_count": action_received_count, + } + + +def plot_estimator_comparison( + dfs: dict[str, pd.DataFrame], + output_path: Path, + csv_paths: dict[str, Path] | None = None, +): + """Plot latency estimator comparison with detailed metrics. + + Expects dfs to be a dict mapping experiment name to DataFrame, + where names contain 'jk' or 'max_last_10' to identify the estimator. + + Args: + dfs: Dict mapping experiment name to DataFrame + output_path: Path to save the plot image + csv_paths: Optional dict mapping experiment name to CSV path (for loading trajectory data) + """ + setup_paper_style() + + # Check if we have trajectory data + trajectories: dict[str, dict | None] = {} + if csv_paths: + for name, csv_path in csv_paths.items(): + trajectories[name] = load_trajectory_data(csv_path) + + has_trajectory_data = any(t is not None for t in trajectories.values()) + + # Create figure with extra rows for trajectory if available + if has_trajectory_data: + # 2 main plots + 2 trajectory plots (one for JK, one for Max10) + fig, axes = plt.subplots( + 4, 1, figsize=(14, 12), sharex=True, gridspec_kw={"height_ratios": [2, 2, 2, 2]} + ) + ax_measured, ax_latency, ax_traj_jk, ax_traj_max = axes + else: + fig, axes = plt.subplots(2, 1, figsize=(14, 8), sharex=True) + ax_measured, ax_latency = axes + ax_traj_jk = None + ax_traj_max = None + + colors = {"jk": "#2ecc71", "max_last_10": "#e74c3c"} + linestyles = {"jk": "-", "max_last_10": "--"} + + for name, df in dfs.items(): + t = df["t_relative"] + + # Determine estimator type from name + if "jk" in name.lower(): + estimator = "jk" + label = f"JK: {name}" + elif "max" in name.lower(): + estimator = "max_last_10" + label = f"Max10: {name}" + else: + estimator = "jk" + label = name + + color = colors.get(estimator, "#3498db") + linestyle = linestyles.get(estimator, "-") + + # Measured RTT in milliseconds (line plot with distinct linestyle) + if "measured_latency_ms" in df.columns: + measured = df[df["measured_latency_ms"].notna()] + if len(measured) > 0: + ax_measured.plot( + measured["t_relative"], + measured["measured_latency_ms"], + linewidth=1.5, + linestyle=linestyle, + color=color, + label=f"RTT ({estimator})", + ) + + # Latency estimate (line plot with distinct linestyle) + ax_latency.plot( + t, df["latency_estimate_steps"], linewidth=1.5, linestyle=linestyle, color=color, label=label + ) + + # Plot trajectory data if available + if has_trajectory_data and name in trajectories and trajectories[name] is not None: + traj_ax = ax_traj_jk if estimator == "jk" else ax_traj_max + if traj_ax is not None: + plot_trajectory_on_axis( + traj_ax, + trajectories[name], + ) + traj_ax.set_title(f"Trajectory: {estimator.upper()}") + traj_ax.legend(loc="upper right", ncol=3) + + ax_measured.set_ylabel("Measured RTT (ms)") + ax_measured.legend(loc="upper right") + ax_measured.grid(True, alpha=0.3) + ax_measured.set_title("Latency Estimation: JK vs Max-of-Last-10") + + ax_latency.set_ylabel("Latency Estimate (steps)") + ax_latency.legend(loc="upper right") + ax_latency.grid(True, alpha=0.3) + + # Set x-axis label on the bottom-most plot + if has_trajectory_data and ax_traj_max is not None: + ax_traj_max.set_xlabel("Time (seconds)") + else: + ax_latency.set_xlabel("Time (seconds)") + + plt.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, bbox_inches="tight") + print(f"Estimator comparison plot saved to: {output_path}") + + +def plot_detailed(df: pd.DataFrame, title: str, output_path: Path): + """Detailed plot for a single experiment with schedule size and L2 metrics.""" + setup_paper_style() + fig, axes = plt.subplots(6, 1, figsize=(14, 14), sharex=True) + ax_schedule, ax_latency, ax_latency_gantt, ax_cooldown, ax_stall, ax_l2 = axes + t = df["t_relative"] + + # 1. Schedule size + if "schedule_size" in df.columns: + ax_schedule.plot(t, df["schedule_size"], linewidth=1, color="#3498db") + ax_schedule.set_ylabel("Schedule Size") + ax_schedule.axhline(y=0, color="red", linestyle="--", alpha=0.5, linewidth=0.5) + ax_schedule.grid(True, alpha=0.3) + ax_schedule.set_title(f"Detailed Analysis: {title}") + + # 2. Latency estimate with measured RTT overlay + ax_latency.plot(t, df["latency_estimate_steps"], linewidth=1.5, color="#2ecc71", label="Estimate") + if "measured_latency_ms" in df.columns: + measured = df[df["measured_latency_ms"].notna()] + if len(measured) > 0: + measured_steps = measured["measured_latency_ms"] / 33.3 + ax_latency.scatter( + measured["t_relative"], + measured_steps, + s=15, + alpha=0.7, + color="#e74c3c", + label="Measured RTT", + zorder=5, + ) + ax_latency.set_ylabel("Latency (steps)") + ax_latency.legend(loc="upper right", fontsize=8) + ax_latency.grid(True, alpha=0.3) + + # 3. Latency breakdown Gantt chart + plot_latency_gantt_on_axis(ax_latency_gantt, df) + + # 4. Cooldown counter + ax_cooldown.plot(t, df["cooldown"], linewidth=0.5, color="#9b59b6", alpha=0.8) + ax_cooldown.set_ylabel("Cooldown") + ax_cooldown.grid(True, alpha=0.3) + + # 5. Stall indicator + ax_stall.fill_between(t, 0, df["stall"], alpha=0.5, color="#e74c3c", step="mid") + ax_stall.set_ylabel("Stall") + ax_stall.set_ylim(-0.1, 1.1) + ax_stall.grid(True, alpha=0.3) + + # 6. L2 discrepancy (if available) + if "chunk_mean_l2" in df.columns: + l2_data = df[df["chunk_mean_l2"].notna()] + if len(l2_data) > 0: + ax_l2.scatter(l2_data["t_relative"], l2_data["chunk_mean_l2"], s=10, alpha=0.7, color="#f39c12") + ax_l2.set_ylabel("Chunk L2") + ax_l2.grid(True, alpha=0.3) + + ax_l2.set_xlabel("Time (seconds)") + + # Summary stats + total_ticks = len(df) + stall_count = df["stall"].sum() + stall_frac = stall_count / total_ticks if total_ticks > 0 else 0 + fig.suptitle(f"{title} | Stalls: {stall_count} ({stall_frac:.1%})", fontsize=11) + + plt.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, bbox_inches="tight") + print(f"Detailed plot saved to: {output_path}") + + +def plot_results(input_path: Path, output_path: Path, mode: str = "basic", filter_pattern: str | None = None): + """Load CSV(s) and generate plots. + + When ``input_path`` is a directory the function looks for CSV files inside + it. If exactly one CSV is found it also loads the matching + ``.trajectory.json`` (if present) and includes a trajectory subplot. + + Args: + input_path: Path to CSV file or directory + output_path: Path to save the plot image + mode: Plot mode - 'basic', 'detailed', or 'estimator_comparison' + filter_pattern: Optional pattern to filter CSV files by name + """ + setup_paper_style() + + # Collect CSV files + if input_path.is_file(): + csv_files = [input_path] + elif input_path.is_dir(): + csv_files = sorted(input_path.glob("*.csv")) + if filter_pattern: + csv_files = [f for f in csv_files if filter_pattern in f.name] + if not csv_files: + print( + f"No CSV files found in {input_path}" + + (f" matching '{filter_pattern}'" if filter_pattern else "") + ) + return + else: + print(f"Input path does not exist: {input_path}") + return + + print(f"Found {len(csv_files)} CSV file(s)") + + # Load all data + dfs = {} + csv_paths = {} + for csv_file in csv_files: + print(f" Loading: {csv_file.name}") + dfs[csv_file.stem] = load_experiment_data(csv_file) + csv_paths[csv_file.stem] = csv_file + + # Route to appropriate plotting function based on mode + if mode == "estimator_comparison": + plot_estimator_comparison(dfs, output_path, csv_paths=csv_paths) + return + elif mode == "detailed" and len(csv_files) == 1: + plot_detailed(list(dfs.values())[0], list(dfs.keys())[0], output_path) + return + + # For a single experiment, try to load trajectory data for extra subplots + trajectory_data: dict | None = None + if len(csv_files) == 1: + trajectory_data = load_trajectory_data(csv_files[0]) + if trajectory_data is not None: + traj_path = csv_files[0].with_suffix(".trajectory.json") + print(f" Loading: {traj_path.name}") + + # Decide subplot layout: + # Without trajectory: 4 base rows (latency, cooldown, schedule, events) + # With trajectory: trajectory + gantt + events + latency + cooldown + schedule = 6 + # (obs/action events are merged into the sim events plot) + has_trajectory = trajectory_data is not None + if has_trajectory: + n_rows = 6 + # trajectory gets 2, the rest 1 + height_ratios = [2, 1, 1, 1, 1, 1] + fig, axes = plt.subplots( + n_rows, + 1, + figsize=(14, 16), + sharex=True, + gridspec_kw={"height_ratios": height_ratios}, + ) + (ax_traj, ax_gantt, ax_sim_events, ax_latency, ax_cooldown, ax_schedule) = axes + ax_events = None # obs/action events merged into sim_events + else: + n_rows = 4 + fig, axes = plt.subplots(n_rows, 1, figsize=(12, 10), sharex=True) + ax_latency, ax_cooldown, ax_schedule, ax_events = axes + ax_traj = None + ax_gantt = None + ax_sim_events = None + + all_stats = [] + + # Compute the time offset between CSV t0 and trajectory t0 so that + # cooldown / latency plots align with trajectory-derived subplots. + time_offset = 0.0 + if has_trajectory: + df0 = list(dfs.values())[0] + csv_t0 = df0["t"].iloc[0] + traj_executed = trajectory_data.get("executed", []) + traj_t0 = min(e["t"] for e in traj_executed) if traj_executed else csv_t0 + time_offset = traj_t0 - csv_t0 + + for name, df in dfs.items(): + stats = plot_single_experiment( + df, + title=name, + ax_cooldown=ax_cooldown, + ax_latency=ax_latency, + ax_schedule=ax_schedule, + ax_events=ax_events, + time_offset=time_offset, + ) + stats["file"] = name + all_stats.append(stats) + + # Plot trajectory-derived subplots if available + if trajectory_data is not None: + sim_config_offset = time_offset # already computed above + df0 = list(dfs.values())[0] + + # 4. Trajectory (all joints) + if ax_traj is not None: + plot_trajectory_on_axis(ax_traj, trajectory_data) + ax_traj.set_title("Trajectory") + ax_traj.legend(loc="upper right", ncol=3) + + # 5. Gantt chart of fault injection windows + if ax_gantt is not None: + plot_gantt_on_axis( + ax_gantt, + trajectory_data, + sim_config_offset=sim_config_offset, + ) + ax_gantt.set_title("Fault Injection Schedule") + + # 6. Events timeline (sim events + obs/action from CSV) + if ax_sim_events is not None: + plot_sim_events_on_axis( + ax_sim_events, + trajectory_data, + sim_config_offset=sim_config_offset, + df=df0, + ) + ax_sim_events.set_title("Events") + + # Label only the bottom subplot with "Time (seconds)" + axes[-1].set_xlabel("Time (seconds)") + + # Add legends (only for the separate events axis in non-trajectory mode) + if ax_events is not None: + ax_events.legend(loc="upper right") + + plt.tight_layout() + + # Save figure as both PNG and PDF + output_path.parent.mkdir(parents=True, exist_ok=True) + stem = output_path.with_suffix("").as_posix().rstrip(".") + png_path = Path(f"{stem}.png") + pdf_path = Path(f"{stem}.pdf") + plt.savefig(png_path, bbox_inches="tight") + plt.savefig(pdf_path, bbox_inches="tight") + print(f"\nPlot saved to: {png_path}") + print(f"Plot saved to: {pdf_path}") + + # Generate a LaTeX document that imports the PDF figure + tex_path = Path(f"{stem}.tex") + pdf_filename = pdf_path.name + + # Build an optional config table from trajectory data + config_table_tex = "" + if trajectory_data is not None: + exp_config = trajectory_data.get("experiment_config") + sim_config = trajectory_data.get("simulation_config") + if exp_config: + config_table_tex = ( + "\n" + + generate_config_table( + exp_config, + simulation_config=sim_config, + ) + + "\n" + ) + + tex_content = rf"""\documentclass[11pt]{{article}} +\usepackage[margin=1in]{{geometry}} +\usepackage{{graphicx}} +\usepackage{{caption}} +\usepackage{{booktabs}} +\begin{{document}} +{config_table_tex} +\begin{{figure}}[htbp] + \centering + \includegraphics[width=\textwidth]{{{pdf_filename}}} + \caption{{Experiment results.}} + \label{{fig:{pdf_path.stem}}} +\end{{figure}} + +\end{{document}} +""" + tex_path.write_text(tex_content) + print(f"LaTeX saved to: {tex_path}") + + # Compile LaTeX to PDF if pdflatex is available + pdflatex_path = shutil.which("pdflatex") + if pdflatex_path: + # Rename the plot PDF temporarily so pdflatex output doesn't collide + plot_pdf_tmp = pdf_path.with_suffix(".plot.pdf") + pdf_path.rename(plot_pdf_tmp) + # Update the tex to reference the temp name + tex_content_tmp = tex_content.replace(pdf_filename, plot_pdf_tmp.name) + tex_path.write_text(tex_content_tmp) + try: + result = subprocess.run( + [pdflatex_path, "-interaction=nonstopmode", tex_path.name], + cwd=tex_path.parent, + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode == 0: + # pdflatex produced the compiled doc; rename files back + compiled_pdf = tex_path.with_suffix(".pdf") + latex_out = tex_path.with_name(f"{tex_path.stem}_doc.pdf") + compiled_pdf.rename(latex_out) + plot_pdf_tmp.rename(pdf_path) + # Restore original tex content + tex_path.write_text(tex_content) + print(f"LaTeX PDF saved to: {latex_out}") + else: + # Restore on failure + plot_pdf_tmp.rename(pdf_path) + tex_path.write_text(tex_content) + print(f"pdflatex failed (exit {result.returncode}). Check {tex_path}") + except subprocess.TimeoutExpired: + plot_pdf_tmp.rename(pdf_path) + tex_path.write_text(tex_content) + print("pdflatex timed out") + finally: + # Clean up pdflatex auxiliary files + for ext in (".aux", ".log", ".out"): + aux = tex_path.with_suffix(ext) + if aux.exists(): + aux.unlink() + else: + print("pdflatex not found; skipping LaTeX compilation") + + # Print summary table + print("\nSummary:") + print("-" * 80) + print(f"{'File':<40} {'Ticks':>8} {'Stalls':>8} {'Stall%':>8} {'ObsSent':>8} {'ActRecv':>8}") + print("-" * 80) + for s in all_stats: + print( + f"{s['file']:<40} {s['total_ticks']:>8} {s['stall_count']:>8} " + f"{s['stall_fraction']:>7.1%} {s['obs_triggered_count']:>8} {s['action_received_count']:>8}" + ) + + # Show plot interactively if not in headless mode + plt.show() + + +def main(): + parser = argparse.ArgumentParser(description="Plot DRTC experiment results") + parser.add_argument( + "--input", + type=Path, + required=True, + help="Path to CSV file or directory containing CSV files", + ) + parser.add_argument( + "--output", + type=Path, + default=None, + help="Output path for the plot image (default: saved beside the input CSV)", + ) + parser.add_argument( + "--mode", + type=str, + choices=["basic", "detailed", "estimator_comparison"], + default="basic", + help="Plot mode: basic (default), detailed (single file), or estimator_comparison", + ) + parser.add_argument( + "--filter", + type=str, + default=None, + help="Filter CSV files by pattern (e.g., 'estimator_' to only plot estimator experiments)", + ) + args = parser.parse_args() + + # Default output path (stem only – both .png and .pdf are generated) + output = args.output + if output is None: + output = args.input.parent / args.input.stem if args.input.is_file() else args.input / args.input.name + + plot_results(args.input, output, mode=args.mode, filter_pattern=args.filter) + + +if __name__ == "__main__": + main() diff --git a/examples/experiments/run_drtc_experiment.py b/examples/experiments/run_drtc_experiment.py new file mode 100644 index 00000000000..85de386f710 --- /dev/null +++ b/examples/experiments/run_drtc_experiment.py @@ -0,0 +1,543 @@ +#!/usr/bin/env python3 +""" +DRTC Experiment Runner + +This script runs experiments on a REAL ROBOT to validate the DRTC algorithm. It assumes the policy server is already running. + +Experiment parameters are defined in YAML config files that live in +examples/experiments/configs/. + +Usage: + python examples/experiments/run_async_inference_experiment.py --config mixture_of_faults + python examples/experiments/run_async_inference_experiment.py --config spike --output_dir results/experiments + python examples/experiments/run_async_inference_experiment.py --config path/to/custom.yaml +""" + +import argparse +import logging +import signal +import threading +import time +from contextlib import suppress +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path + +import yaml + +from lerobot.async_inference.configs_drtc import RobotClientDrtcConfig +from lerobot.async_inference.robot_client_drtc import RobotClientDrtc +from lerobot.async_inference.utils.simulation import ( + DisconnectConfig, + DisconnectEvent, + DropConfig, + DropEvent, + DuplicateConfig, + DuplicateEvent, + ReorderConfig, + ReorderEvent, +) +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.robots.so_follower.config_so_follower import SO100FollowerConfig, SO101FollowerConfig + +logger = logging.getLogger(__name__) + + +DEFAULT_SERVER_ADDRESS = "192.168.4.38:8080" +DEFAULT_ROBOT_PORT = "/dev/ttyACM0" +DEFAULT_ROBOT_ID = "so101_follower_2026_01_03" +DEFAULT_CAMERA1_PATH = "/dev/video0" +DEFAULT_CAMERA2_PATH = "/dev/video4" +DEFAULT_CAMERA_WIDTH = 800 +DEFAULT_CAMERA_HEIGHT = 600 +DEFAULT_CAMERA_FPS = 30 +DEFAULT_CAMERA_FOURCC = "MJPG" +DEFAULT_MODEL_PATH = "jackvial/so101_smolvla_pickplaceorangecube_e100" +DEFAULT_TASK = "Pick up the orange cube and place it on the black X marker with the white background" + +CONFIGS_DIR = Path(__file__).parent / "configs" + + +@dataclass +class ExperimentConfig: + """Configuration for a single experiment run.""" + + name: str + estimator: str + cooldown: bool + # Hardware + robot_type: str = "so101" + gpu: str = "" + client_host: str = "" + server_host: str = "" + robot_port: str = DEFAULT_ROBOT_PORT + robot_id: str = DEFAULT_ROBOT_ID + camera1_path: str = DEFAULT_CAMERA1_PATH + camera2_path: str = DEFAULT_CAMERA2_PATH + camera_width: int = DEFAULT_CAMERA_WIDTH + camera_height: int = DEFAULT_CAMERA_HEIGHT + camera_fps: int = DEFAULT_CAMERA_FPS + camera_fourcc: str | None = DEFAULT_CAMERA_FOURCC + # Policy + policy_type: str = "smolvla" + pretrained_name_or_path: str = DEFAULT_MODEL_PATH + # DRTC parameters + latency_k: float = 2.0 + epsilon: int = 2 + s_min: int = 15 + latency_alpha: float = 0.125 + latency_beta: float = 0.25 + # Timing + duration_s: float = 60.0 + fps: int = 60 + actions_per_chunk: int = 50 + # Flow matching / RTC + num_flow_matching_steps: int | None = 8 + rtc_enabled: bool = True + rtc_max_guidance_weight: float | None = None + rtc_prefix_attention_schedule: str = "linear" + rtc_sigma_d: float = 0.2 + rtc_full_trajectory_alignment: bool = False + # Butterworth filter + action_filter_mode: str = "butterworth" + action_filter_butterworth_cutoff: float = 3.0 + action_filter_butterworth_order: int = 2 + action_filter_gain: float = 1.4 + action_filter_past_buffer_size: int = 10 + # Drop/spike/duplicate/reorder/disconnect injection + drop_obs_config: DropConfig | None = None + drop_action_config: DropConfig | None = None + dup_obs_config: DuplicateConfig | None = None + dup_action_config: DuplicateConfig | None = None + reorder_obs_config: ReorderConfig | None = None + reorder_action_config: ReorderConfig | None = None + disconnect_config: DisconnectConfig | None = None + spikes: list[dict] = field(default_factory=list) + # Diagnostics + full_diagnostics: bool = False + trajectory_viz_enabled: bool = False + + +# ---- YAML config loading ---- + +# Scalar fields that map 1:1 from YAML keys to ExperimentConfig constructor args. +_SCALAR_FIELDS = frozenset( + { + "name", + "estimator", + "cooldown", + "robot_type", + "gpu", + "client_host", + "server_host", + "robot_port", + "robot_id", + "camera1_path", + "camera2_path", + "camera_width", + "camera_height", + "camera_fps", + "camera_fourcc", + "policy_type", + "pretrained_name_or_path", + "latency_k", + "epsilon", + "s_min", + "latency_alpha", + "latency_beta", + "duration_s", + "fps", + "actions_per_chunk", + "num_flow_matching_steps", + "rtc_enabled", + "rtc_max_guidance_weight", + "rtc_prefix_attention_schedule", + "rtc_sigma_d", + "rtc_full_trajectory_alignment", + "action_filter_mode", + "action_filter_butterworth_cutoff", + "action_filter_butterworth_order", + "action_filter_gain", + "action_filter_past_buffer_size", + "full_diagnostics", + "trajectory_viz_enabled", + } +) + + +def _parse_experiment_dict(d: dict) -> ExperimentConfig: + """Convert a raw YAML dict into an ExperimentConfig.""" + kwargs: dict = {k: v for k, v in d.items() if k in _SCALAR_FIELDS} + + # Fault-injection lists -> typed config objects + if "drop_obs" in d: + kwargs["drop_obs_config"] = DropConfig(drops=[DropEvent(**e) for e in d["drop_obs"]]) + if "drop_action" in d: + kwargs["drop_action_config"] = DropConfig(drops=[DropEvent(**e) for e in d["drop_action"]]) + if "dup_obs" in d: + kwargs["dup_obs_config"] = DuplicateConfig(duplicates=[DuplicateEvent(**e) for e in d["dup_obs"]]) + if "dup_action" in d: + kwargs["dup_action_config"] = DuplicateConfig( + duplicates=[DuplicateEvent(**e) for e in d["dup_action"]] + ) + if "reorder_obs" in d: + kwargs["reorder_obs_config"] = ReorderConfig(reorders=[ReorderEvent(**e) for e in d["reorder_obs"]]) + if "reorder_action" in d: + kwargs["reorder_action_config"] = ReorderConfig( + reorders=[ReorderEvent(**e) for e in d["reorder_action"]] + ) + if "disconnect" in d: + kwargs["disconnect_config"] = DisconnectConfig( + disconnects=[DisconnectEvent(**e) for e in d["disconnect"]] + ) + if "spikes" in d: + kwargs["spikes"] = d["spikes"] + + return ExperimentConfig(**kwargs) + + +def load_experiments_from_yaml(path: Path) -> list[ExperimentConfig]: + """Load one or more ExperimentConfig from a YAML file. + + Supports two formats: + + **Single experiment** -- top-level dict IS the experiment:: + + name: my_experiment + estimator: jk + cooldown: true + + **Multi-experiment** -- has an ``experiments`` key (and optional ``defaults``):: + + defaults: + estimator: jk + cooldown: true + experiments: + - name: run_1 + - name: run_2 + estimator: max_last_10 + """ + with open(path) as f: + raw = yaml.safe_load(f) + + if not isinstance(raw, dict): + raise ValueError(f"Expected a YAML mapping at top level, got {type(raw).__name__}") + + if "experiments" in raw: + defaults = raw.get("defaults", {}) + configs = [] + for exp_dict in raw["experiments"]: + merged = {**defaults, **exp_dict} + configs.append(_parse_experiment_dict(merged)) + return configs + + return [_parse_experiment_dict(raw)] + + +def resolve_config_path(config_arg: str) -> Path: + """Resolve a ``--config`` argument to a YAML file path. + + Accepts: + - A relative or absolute path to a ``.yaml`` file. + - A bare name (e.g. ``spike``), which resolves to + ``examples/experiments/configs/.yaml``. + """ + path = Path(config_arg) + if path.exists(): + return path + + # Try appending .yaml + if not config_arg.endswith(".yaml"): + with_ext = Path(config_arg + ".yaml") + if with_ext.exists(): + return with_ext + + # Try the bundled configs directory + in_configs = CONFIGS_DIR / config_arg + if in_configs.exists(): + return in_configs + if not config_arg.endswith(".yaml"): + in_configs_yaml = CONFIGS_DIR / (config_arg + ".yaml") + if in_configs_yaml.exists(): + return in_configs_yaml + + raise FileNotFoundError(f"Config not found: {config_arg} (also tried {CONFIGS_DIR / config_arg})") + + +def create_robot_config(config: ExperimentConfig) -> SO100FollowerConfig | SO101FollowerConfig: + camera_fourcc = ( + config.camera_fourcc.strip() if isinstance(config.camera_fourcc, str) else config.camera_fourcc + ) + if camera_fourcc == "": + camera_fourcc = None + + camera_cfg = { + "camera2": OpenCVCameraConfig( + index_or_path=config.camera2_path, + width=config.camera_width, + height=config.camera_height, + fps=config.camera_fps, + fourcc=camera_fourcc, + ), + "camera1": OpenCVCameraConfig( + index_or_path=config.camera1_path, + width=config.camera_width, + height=config.camera_height, + fps=config.camera_fps, + fourcc=camera_fourcc, + ), + } + robot_type_normalized = config.robot_type.strip().lower() + if robot_type_normalized in {"so101", "so101_follower"}: + return SO101FollowerConfig(port=config.robot_port, id=config.robot_id, cameras=camera_cfg) + if robot_type_normalized in {"so100", "so100_follower"}: + return SO100FollowerConfig(port=config.robot_port, id=config.robot_id, cameras=camera_cfg) + + raise ValueError( + f"Unsupported robot_type '{config.robot_type}'. " + "Supported values: so101, so101_follower, so100, so100_follower." + ) + + +def create_client_config( + config: ExperimentConfig, + metrics_path: Path, + server_address: str = DEFAULT_SERVER_ADDRESS, + trajectory_viz_ws_url: str | None = None, +) -> RobotClientDrtcConfig: + """Create a client config for a single experiment.""" + robot_cfg = create_robot_config(config) + client_kwargs = { + "robot": robot_cfg, + "server_address": server_address, + "robot_type": config.robot_type, + "gpu": config.gpu, + "client_host": config.client_host, + "server_host": config.server_host, + "policy_device": "cuda", + "policy_type": config.policy_type, + "pretrained_name_or_path": config.pretrained_name_or_path, + "actions_per_chunk": config.actions_per_chunk, + "fps": config.fps, + "s_min": config.s_min, + "latency_estimator_type": config.estimator, + "cooldown_enabled": config.cooldown, + "latency_k": config.latency_k, + "epsilon": config.epsilon, + "latency_alpha": config.latency_alpha, + "latency_beta": config.latency_beta, + # Flow matching / RTC + "num_flow_matching_steps": config.num_flow_matching_steps, + "rtc_enabled": config.rtc_enabled, + "rtc_max_guidance_weight": config.rtc_max_guidance_weight, + "rtc_prefix_attention_schedule": config.rtc_prefix_attention_schedule, + "rtc_sigma_d": config.rtc_sigma_d, + "rtc_full_trajectory_alignment": config.rtc_full_trajectory_alignment, + # Butterworth filter + "action_filter_mode": config.action_filter_mode, + "action_filter_butterworth_cutoff": config.action_filter_butterworth_cutoff, + "action_filter_butterworth_order": config.action_filter_butterworth_order, + "action_filter_gain": config.action_filter_gain, + "action_filter_past_buffer_size": config.action_filter_past_buffer_size, + # Diagnostics and robustness + "metrics_diagnostic_enabled": True, + "metrics_diagnostic_interval_s": 2.0, + "metrics_diagnostic_window_s": 10.0, + "metrics_diagnostic_verbose": config.full_diagnostics, + "control_use_deadline_clock": True, + "obs_fallback_on_failure": True, + "obs_fallback_max_age_s": 2.0, + "trajectory_viz_enabled": config.trajectory_viz_enabled, + # Drop/spike/duplicate/reorder/disconnect injection + "drop_obs_config": config.drop_obs_config, + "drop_action_config": config.drop_action_config, + "dup_obs_config": config.dup_obs_config, + "dup_action_config": config.dup_action_config, + "reorder_obs_config": config.reorder_obs_config, + "reorder_action_config": config.reorder_action_config, + "disconnect_config": config.disconnect_config, + "spikes": config.spikes, + "metrics_path": str(metrics_path), + } + if trajectory_viz_ws_url: + client_kwargs["trajectory_viz_ws_url"] = trajectory_viz_ws_url + return RobotClientDrtcConfig(**client_kwargs) + + +def run_experiment( + config: ExperimentConfig, + output_dir: Path, + server_address: str = DEFAULT_SERVER_ADDRESS, + trajectory_viz_ws_url: str | None = None, + task: str = DEFAULT_TASK, + experiment_name: str | None = None, +) -> dict: + """Run a single standalone experiment (creates and tears down client).""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + if experiment_name: + # Use the provided name verbatim; append timestamp if folder already exists. + exp_dir = output_dir / experiment_name + if exp_dir.exists(): + exp_dir = output_dir / f"{experiment_name}_{timestamp}" + exp_name = exp_dir.name + else: + exp_name = f"{config.name}_{timestamp}" + exp_dir = output_dir / exp_name + exp_dir.mkdir(parents=True, exist_ok=True) + metrics_path = exp_dir / f"{exp_name}.csv" + + logger.info(f"Running experiment: {config.name}") + logger.info( + f" Estimator: {config.estimator}, Cooldown: {config.cooldown}, Full diagnostics: {config.full_diagnostics}" + ) + if config.drop_obs_config: + logger.info(f" Drop obs: {config.drop_obs_config}") + if config.drop_action_config: + logger.info(f" Drop action: {config.drop_action_config}") + if config.dup_obs_config: + logger.info(f" Dup obs: {config.dup_obs_config}") + if config.dup_action_config: + logger.info(f" Dup action: {config.dup_action_config}") + if config.reorder_obs_config: + logger.info(f" Reorder obs: {config.reorder_obs_config}") + if config.reorder_action_config: + logger.info(f" Reorder action: {config.reorder_action_config}") + if config.disconnect_config: + logger.info(f" Disconnect: {config.disconnect_config}") + if config.spikes: + logger.info(f" Spikes: {config.spikes}") + + client_cfg = create_client_config( + config, + metrics_path, + server_address=server_address, + trajectory_viz_ws_url=trajectory_viz_ws_url, + ) + client = RobotClientDrtc(client_cfg) + + def stop_after_duration(): + time.sleep(config.duration_s) + client.signal_stop() + + def signal_handler(sig, frame): + client.signal_stop() + + timer_thread = threading.Thread(target=stop_after_duration, daemon=True) + original_handler = signal.signal(signal.SIGINT, signal_handler) + + try: + logger.info("Starting client...") + if client.start(): + logger.info("Client started successfully") + obs_thread = threading.Thread(target=client.observation_sender, daemon=True) + action_thread = threading.Thread(target=client.action_receiver, daemon=True) + obs_thread.start() + action_thread.start() + timer_thread.start() + logger.info(f"Running for {config.duration_s}s...") + try: + client.control_loop(task=task) + except Exception as e: + logger.exception(f"Control loop error: {e}") + # Wait for the timer thread to finish (it calls signal_stop which flushes) + timer_thread.join(timeout=5.0) + # Ensure metrics are flushed from the main thread in case signal_stop + # hasn't finished or was never called (e.g. control loop exited early). + if client._metrics.experiment is not None and client.config.metrics_path: + with suppress(Exception): + client._metrics.experiment.flush(client.config.metrics_path) + success = metrics_path.exists() + logger.info(f"Experiment finished. Metrics saved: {success}") + if success: + exp_dir = metrics_path.parent + logger.info(f"Metrics file: {metrics_path}") + logger.info("To plot:") + logger.info(f" uv run python examples/experiments/plot_results.py --input {exp_dir}") + return {"success": success, "metrics_path": str(metrics_path)} + else: + logger.error("Client failed to start!") + return {"success": False, "error": "Client failed to start"} + except Exception as e: + logger.exception(f"Exception during experiment: {e}") + return {"success": False, "error": str(e)} + finally: + signal.signal(signal.SIGINT, original_handler) + # Only disconnect at the very end for standalone experiments + with suppress(Exception): + client.stop() + + +def main(): + parser = argparse.ArgumentParser( + description="DRTC Experiment Runner", + epilog=( + "Config files live in examples/experiments/configs/. " + "Pass a bare name (e.g. spike) or a path to a .yaml file." + ), + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to a YAML config file, or a bare config name from examples/experiments/configs/", + ) + parser.add_argument( + "--experiment_name", + type=str, + default="", + help=( + "Optional custom run name (single-experiment configs only). " + "Overrides the name from the YAML file." + ), + ) + parser.add_argument("--output_dir", type=str, default="results/experiments") + parser.add_argument("--server_address", type=str, default=DEFAULT_SERVER_ADDRESS) + parser.add_argument( + "--trajectory_viz_ws_url", + type=str, + default=None, + help=( + "Optional WebSocket URL for trajectory visualization. " + "Used when trajectory visualization is enabled in the experiment config." + ), + ) + parser.add_argument("--pause_between_s", type=float, default=10.0) + + args = parser.parse_args() + + config_path = resolve_config_path(args.config) + configs = load_experiments_from_yaml(config_path) + logger.info(f"Loaded {len(configs)} experiment(s) from {config_path}") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + results = [] + for i, config in enumerate(configs): + if len(configs) > 1: + logger.info(f"{'=' * 50}") + logger.info(f"[{i + 1}/{len(configs)}] {config.name}") + logger.info(f"{'=' * 50}") + + experiment_name = (args.experiment_name or "").strip() or None + result = run_experiment( + config, + output_dir, + server_address=args.server_address, + trajectory_viz_ws_url=args.trajectory_viz_ws_url, + task=DEFAULT_TASK, + experiment_name=experiment_name if len(configs) == 1 else None, + ) + results.append(result) + + if i < len(configs) - 1: + logger.info(f"Pausing {args.pause_between_s}s before next experiment...") + time.sleep(args.pause_between_s) + + if len(configs) > 1: + success_count = sum(1 for r in results if r.get("success")) + logger.info(f"All experiments complete: {success_count}/{len(results)} succeeded") + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/async-inf/policy_server.py b/examples/tutorial/async-inf/policy_server.py index 244205bcfc6..87ebec11683 100644 --- a/examples/tutorial/async-inf/policy_server.py +++ b/examples/tutorial/async-inf/policy_server.py @@ -1,15 +1,27 @@ +""" +Run a Policy Server. + +To expose it to your local network, bind to 0.0.0.0 and connect from other machines +using this machine's LAN IP (e.g. 192.168.x.y), not 0.0.0.0. +""" + +import argparse + from lerobot.async_inference.configs import PolicyServerConfig from lerobot.async_inference.policy_server import serve -def main(): - host = ... # something like "127.0.0.1" if you're exposing to localhost - port = ... # something like 8080 - - config = PolicyServerConfig( - host=host, - port=port, +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--host", + default="0.0.0.0", + help='Host/interface to bind to. Use "127.0.0.1" for local-only.', ) + parser.add_argument("--port", type=int, default=8080, help="Port to bind to.") + args = parser.parse_args() + + config = PolicyServerConfig(host=args.host, port=args.port) serve(config) diff --git a/examples/tutorial/async-inf/policy_server_drtc.py b/examples/tutorial/async-inf/policy_server_drtc.py new file mode 100644 index 00000000000..204070801e9 --- /dev/null +++ b/examples/tutorial/async-inf/policy_server_drtc.py @@ -0,0 +1,92 @@ +""" +DRTC Policy Server Example + +This example demonstrates how to run the DRTC policy server with: +- 2-thread architecture (observation receiver + main inference loop) +- SPSC one-slot mailbox for observation queue + +Usage: + python examples/tutorial/async-inf/policy_server_drtc.py + python examples/tutorial/async-inf/policy_server_drtc.py --host 0.0.0.0 --port 8080 + +To expose it to your local network, bind to 0.0.0.0 and connect from other machines +using this machine's LAN IP (e.g. 192.168.x.y), not 0.0.0.0. +""" + +import argparse +import faulthandler +import logging +import os +import sys + +from lerobot.async_inference.policy_server_drtc import ( + PolicyServerDrtcConfig, + serve_drtc, +) + +faulthandler.enable(file=sys.stderr, all_threads=True) + + +def main() -> None: + debug = os.environ.get("LEROBOT_DEBUG", "0") == "1" + logging.basicConfig( + level=logging.DEBUG if debug else logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + ) + + parser = argparse.ArgumentParser(description="Run the DRTC Policy Server") + parser.add_argument( + "--host", + default="0.0.0.0", + help='Host/interface to bind to. Use "127.0.0.1" for local-only.', + ) + parser.add_argument( + "--port", + type=int, + default=8080, + help="Port to bind to.", + ) + parser.add_argument( + "--fps", + type=int, + default=30, + help="Control frequency in Hz (should match client).", + ) + parser.add_argument( + "--obs-queue-timeout", + type=float, + default=2.0, + help="Timeout for observation queue in seconds.", + ) + parser.add_argument( + "--verbose-diagnostics", + action="store_true", + default=False, + help="Enable verbose diagnostic metrics (all timings/counters instead of compact summary).", + ) + parser.add_argument( + "--viz", + action="store_true", + default=False, + help="Enable trajectory visualization server (HTTP on :8088, WebSocket on :8089).", + ) + args = parser.parse_args() + + config = PolicyServerDrtcConfig( + host=args.host, + port=args.port, + fps=args.fps, + obs_queue_timeout=args.obs_queue_timeout, + metrics_diagnostic_verbose=args.verbose_diagnostics, + trajectory_viz_enabled=args.viz, + ) + + try: + serve_drtc(config) + except Exception: + logging.exception("Policy server example runner crashed") + raise + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/async-inf/robot_client_drtc.py b/examples/tutorial/async-inf/robot_client_drtc.py new file mode 100644 index 00000000000..9b75c6fd953 --- /dev/null +++ b/examples/tutorial/async-inf/robot_client_drtc.py @@ -0,0 +1,242 @@ +import logging +import os +import threading + +from lerobot.async_inference.helpers import visualize_action_queue_size +from lerobot.async_inference.robot_client_drtc import ( + RobotClientDrtc, + RobotClientDrtcConfig, +) +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.robots.so100_follower import SO100FollowerConfig +from lerobot.robots.so101_follower import SO101FollowerConfig + +DEFAULT_ROBOT_TYPE = "so101" +DEFAULT_POLICY_TYPE = "smolvla" +DEFAULT_MODEL_PATH = "jackvial/so101_smolvla_pickplaceorangecube_e100" +DEFAULT_FOLLOWER_PORT = "/dev/ttyACM0" +DEFAULT_FOLLOWER_ID = "so101_follower_2026_01_03" +DEFAULT_CAMERA1_PATH = "/dev/v4l/by-path/platform-xhci-hcd.1-usb-0:2:1.0-video-index0" +DEFAULT_CAMERA2_PATH = "/dev/v4l/by-path/platform-xhci-hcd.0-usb-0:2:1.0-video-index0" +DEFAULT_CAMERA_WIDTH = 800 +DEFAULT_CAMERA_HEIGHT = 600 +DEFAULT_CAMERA_FPS = 30 +DEFAULT_CAMERA_FOURCC = "MJPG" +DEFAULT_CAMERA_THREADED_ASYNC_READ = True +DEFAULT_CAMERA_ALLOW_STALE_FRAMES = True + + +def _env_bool(name: str, default: bool) -> bool: + raw = os.getenv(name) + if raw is None: + return default + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +def _env_int(name: str, default: int) -> int: + raw = os.getenv(name) + if raw is None: + return default + return int(raw) + + +def _enable_debug_logging_if_requested() -> None: + """Enable DEBUG logs in the console. + + Note: async-inference uses `init_logging()` internally at import time, which sets the console + handler to INFO by default. Setting the root logger level is not enough; we must also bump the + handler level. + """ + if os.getenv("LEROBOT_DEBUG", "0") != "1": + return + + root_logger = logging.getLogger() + for handler in root_logger.handlers: + handler.setLevel(logging.DEBUG) + + # Make sure the module logger itself does not filter DEBUG. + logging.getLogger("robot_client_drtc").setLevel(logging.DEBUG) + + +def main() -> None: + _enable_debug_logging_if_requested() + + # Optional user overrides (defaults preserve current behavior): + # LEROBOT_ROBOT_TYPE (so101, so101_follower, so100, so100_follower) + # LEROBOT_POLICY_TYPE + # LEROBOT_PRETRAINED_NAME_OR_PATH + # LEROBOT_FOLLOWER_PORT / LEROBOT_FOLLOWER_ID + # LEROBOT_CAMERA1_PATH / LEROBOT_CAMERA2_PATH + # LEROBOT_CAMERA_WIDTH / LEROBOT_CAMERA_HEIGHT / LEROBOT_CAMERA_FPS + # LEROBOT_CAMERA_FOURCC + # LEROBOT_CAMERA_THREADED_ASYNC_READ / LEROBOT_CAMERA_ALLOW_STALE_FRAMES + robot_type = os.getenv("LEROBOT_ROBOT_TYPE", DEFAULT_ROBOT_TYPE).strip().lower() + policy_type = os.getenv("LEROBOT_POLICY_TYPE", DEFAULT_POLICY_TYPE) + pretrained_name_or_path = os.getenv("LEROBOT_PRETRAINED_NAME_OR_PATH", DEFAULT_MODEL_PATH) + follower_port = os.getenv("LEROBOT_FOLLOWER_PORT", DEFAULT_FOLLOWER_PORT) + follower_id = os.getenv("LEROBOT_FOLLOWER_ID", DEFAULT_FOLLOWER_ID) + + camera1_path = os.getenv("LEROBOT_CAMERA1_PATH", DEFAULT_CAMERA1_PATH) + camera2_path = os.getenv("LEROBOT_CAMERA2_PATH", DEFAULT_CAMERA2_PATH) + camera_width = _env_int("LEROBOT_CAMERA_WIDTH", DEFAULT_CAMERA_WIDTH) + camera_height = _env_int("LEROBOT_CAMERA_HEIGHT", DEFAULT_CAMERA_HEIGHT) + camera_fps = _env_int("LEROBOT_CAMERA_FPS", DEFAULT_CAMERA_FPS) + camera_fourcc_raw = os.getenv("LEROBOT_CAMERA_FOURCC", DEFAULT_CAMERA_FOURCC).strip() + camera_fourcc = camera_fourcc_raw or None + camera_threaded_async_read = _env_bool( + "LEROBOT_CAMERA_THREADED_ASYNC_READ", + DEFAULT_CAMERA_THREADED_ASYNC_READ, + ) + camera_allow_stale_frames = _env_bool( + "LEROBOT_CAMERA_ALLOW_STALE_FRAMES", + DEFAULT_CAMERA_ALLOW_STALE_FRAMES, + ) + + # These cameras must match the ones expected by the policy. + # Find your cameras with: lerobot-find-cameras + # Check the config.json on the Hub for the policy you are using. + camera_cfg = { + "camera2": OpenCVCameraConfig( + index_or_path=camera2_path, + width=camera_width, + height=camera_height, + fps=camera_fps, + fourcc=camera_fourcc, + use_threaded_async_read=camera_threaded_async_read, + allow_stale_frames=camera_allow_stale_frames, + ), + "camera1": OpenCVCameraConfig( + index_or_path=camera1_path, + width=camera_width, + height=camera_height, + fps=camera_fps, + fourcc=camera_fourcc, + use_threaded_async_read=camera_threaded_async_read, + allow_stale_frames=camera_allow_stale_frames, + ), + } + + if robot_type in {"so101", "so101_follower"}: + robot_cfg = SO101FollowerConfig( + port=follower_port, + id=follower_id, + cameras=camera_cfg, + ) + elif robot_type in {"so100", "so100_follower"}: + robot_cfg = SO100FollowerConfig( + port=follower_port, + id=follower_id, + cameras=camera_cfg, + ) + else: + raise ValueError( + f"Unsupported LEROBOT_ROBOT_TYPE '{robot_type}'. " + "Supported values: so101, so101_follower, so100, so100_follower." + ) + + # Server address (use LAN IP if connecting over network) + # Examples: + # - Local: 127.0.0.1:8080 + # - LAN: 192.168.4.37:8080 + # - Tunnel (see scripts/start_client.sh): 127.0.0.1:18080 + server_address = os.getenv("LEROBOT_SERVER_ADDRESS", "127.0.0.1:8080") + + client_cfg = RobotClientDrtcConfig( + robot=robot_cfg, + server_address=server_address, + policy_device="cuda", + # Policy selection: + # - `policy_type` must be one of the async-inference supported policies (includes "smolvla"). + # - `pretrained_name_or_path` is passed to `.from_pretrained(...)` on the server. + policy_type=policy_type, + pretrained_name_or_path=pretrained_name_or_path, + actions_per_chunk=50, + # Control frequency + fps=60, + # RTC s_min (aka minimum execution horizon) + s_min=15, + # DRTC cooldown margin + epsilon=2, + # DRTC Jacobson-Karels parameters (default values work well in most cases) + latency_alpha=0.125, # Smoothing factor for RTT mean + latency_beta=0.25, # Smoothing factor for RTT deviation + latency_k=2.0, # Scaling factor for deviation (K=1 for faster recovery) + # DRTC trajectory smoothing filter + action_filter_mode="butterworth", + action_filter_past_buffer_size=10, + action_filter_butterworth_cutoff=3.0, # Hz - passes motion, attenuates jitter + action_filter_butterworth_order=2, # Good balance of sharpness vs phase lag + action_filter_gain=1.4, # Slight boost to compensate attenuation + # Debug: visualize action queue size after stopping + debug_visualize_queue_size=False, + # Diagnostics (helpful to distinguish model stutter vs timing/latency jitter) + metrics_diagnostic_enabled=True, + metrics_diagnostic_interval_s=2.0, + metrics_diagnostic_window_s=10.0, + # Optional: use a deadline-based control clock for steadier action timing + control_use_deadline_clock=True, + # Robustness: if the robot state read occasionally fails, reuse the last good observation + # to avoid stalling action production (reduces visible hitches). + obs_fallback_on_failure=True, + obs_fallback_max_age_s=2.0, + # Trajectory visualization (sends data to policy server for real-time visualization) + # Local: + # - Open http://localhost:8088 in your browser to view trajectories + # Tunnel (see scripts/start_client.sh): + # - Open http://localhost:18088 in your browser to view trajectories + trajectory_viz_enabled=True, + trajectory_viz_ws_url=os.getenv("LEROBOT_TRAJECTORY_VIZ_WS_URL", "ws://localhost:8089"), + # RTC parameters + rtc_sigma_d=0.2, + rtc_full_trajectory_alignment=False, + num_flow_matching_steps=None, # Use policy default + rtc_max_guidance_weight=None, # Auto (Beta = n) + # Experiment metrics + metrics_path="results/jitter_analysis.csv", + ) + + # ------------------------------------------------------------------------- + # 4. Create and start client + # ------------------------------------------------------------------------- + client = RobotClientDrtc(client_cfg) + + # Task description for VLA policies + task = "Pick up the orange cube and place it on the black X marker with the white background" + + if client.start(): + # Start observation sender thread + obs_sender_thread = threading.Thread( + target=client.observation_sender, + name="observation_sender", + daemon=True, + ) + + # Start action receiver thread + action_receiver_thread = threading.Thread( + target=client.action_receiver, + name="action_receiver", + daemon=True, + ) + + obs_sender_thread.start() + action_receiver_thread.start() + + try: + # Main thread runs the control loop + client.control_loop(task) + + except KeyboardInterrupt: + print("\nStopping client...") + + finally: + client.stop() + obs_sender_thread.join(timeout=2.0) + action_receiver_thread.join(timeout=2.0) + + # Visualize action queue size if enabled + if client_cfg.debug_visualize_queue_size and client.action_queue_sizes: + visualize_action_queue_size(client.action_queue_sizes) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index e85d695df47..8c4bde83876 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ dependencies = [ "deepdiff>=7.0.1,<9.0.0", "imageio[ffmpeg]>=2.34.0,<3.0.0", "termcolor>=2.4.0,<4.0.0", + "sortedcontainers>=2.4.0,<3.0.0" ] # Optional dependencies diff --git a/scripts/provision_prime_lerobot.sh b/scripts/provision_prime_lerobot.sh new file mode 100755 index 00000000000..bebb0ad8b4f --- /dev/null +++ b/scripts/provision_prime_lerobot.sh @@ -0,0 +1,453 @@ +#!/bin/bash +# ============================================================================= +# Prime Intellect 4090 Provisioning for drtc +# ============================================================================= +# +# Provisions a GPU instance on Prime Intellect via REST API, SSHs in, +# clones drtc, sets up a Python venv with deps, and installs/configures +# Tailscale. +# +# Prerequisites (local): +# curl, jq, ssh +# ~/.prime/config.json with at least api_key and ssh_key_path +# +# Usage: +# ./scripts/provision_prime_lerobot.sh # create new pod + setup +# ./scripts/provision_prime_lerobot.sh --pod-id # resume setup on existing pod +# +# ============================================================================= + +set -euo pipefail + +# --------------------------------------------------------------------------- +# Colors / helpers +# --------------------------------------------------------------------------- +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +CYAN='\033[0;36m' +NC='\033[0m' + +info() { echo -e "${CYAN}[INFO]${NC} $*"; } +ok() { echo -e "${GREEN}[OK]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +die() { echo -e "${RED}[ERROR]${NC} $*" >&2; exit 1; } + +usage() { + cat <] + +Optional: + --pod-id Resume setup on an existing pod + -h, --help Show this help message +EOF +} + +# --------------------------------------------------------------------------- +# Dependency check +# --------------------------------------------------------------------------- +for cmd in curl jq ssh; do + command -v "$cmd" >/dev/null 2>&1 || die "Required command '$cmd' not found. Please install it." +done + +# --------------------------------------------------------------------------- +# Parse arguments +# --------------------------------------------------------------------------- +EXISTING_POD_ID="" +while [[ $# -gt 0 ]]; do + case "$1" in + --pod-id) + [ $# -ge 2 ] || die "Missing value for --pod-id" + EXISTING_POD_ID="$2" + shift 2 + ;; + --pod-id=*) + EXISTING_POD_ID="${1#*=}" + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + die "Unknown argument: $1. Usage: $0 [--pod-id ]" + ;; + esac +done + +# --------------------------------------------------------------------------- +# Read ~/.prime/config.json +# --------------------------------------------------------------------------- +PRIME_CONFIG="$HOME/.prime/config.json" +[ -f "$PRIME_CONFIG" ] || die "Config file not found: $PRIME_CONFIG" + +API_KEY=$(jq -r '.api_key // empty' "$PRIME_CONFIG") +BASE_URL=$(jq -r '.base_url // "https://api.primeintellect.ai"' "$PRIME_CONFIG") +SSH_KEY_PATH=$(jq -r '.ssh_key_path // empty' "$PRIME_CONFIG") +DEFAULT_GPU=$(jq -r '.default_gpu // "H100_80GB"' "$PRIME_CONFIG") +DEFAULT_IMAGE=$(jq -r '.default_image // "cuda_12_4_pytorch_2_4"' "$PRIME_CONFIG") +DEFAULT_DISK_SIZE=$(jq -r '.default_disk_size // 120' "$PRIME_CONFIG") +PROVIDER_TYPE=$(jq -r '.provider_type // "runpod"' "$PRIME_CONFIG") +TEAM_ID=$(jq -r '.team_id // empty' "$PRIME_CONFIG") + +[ -n "$API_KEY" ] || die "api_key missing from $PRIME_CONFIG" +[ -n "$SSH_KEY_PATH" ] || die "ssh_key_path missing from $PRIME_CONFIG" + +# Expand ~ in SSH_KEY_PATH +SSH_KEY_PATH="${SSH_KEY_PATH/#\~/$HOME}" +[ -f "$SSH_KEY_PATH" ] || die "SSH key not found at $SSH_KEY_PATH" + +info "Config loaded from $PRIME_CONFIG" +info " API base URL: $BASE_URL" +info " GPU: $DEFAULT_GPU" +info " Image: $DEFAULT_IMAGE" +info " Disk: ${DEFAULT_DISK_SIZE}GB" +info " Provider: $PROVIDER_TYPE" +info " SSH key: $SSH_KEY_PATH" + +if [ -n "$EXISTING_POD_ID" ]; then + # ----------------------------------------------------------------------- + # Resume mode: skip creation, use existing pod + # ----------------------------------------------------------------------- + POD_ID="$EXISTING_POD_ID" + POD_NAME="(existing)" + info "Resuming setup for existing pod: $POD_ID" +else + # ----------------------------------------------------------------------- + # Query availability API across ALL gpu types, filter by required image + # ----------------------------------------------------------------------- + GPU_TYPES=("H100_80GB" "A100_80GB" "RTX4090_24GB" "RTX6000Ada_48GB" "A6000_48GB" "L40S_48GB" "H200_141GB") + + info "Searching for GPUs with image ${DEFAULT_IMAGE} ..." + ALL_MATCHES="[]" + + for gpu_type in "${GPU_TYPES[@]}"; do + RESP=$(curl -sS -X GET \ + "${BASE_URL}/api/v1/availability/gpus?gpu_type=${gpu_type}&gpu_count=1&page_size=100" \ + -H "Authorization: Bearer ${API_KEY}" 2>/dev/null || echo '{"items":[]}') + + MATCHES=$(echo "$RESP" | jq -c \ + --arg img "$DEFAULT_IMAGE" \ + '[.items[] | select(.stockStatus != "Unavailable") | select(.images | index($img))]' 2>/dev/null || echo '[]') + + ALL_MATCHES=$(echo "$ALL_MATCHES" "$MATCHES" | jq -s '.[0] + .[1]') + done + + NUM_MATCHES=$(echo "$ALL_MATCHES" | jq 'length') + [ "$NUM_MATCHES" -gt 0 ] || die "No GPUs found with image ${DEFAULT_IMAGE} across any GPU type." + + # Deduplicate by cloudId+provider+dataCenter, sort by price + ALL_MATCHES=$(echo "$ALL_MATCHES" | jq '[group_by(.cloudId + .provider + .dataCenter) | .[] | .[0]] | sort_by(.prices.onDemand // 9999)') + + NUM_MATCHES=$(echo "$ALL_MATCHES" | jq 'length') + + echo "" + echo "==============================================" + echo " Available GPUs with image: ${DEFAULT_IMAGE}" + echo "==============================================" + echo "" + printf " %-4s %-20s %-15s %-12s %-10s %-10s\n" "#" "GPU Type" "Provider" "Datacenter" "Stock" "Price/hr" + printf " %-4s %-20s %-15s %-12s %-10s %-10s\n" "----" "--------------------" "---------------" "------------" "----------" "----------" + + for i in $(seq 0 $((NUM_MATCHES - 1))); do + ROW=$(echo "$ALL_MATCHES" | jq -r --argjson i "$i" '.[$i]') + R_GPU=$(echo "$ROW" | jq -r '.gpuType') + R_PROV=$(echo "$ROW" | jq -r '.provider') + R_DC=$(echo "$ROW" | jq -r '.dataCenter // "?"') + R_STOCK=$(echo "$ROW" | jq -r '.stockStatus') + R_PRICE=$(echo "$ROW" | jq -r '.prices.onDemand // .prices.communityPrice // "?"') + printf " %-4s %-20s %-15s %-12s %-10s \$%-9s\n" "$((i+1))" "$R_GPU" "$R_PROV" "$R_DC" "$R_STOCK" "$R_PRICE" + done + + echo "" + read -rp "Select a GPU [1-${NUM_MATCHES}]: " GPU_CHOICE + + if ! [[ "$GPU_CHOICE" =~ ^[0-9]+$ ]] || [ "$GPU_CHOICE" -lt 1 ] || [ "$GPU_CHOICE" -gt "$NUM_MATCHES" ]; then + die "Invalid selection: $GPU_CHOICE" + fi + + AVAIL_FILTER=$(echo "$ALL_MATCHES" | jq -r --argjson i "$((GPU_CHOICE - 1))" '.[$i]') + + GPU_CLOUD_ID=$(echo "$AVAIL_FILTER" | jq -r '.cloudId') + GPU_SOCKET=$(echo "$AVAIL_FILTER" | jq -r '.socket') + RESOLVED_PROVIDER=$(echo "$AVAIL_FILTER" | jq -r '.provider') + STOCK_STATUS=$(echo "$AVAIL_FILTER" | jq -r '.stockStatus') + PRICE=$(echo "$AVAIL_FILTER" | jq -r '.prices.onDemand // .prices.communityPrice // "unknown"') + DATACENTER=$(echo "$AVAIL_FILTER" | jq -r '.dataCenter // "unknown"') + COUNTRY=$(echo "$AVAIL_FILTER" | jq -r '.country // empty') + SECURITY=$(echo "$AVAIL_FILTER" | jq -r '.security // "unknown"') + DEFAULT_GPU=$(echo "$AVAIL_FILTER" | jq -r '.gpuType') + + ok "Selected GPU" + info " GPU Type: $DEFAULT_GPU" + info " Cloud ID: $GPU_CLOUD_ID" + info " Socket: $GPU_SOCKET" + info " Provider: $RESOLVED_PROVIDER" + info " Datacenter: $DATACENTER" + info " Security: $SECURITY" + info " Stock: $STOCK_STATUS" + info " Price: \$${PRICE}/hr" + + # ------------------------------------------------------------------- + # Register SSH public key via API (or find existing) + # ------------------------------------------------------------------- + SSH_PUB_KEY_PATH="${SSH_KEY_PATH}.pub" + [ -f "$SSH_PUB_KEY_PATH" ] || die "SSH public key not found at $SSH_PUB_KEY_PATH" + SSH_PUBLIC_KEY=$(cat "$SSH_PUB_KEY_PATH") + + info "Checking for existing SSH key on Prime Intellect ..." + EXISTING_KEYS=$(curl -sS -X GET \ + "${BASE_URL}/api/v1/ssh_keys/" \ + -H "Authorization: Bearer ${API_KEY}") + + SSH_KEY_ID=$(echo "$EXISTING_KEYS" | jq -r --arg pk "$SSH_PUBLIC_KEY" ' + if type == "array" then + [.[] | select(.publicKey == $pk)][0].id // empty + elif .items then + [.items[] | select(.publicKey == $pk)][0].id // empty + else empty end') + + if [ -n "$SSH_KEY_ID" ]; then + ok "Found existing SSH key: $SSH_KEY_ID" + else + info "Uploading SSH key ..." + SSH_KEY_NAME="lerobot-provision-$(date +%Y%m%d)" + UPLOAD_RESPONSE=$(curl -sS -X POST \ + "${BASE_URL}/api/v1/ssh_keys/" \ + -H "Authorization: Bearer ${API_KEY}" \ + -H "Content-Type: application/json" \ + -d "$(jq -n --arg name "$SSH_KEY_NAME" --arg pk "$SSH_PUBLIC_KEY" \ + '{name: $name, publicKey: $pk}')") + SSH_KEY_ID=$(echo "$UPLOAD_RESPONSE" | jq -r '.id // empty') + [ -n "$SSH_KEY_ID" ] || die "Failed to upload SSH key. Response:\n$UPLOAD_RESPONSE" + ok "SSH key uploaded: $SSH_KEY_ID" + fi + + # ------------------------------------------------------------------- + # Create pod + # ------------------------------------------------------------------- + POD_NAME="lerobot-$(date +%Y%m%d-%H%M%S)" + info "Creating pod '$POD_NAME' ..." + + CREATE_PAYLOAD=$(jq -n \ + --arg name "$POD_NAME" \ + --arg cloud_id "$GPU_CLOUD_ID" \ + --arg gpu_type "$DEFAULT_GPU" \ + --arg socket "$GPU_SOCKET" \ + --arg image "$DEFAULT_IMAGE" \ + --argjson disk "$DEFAULT_DISK_SIZE" \ + --arg ssh_key_id "$SSH_KEY_ID" \ + --arg datacenter_id "$DATACENTER" \ + --arg country "$COUNTRY" \ + --arg security "$SECURITY" \ + --arg provider "$RESOLVED_PROVIDER" \ + '{ + pod: { + name: $name, + cloudId: $cloud_id, + gpuType: $gpu_type, + socket: $socket, + gpuCount: 1, + diskSize: $disk, + image: $image, + sshKeyId: $ssh_key_id, + dataCenterId: $datacenter_id, + country: $country, + security: $security + }, + provider: { + type: $provider + } + }') + + if [ -n "$TEAM_ID" ]; then + CREATE_PAYLOAD=$(echo "$CREATE_PAYLOAD" | jq --arg tid "$TEAM_ID" '. + {team: {teamId: $tid}}') + fi + + info "Payload: $(echo "$CREATE_PAYLOAD" | jq -c .)" + + CREATE_RESPONSE=$(curl -sS -X POST \ + "${BASE_URL}/api/v1/pods/" \ + -H "Authorization: Bearer ${API_KEY}" \ + -H "Content-Type: application/json" \ + -d "$CREATE_PAYLOAD") + + POD_ID=$(echo "$CREATE_RESPONSE" | jq -r '.id // .pod_id // empty') + [ -n "$POD_ID" ] || die "Failed to create pod. Response:\n$CREATE_RESPONSE" + ok "Pod created: $POD_ID" + + # ------------------------------------------------------------------- + # Poll until ACTIVE / RUNNING (timeout 10 min) + # ------------------------------------------------------------------- + info "Waiting for pod to become ACTIVE (timeout: 10 min) ..." + POLL_INTERVAL=10 + MAX_WAIT=600 + ELAPSED=0 + + while true; do + STATUS_RESPONSE=$(curl -sS -X GET \ + "${BASE_URL}/api/v1/pods/${POD_ID}" \ + -H "Authorization: Bearer ${API_KEY}") + + POD_STATUS=$(echo "$STATUS_RESPONSE" | jq -r '.status // empty') + + if [[ "$POD_STATUS" == "ACTIVE" || "$POD_STATUS" == "RUNNING" ]]; then + ok "Pod is $POD_STATUS" + break + fi + + if [ "$ELAPSED" -ge "$MAX_WAIT" ]; then + die "Timed out after ${MAX_WAIT}s waiting for pod to become ACTIVE (current: $POD_STATUS)" + fi + + info " Status: ${POD_STATUS:-unknown} (${ELAPSED}s elapsed) ..." + sleep "$POLL_INTERVAL" + ELAPSED=$((ELAPSED + POLL_INTERVAL)) + done +fi + +# --------------------------------------------------------------------------- +# Get SSH connection info +# --------------------------------------------------------------------------- +info "Fetching SSH connection info ..." +SSH_STATUS_RESPONSE=$(curl -sS -X GET \ + "${BASE_URL}/api/v1/pods/status?pod_ids=${POD_ID}" \ + -H "Authorization: Bearer ${API_KEY}") + +SSH_CONNECTION=$(echo "$SSH_STATUS_RESPONSE" | jq -r ' + if .data then .data[0].sshConnection // .data[0].ssh_connection // empty + elif type == "array" then .[0].sshConnection // .[0].ssh_connection // empty + else .sshConnection // .ssh_connection // empty + end') + +[ -n "$SSH_CONNECTION" ] || die "Could not parse SSH connection from response:\n$SSH_STATUS_RESPONSE" +ok "SSH connection: $SSH_CONNECTION" + +# Parse "root@1.2.3.4 -p 22" -> user, host, port +SSH_USER=$(echo "$SSH_CONNECTION" | grep -oP '^[^@]+') +SSH_HOST=$(echo "$SSH_CONNECTION" | grep -oP '(?<=@)[^\s]+') +SSH_PORT=$(echo "$SSH_CONNECTION" | grep -oP '(?<=-p\s)\d+' || echo "22") + +info " User: $SSH_USER Host: $SSH_HOST Port: $SSH_PORT" + +# --------------------------------------------------------------------------- +# remote_exec helper +# --------------------------------------------------------------------------- +remote_exec() { + ssh -i "$SSH_KEY_PATH" \ + -o StrictHostKeyChecking=no \ + -o UserKnownHostsFile=/dev/null \ + -p "$SSH_PORT" \ + "${SSH_USER}@${SSH_HOST}" \ + "$@" +} + +# Wait a moment for SSH to be ready +info "Waiting for SSH to be ready ..." +SSH_READY=0 +for i in $(seq 1 30); do + if remote_exec "echo ok" >/dev/null 2>&1; then + SSH_READY=1 + break + fi + sleep 5 +done +[ "$SSH_READY" -eq 1 ] || die "SSH not reachable after 150s" +ok "SSH is ready" + +# =========================================================================== +# Phase 2: Remote setup +# =========================================================================== +echo "" +echo "==============================================" +echo " Phase 2: Remote Setup" +echo "==============================================" +echo "" + +# --------------------------------------------------------------------------- +# 1. Clone repo +# --------------------------------------------------------------------------- +info "Cloning drtc ..." +remote_exec 'git clone https://github.com/jackvial/drtc.git /workspace/drtc' +ok "Repo cloned to /workspace/drtc" + +# --------------------------------------------------------------------------- +# 2. Install UV +# --------------------------------------------------------------------------- +info "Installing uv ..." +remote_exec 'curl -LsSf https://astral.sh/uv/install.sh | sh' +ok "uv installed" + +# --------------------------------------------------------------------------- +# 3. Create venv +# --------------------------------------------------------------------------- +info "Creating Python 3.12 venv ..." +remote_exec 'export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH" && cd /workspace/drtc && uv venv --python 3.12' +ok "Venv created" + +# --------------------------------------------------------------------------- +# 4. Install deps +# --------------------------------------------------------------------------- +info "Installing Python dependencies (this may take a few minutes) ..." +remote_exec 'export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH" && cd /workspace/drtc && source .venv/bin/activate && uv pip install -e ".[async,smolvla]"' +ok "Dependencies installed" + +# --------------------------------------------------------------------------- +# 5. Install Tailscale +# --------------------------------------------------------------------------- +info "Installing Tailscale ..." +remote_exec 'curl -fsSL https://tailscale.com/install.sh | sh' +ok "Tailscale installed" + +# --------------------------------------------------------------------------- +# 6. Tailscale auth +# --------------------------------------------------------------------------- +echo "" +echo "==============================================" +echo -e "${YELLOW}ACTION REQUIRED: Tailscale auth key${NC}" +echo "==============================================" +echo "" +echo " Generate a key at: https://login.tailscale.com/admin/settings/keys" +echo "" +read -rsp "Paste your Tailscale auth key: " TAILSCALE_AUTH_KEY +echo "" +echo "" + +[ -n "$TAILSCALE_AUTH_KEY" ] || die "Tailscale auth key cannot be empty" + +info "Starting tailscaled ..." +remote_exec 'tailscaled --tun=userspace-networking --state=/var/lib/tailscale/tailscaled.state > /var/log/tailscaled.log 2>&1 &' +sleep 3 + +info "Authenticating with Tailscale ..." +remote_exec "tailscale up --auth-key=${TAILSCALE_AUTH_KEY}" +ok "Tailscale authenticated" + +# --------------------------------------------------------------------------- +# 7. Print Tailscale hostname +# --------------------------------------------------------------------------- +info "Fetching Tailscale status ..." +echo "" +TS_STATUS=$(remote_exec 'tailscale status' 2>/dev/null || true) +TS_HOSTNAME=$(remote_exec 'tailscale status --self --json' 2>/dev/null | jq -r '.Self.DNSName // empty' || true) + +echo "==============================================" +echo -e "${GREEN} Provisioning complete!${NC}" +echo "==============================================" +echo "" +echo " Pod ID: $POD_ID" +echo " Pod Name: $POD_NAME" +echo " SSH: ssh -i $SSH_KEY_PATH -p $SSH_PORT ${SSH_USER}@${SSH_HOST}" +echo "" +if [ -n "$TS_HOSTNAME" ]; then + echo -e " ${CYAN}Tailscale domain: ${TS_HOSTNAME}${NC}" +else + echo " Tailscale status:" + echo " $TS_STATUS" +fi +echo "" +echo " Next steps:" +echo " 1. Run run_drtc_experiment_with_remote_server.sh with --remote-server-host set to the Tailscale domain above" +echo " 2. Run the experiment from the robot client" +echo "" diff --git a/scripts/run_drtc_experiment.sh b/scripts/run_drtc_experiment.sh new file mode 100755 index 00000000000..a4a75899cf9 --- /dev/null +++ b/scripts/run_drtc_experiment.sh @@ -0,0 +1,140 @@ +#!/bin/bash +# ============================================================================= +# DRTC Experiment Runner +# ============================================================================= +# +# Starts the policy server (if not already running), then runs experiments +# defined in a YAML config. All arguments are forwarded to the Python +# experiment runner. +# +# Usage: +# ./scripts/run_drtc_experiment.sh --config mixture_of_faults +# ./scripts/run_drtc_experiment.sh --config spike --output_dir results/experiments +# ./scripts/run_drtc_experiment.sh --config examples/experiments/configs/disconnect.yaml +# ./scripts/run_drtc_experiment.sh --viz --config baseline # enable trajectory viz +# +# Flags (consumed by this script, not forwarded to the experiment runner): +# --viz - Start the trajectory visualization server (HTTP :8088) +# +# Environment variables: +# POLICY_SERVER_DELAY_S - Seconds to wait for policy server startup (default: 3) +# POLICY_SERVER_PORT - Port to check / bind (default: 8080) +# +# ============================================================================= + +set -e + +# --- Parse flags consumed by this script ----------------------------------- # +ENABLE_VIZ=false +PASSTHROUGH_ARGS=() +for arg in "$@"; do + case "$arg" in + --viz) ENABLE_VIZ=true ;; + *) PASSTHROUGH_ARGS+=("$arg") ;; + esac +done +set -- "${PASSTHROUGH_ARGS[@]}" + +POLICY_SERVER_DELAY_S="${POLICY_SERVER_DELAY_S:-3}" +POLICY_SERVER_PORT="${POLICY_SERVER_PORT:-8080}" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +LOG_DIR="$PROJECT_ROOT/logs" +LOG_TIMESTAMP="$(date +%Y%m%d_%H%M%S)" +LOG_FILE="$LOG_DIR/policy_server_${LOG_TIMESTAMP}.log" + +# PIDs for cleanup +POLICY_SERVER_PID="" +STARTED_SERVER=false + +cleanup() { + echo "" + echo "Shutting down experiment components..." + + if [ "$STARTED_SERVER" = true ] && [ -n "$POLICY_SERVER_PID" ] && kill -0 "$POLICY_SERVER_PID" 2>/dev/null; then + echo "Stopping policy server (PID: $POLICY_SERVER_PID)..." + kill -TERM "$POLICY_SERVER_PID" 2>/dev/null || true + wait "$POLICY_SERVER_PID" 2>/dev/null || true + fi + + echo "Cleanup complete." + exit 0 +} + +trap cleanup SIGINT SIGTERM EXIT + +cd "$PROJECT_ROOT" + +echo "==============================================" +echo " DRTC Experiment Runner" +echo "==============================================" +echo "Project root: $PROJECT_ROOT" +echo "Arguments: $*" +echo "" + +# ----------------------------------------------------------------------------- +# Step 1: Start Policy Server (kill existing + start fresh) +# ----------------------------------------------------------------------------- +mkdir -p "$LOG_DIR" + +if ss -tlnp 2>/dev/null | grep -q ":${POLICY_SERVER_PORT} " || \ + lsof -iTCP:"${POLICY_SERVER_PORT}" -sTCP:LISTEN >/dev/null 2>&1; then + echo "[1/2] Killing existing policy server on port ${POLICY_SERVER_PORT}..." + # Find and kill the process listening on the port + EXISTING_PID=$(lsof -ti TCP:"${POLICY_SERVER_PORT}" -sTCP:LISTEN 2>/dev/null || true) + if [ -n "$EXISTING_PID" ]; then + kill -TERM $EXISTING_PID 2>/dev/null || true + sleep 1 + # Force-kill if still running + kill -0 $EXISTING_PID 2>/dev/null && kill -9 $EXISTING_PID 2>/dev/null || true + sleep 0.5 + fi + echo " Old server stopped." +fi + +echo "[1/2] Starting policy server..." +echo " Policy server logs: $LOG_FILE" +POLICY_SERVER_CMD=(uv run --no-sync python examples/tutorial/async-inf/policy_server_drtc.py --verbose-diagnostics) +if [ "$ENABLE_VIZ" = true ]; then + POLICY_SERVER_CMD+=(--viz) +fi +"${POLICY_SERVER_CMD[@]}" >"$LOG_FILE" 2>&1 & +POLICY_SERVER_PID=$! +STARTED_SERVER=true +echo " Policy server started (PID: $POLICY_SERVER_PID)" +if [ "$ENABLE_VIZ" = true ]; then + echo " Trajectory visualization: http://localhost:8088" +fi +echo " Waiting ${POLICY_SERVER_DELAY_S}s for server to initialize..." +sleep "$POLICY_SERVER_DELAY_S" + +if ! kill -0 "$POLICY_SERVER_PID" 2>/dev/null; then + echo "ERROR: Policy server failed to start!" + echo "" + echo "---- policy server log (last 200 lines) ----" + tail -n 200 "$LOG_FILE" 2>/dev/null || true + exit 1 +fi +echo " Policy server is running." +echo "" + +# ----------------------------------------------------------------------------- +# Step 2: Run Experiment (foreground) +# ----------------------------------------------------------------------------- +echo "[2/2] Starting experiment..." +echo " Press Ctrl+C to stop." +echo "" +echo "----------------------------------------------" + +uv run --no-sync python examples/experiments/run_drtc_experiment.py "$@" + +# Show server-side diagnostics from the log (if any DIAG_SERVER lines exist) +if [ -f "$LOG_FILE" ] && grep -q "DIAG_SERVER" "$LOG_FILE"; then + echo "" + echo "----------------------------------------------" + echo " Server diagnostics (from $LOG_FILE):" + echo "----------------------------------------------" + grep "DIAG_SERVER" "$LOG_FILE" +fi +echo "" +echo "Server log: $LOG_FILE" diff --git a/scripts/run_drtc_experiment_with_remote_server.sh b/scripts/run_drtc_experiment_with_remote_server.sh new file mode 100755 index 00000000000..6ff71276ad1 --- /dev/null +++ b/scripts/run_drtc_experiment_with_remote_server.sh @@ -0,0 +1,167 @@ +#!/bin/bash +# ============================================================================= +# DRTC Experiment Runner (Remote Policy Server) +# ============================================================================= +# +# Runs DRTC experiments from a robot client machine (e.g., Raspberry Pi) +# while the policy server runs on a remote host reachable directly +# (e.g., via Tailscale). +# +# This script runs the standard experiment runner with --server_address and +# --trajectory_viz_ws_url set to the remote host. +# +# Usage: +# ./scripts/run_drtc_experiment_with_remote_server.sh --remote-server-host --config mixture_of_faults +# ./scripts/run_drtc_experiment_with_remote_server.sh --remote-server-host --config spike --output_dir results/experiments +# +# Notes: +# - Do NOT pass --server_address; this script sets it to the remote host. +# - You can still pass all normal run_drtc_experiment.py flags. +# +# Environment variables: +# REMOTE_GRPC_PORT - Remote gRPC port (default: 8080) +# REMOTE_VIZ_HTTP_PORT - Remote viz HTTP port (default: 8088) +# REMOTE_VIZ_WS_PORT - Remote viz WebSocket port (default: 8089) +# +# ============================================================================= + +set -e + +# Optional debug tracing for this script +if [ "${LEROBOT_DEBUG:-0}" = "1" ]; then + set -x +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" + +usage() { + cat < [run_drtc_experiment.py args...] + +Required: + --remote-server-host Remote policy server host/IP/domain + +Optional: + -h, --help Show this help message + +Environment variables: + REMOTE_GRPC_PORT Remote gRPC port (default: 8080) + REMOTE_VIZ_HTTP_PORT Remote viz HTTP port (default: 8088) + REMOTE_VIZ_WS_PORT Remote viz WebSocket port (default: 8089) +EOF +} + +# ----------------------------------------------------------------------------- +# Remote policy server (direct reachability, e.g. Tailscale) +# ----------------------------------------------------------------------------- +REMOTE_SERVER_HOST="" +REMOTE_GRPC_PORT="${REMOTE_GRPC_PORT:-8080}" +REMOTE_VIZ_HTTP_PORT="${REMOTE_VIZ_HTTP_PORT:-8088}" +REMOTE_VIZ_WS_PORT="${REMOTE_VIZ_WS_PORT:-8089}" + +# ----------------------------------------------------------------------------- +# Parse script arguments and preserve passthrough args +# ----------------------------------------------------------------------------- +HAS_SERVER_ADDRESS_ARG=0 +HAS_VIZ_WS_URL_ARG=0 +PASSTHROUGH_ARGS=() +while [[ $# -gt 0 ]]; do + case "$1" in + --remote-server-host) + if [ $# -lt 2 ]; then + echo "ERROR: Missing value for --remote-server-host." + usage + exit 1 + fi + REMOTE_SERVER_HOST="$2" + shift 2 + ;; + --remote-server-host=*) + REMOTE_SERVER_HOST="${1#*=}" + shift + ;; + -h|--help) + usage + exit 0 + ;; + --server_address|--server_address=*) + HAS_SERVER_ADDRESS_ARG=1 + PASSTHROUGH_ARGS+=("$1") + shift + ;; + --trajectory_viz_ws_url|--trajectory_viz_ws_url=*) + HAS_VIZ_WS_URL_ARG=1 + PASSTHROUGH_ARGS+=("$1") + shift + ;; + *) + PASSTHROUGH_ARGS+=("$1") + shift + ;; + esac +done +set -- "${PASSTHROUGH_ARGS[@]}" + +if [ -z "$REMOTE_SERVER_HOST" ]; then + echo "ERROR: Missing required --remote-server-host." + usage + exit 1 +fi + +if [[ "$REMOTE_SERVER_HOST" =~ [[:space:]] ]]; then + echo "ERROR: --remote-server-host cannot contain whitespace: $REMOTE_SERVER_HOST" + exit 1 +fi + +if [ "$HAS_SERVER_ADDRESS_ARG" = "1" ]; then + echo "ERROR: Do not pass --server_address to this script." + echo " This script sets --server_address to ${REMOTE_SERVER_HOST}:${REMOTE_GRPC_PORT}." + exit 1 +fi + +cd "$PROJECT_ROOT" + +echo "==============================================" +echo " DRTC Experiment Runner (Remote Server)" +echo "==============================================" +echo "Project root: $PROJECT_ROOT" +echo "Arguments: $*" +echo "" + +# ----------------------------------------------------------------------------- +# Optional: fail fast if remote policy server is unreachable +# ----------------------------------------------------------------------------- +tcp_probe() { + local host="$1" + local port="$2" + timeout 2 bash -c "cat < /dev/null > /dev/tcp/${host}/${port}" >/dev/null 2>&1 +} + +if ! tcp_probe "$REMOTE_SERVER_HOST" "$REMOTE_GRPC_PORT"; then + echo "ERROR: Policy server is not reachable at ${REMOTE_SERVER_HOST}:${REMOTE_GRPC_PORT}." + echo " Start/verify the server on the remote host (e.g. via scripts/start_drtc_server.sh)." + exit 1 +fi + +echo "Starting experiment..." +echo " Server address: ${REMOTE_SERVER_HOST}:${REMOTE_GRPC_PORT}" +if [ "$HAS_VIZ_WS_URL_ARG" = "0" ]; then + echo " Trajectory viz WS URL: ws://${REMOTE_SERVER_HOST}:${REMOTE_VIZ_WS_PORT}" +fi +echo " Press Ctrl+C to stop." +echo "" +echo "----------------------------------------------" + +EXPERIMENT_CMD=( + uv run --no-sync python examples/experiments/run_drtc_experiment.py + --server_address "${REMOTE_SERVER_HOST}:${REMOTE_GRPC_PORT}" +) + +if [ "$HAS_VIZ_WS_URL_ARG" = "0" ]; then + EXPERIMENT_CMD+=(--trajectory_viz_ws_url "ws://${REMOTE_SERVER_HOST}:${REMOTE_VIZ_WS_PORT}") +fi + +EXPERIMENT_CMD+=("$@") +"${EXPERIMENT_CMD[@]}" diff --git a/scripts/start_drtc_client.sh b/scripts/start_drtc_client.sh new file mode 100755 index 00000000000..8940c09341f --- /dev/null +++ b/scripts/start_drtc_client.sh @@ -0,0 +1,263 @@ +#!/bin/bash +# ============================================================================= +# Async Inference Startup Script +# ============================================================================= +# +# Starts the async inference components in the correct order: +# 1. Policy server (includes trajectory visualization via HTTP/WebSocket) +# 2. Robot client (connects to policy server) +# +# The trajectory visualization runs inside the policy server, so you can +# view it at http://localhost:8088 once the policy server starts. +# +# Usage: +# ./scripts/start_drtc_client.sh --tunnel-ssh-user-host +# +# Environment variables: +# TUNNEL_SSH_PORT - SSH port on cloud host (default: 18468) +# LEROBOT_DEBUG - Set to 1 for debug logging +# +# ============================================================================= + +set -e + +# Configuration +POLICY_SERVER_DELAY_S="${POLICY_SERVER_DELAY_S:-3}" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +LOG_DIR="$PROJECT_ROOT/logs" +TUNNEL_LOG_FILE="$LOG_DIR/ssh_tunnel.log" + +usage() { + cat < + +Required: + --tunnel-ssh-user-host SSH target used for tunnel setup + +Optional: + -h, --help Show this help message + +Environment variables: + TUNNEL_SSH_PORT SSH port on cloud host (default: 18468) + TUNNEL_GRPC_LOCAL_PORT Local tunnel port for gRPC (default: 18080) + TUNNEL_VIZ_HTTP_LOCAL_PORT Local tunnel port for viz HTTP (default: 18088) + TUNNEL_VIZ_WS_LOCAL_PORT Local tunnel port for viz WS (default: 18089) + TUNNEL_GRPC_REMOTE_PORT Remote gRPC port (default: 8080) + TUNNEL_VIZ_HTTP_REMOTE_PORT Remote viz HTTP port (default: 8088) + TUNNEL_VIZ_WS_REMOTE_PORT Remote viz WS port (default: 8089) +EOF +} + +# ----------------------------------------------------------------------------- +# Cloud tunnel configuration (LAN box -> cloud policy server) +# ----------------------------------------------------------------------------- +# These are the *local* ports on this machine. They forward to the cloud machine's +# standard ports (8080/8088/8089). We use 1808x to avoid conflicts with other +# port forwarders (e.g. editor/remote tooling). +TUNNEL_SSH_PORT="${TUNNEL_SSH_PORT:-18468}" +TUNNEL_SSH_USER_HOST="" +TUNNEL_GRPC_LOCAL_PORT="${TUNNEL_GRPC_LOCAL_PORT:-18080}" +TUNNEL_VIZ_HTTP_LOCAL_PORT="${TUNNEL_VIZ_HTTP_LOCAL_PORT:-18088}" +TUNNEL_VIZ_WS_LOCAL_PORT="${TUNNEL_VIZ_WS_LOCAL_PORT:-18089}" + +# Remote ports (on the cloud host) +TUNNEL_GRPC_REMOTE_PORT="${TUNNEL_GRPC_REMOTE_PORT:-8080}" +TUNNEL_VIZ_HTTP_REMOTE_PORT="${TUNNEL_VIZ_HTTP_REMOTE_PORT:-8088}" +TUNNEL_VIZ_WS_REMOTE_PORT="${TUNNEL_VIZ_WS_REMOTE_PORT:-8089}" + +# If set to 1, allow reusing an existing listener on the tunnel ports. +# Default is to fail fast (helps avoid "dangling" / stale tunnels). +TUNNEL_REUSE_EXISTING="${TUNNEL_REUSE_EXISTING:-0}" + +# Parse script arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --tunnel-ssh-user-host) + if [ $# -lt 2 ]; then + echo "ERROR: Missing value for --tunnel-ssh-user-host." + usage + exit 1 + fi + TUNNEL_SSH_USER_HOST="$2" + shift 2 + ;; + --tunnel-ssh-user-host=*) + TUNNEL_SSH_USER_HOST="${1#*=}" + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "ERROR: Unknown argument: $1" + usage + exit 1 + ;; + esac +done + +if [ -z "$TUNNEL_SSH_USER_HOST" ]; then + echo "ERROR: Missing required --tunnel-ssh-user-host." + usage + exit 1 +fi + +if [[ "$TUNNEL_SSH_USER_HOST" =~ [[:space:]] ]]; then + echo "ERROR: --tunnel-ssh-user-host cannot contain whitespace: $TUNNEL_SSH_USER_HOST" + exit 1 +fi + +if [[ "$TUNNEL_SSH_USER_HOST" != *@* ]]; then + echo "ERROR: --tunnel-ssh-user-host must be in USER@HOST format: $TUNNEL_SSH_USER_HOST" + exit 1 +fi + +# PIDs for cleanup +POLICY_SERVER_PID="" +SSH_TUNNEL_PID="" + +# Cleanup function +cleanup() { + echo "" + echo "Shutting down async inference components..." + + if [ -n "$POLICY_SERVER_PID" ] && kill -0 "$POLICY_SERVER_PID" 2>/dev/null; then + echo "Stopping policy server (PID: $POLICY_SERVER_PID)..." + kill -TERM "$POLICY_SERVER_PID" 2>/dev/null || true + wait "$POLICY_SERVER_PID" 2>/dev/null || true + fi + + if [ -n "$SSH_TUNNEL_PID" ] && kill -0 "$SSH_TUNNEL_PID" 2>/dev/null; then + echo "Stopping SSH tunnel (PID: $SSH_TUNNEL_PID)..." + kill -TERM "$SSH_TUNNEL_PID" 2>/dev/null || true + wait "$SSH_TUNNEL_PID" 2>/dev/null || true + fi + + echo "Cleanup complete." + exit 0 +} + +# Register signal handlers +trap cleanup SIGINT SIGTERM EXIT + +# Change to project root +cd "$PROJECT_ROOT" +mkdir -p "$LOG_DIR" + +# ----------------------------------------------------------------------------- +# Step 1: Start SSH tunnel to cloud policy server +# ----------------------------------------------------------------------------- +echo "[1/2] Starting SSH tunnel to cloud policy server..." +echo " Target: $TUNNEL_SSH_USER_HOST (ssh port: $TUNNEL_SSH_PORT)" +echo " Local forwards:" +echo " - 127.0.0.1:${TUNNEL_GRPC_LOCAL_PORT} -> localhost:${TUNNEL_GRPC_REMOTE_PORT} (gRPC)" +echo " - 127.0.0.1:${TUNNEL_VIZ_HTTP_LOCAL_PORT} -> localhost:${TUNNEL_VIZ_HTTP_REMOTE_PORT} (viz HTTP)" +echo " - 127.0.0.1:${TUNNEL_VIZ_WS_LOCAL_PORT} -> localhost:${TUNNEL_VIZ_WS_REMOTE_PORT} (viz WS)" + +# Check whether any of the local tunnel ports are already bound. +existing_listeners="" +for p in "$TUNNEL_GRPC_LOCAL_PORT" "$TUNNEL_VIZ_HTTP_LOCAL_PORT" "$TUNNEL_VIZ_WS_LOCAL_PORT"; do + if ss -lnt | grep -Eq ":${p}\b"; then + existing_listeners="1" + fi +done + +if [ -n "$existing_listeners" ]; then + echo " Detected existing listener(s) on one or more tunnel ports:" + ss -lntp | grep -E ":((${TUNNEL_GRPC_LOCAL_PORT})|(${TUNNEL_VIZ_HTTP_LOCAL_PORT})|(${TUNNEL_VIZ_WS_LOCAL_PORT}))\b" || true + if [ "$TUNNEL_REUSE_EXISTING" = "1" ]; then + echo " Reusing existing listener(s) (TUNNEL_REUSE_EXISTING=1)." + else + echo "ERROR: Tunnel ports are already in use." + echo " Either stop the existing tunnel/process, or re-run with:" + echo " TUNNEL_REUSE_EXISTING=1 ./scripts/start_client.sh" + exit 1 + fi +else + # ExitOnForwardFailure ensures we fail fast if any -L can't bind. + # LogLevel=ERROR + redirect prevents noisy 'channel open failed' spam in your terminal. + : >"$TUNNEL_LOG_FILE" + ssh -p "$TUNNEL_SSH_PORT" -N \ + -o ExitOnForwardFailure=yes \ + -o ConnectTimeout=10 \ + -o ServerAliveInterval=30 \ + -o ServerAliveCountMax=3 \ + -o LogLevel=ERROR \ + -L "${TUNNEL_GRPC_LOCAL_PORT}:localhost:${TUNNEL_GRPC_REMOTE_PORT}" \ + -L "${TUNNEL_VIZ_HTTP_LOCAL_PORT}:localhost:${TUNNEL_VIZ_HTTP_REMOTE_PORT}" \ + -L "${TUNNEL_VIZ_WS_LOCAL_PORT}:localhost:${TUNNEL_VIZ_WS_REMOTE_PORT}" \ + "$TUNNEL_SSH_USER_HOST" >"$TUNNEL_LOG_FILE" 2>&1 & + SSH_TUNNEL_PID=$! + echo " SSH tunnel started (PID: $SSH_TUNNEL_PID)" + + # Give ssh a moment to error out if something is wrong, then validate. + sleep 0.2 + if ! kill -0 "$SSH_TUNNEL_PID" 2>/dev/null; then + echo "ERROR: SSH tunnel failed to start (process exited)." + echo "---- ssh tunnel log (last 50 lines) ----" + tail -n 50 "$TUNNEL_LOG_FILE" 2>/dev/null || true + exit 1 + fi +fi + +echo "" + +# ----------------------------------------------------------------------------- +# Tunnel health checks (fail fast on missing remote services) +# ----------------------------------------------------------------------------- +# Notes: +# - Local ports 18080/18088/18089 exist on THIS machine (LAN box), not on the cloud host. +# - Cloud services must listen on 8080/8088/8089 for the forwards to succeed. +echo "Checking tunnel targets..." + +# Minimal TCP connect probe (no protocol validation). +tcp_probe() { + local port="$1" + # `timeout` is used to avoid hanging if something wedges. + timeout 1 bash -c "cat < /dev/null > /dev/tcp/127.0.0.1/${port}" >/dev/null 2>&1 +} + +# gRPC is required. If this fails, the client will hit "connection reset by peer". +if ! tcp_probe "$TUNNEL_GRPC_LOCAL_PORT"; then + echo "ERROR: Tunnel is up but policy server is not reachable on 127.0.0.1:${TUNNEL_GRPC_LOCAL_PORT}." + echo " This usually means the cloud policy server is not listening on localhost:${TUNNEL_GRPC_REMOTE_PORT}." + echo " On the cloud machine, confirm with:" + echo " ss -lntp | egrep ':(8080|8088|8089)\\b' || true" + echo " If needed, start it (cloud):" + echo " uv run --no-sync python examples/tutorial/async-inf/policy_server_drtc.py --host 127.0.0.1 --port 8080" + echo "---- ssh tunnel log (last 50 lines) ----" + tail -n 50 "$TUNNEL_LOG_FILE" 2>/dev/null || true + exit 1 +fi + +# Viz endpoints are optional (warn only). HTTP typically works even if WS doesn't. +if ! tcp_probe "$TUNNEL_VIZ_HTTP_LOCAL_PORT"; then + echo "WARNING: Viz HTTP not reachable on 127.0.0.1:${TUNNEL_VIZ_HTTP_LOCAL_PORT} (tunnels to :${TUNNEL_VIZ_HTTP_REMOTE_PORT})." +fi +if ! tcp_probe "$TUNNEL_VIZ_WS_LOCAL_PORT"; then + echo "WARNING: Viz WebSocket not reachable on 127.0.0.1:${TUNNEL_VIZ_WS_LOCAL_PORT} (tunnels to :${TUNNEL_VIZ_WS_REMOTE_PORT})." + echo " If 8088 is listening but 8089 is not, install websockets on the cloud env and restart the policy server." +fi + +echo "" + +# ----------------------------------------------------------------------------- +# Step 2: Start Robot Client (foreground) +# ----------------------------------------------------------------------------- +echo "[2/2] Starting robot client..." +echo " Press Ctrl+C to stop all components." +echo "" +echo "----------------------------------------------" + +# Configure the example robot client to use the tunnel's local ports. +export LEROBOT_SERVER_ADDRESS="127.0.0.1:${TUNNEL_GRPC_LOCAL_PORT}" +export LEROBOT_TRAJECTORY_VIZ_WS_URL="ws://localhost:${TUNNEL_VIZ_WS_LOCAL_PORT}" + +# Run robot client in foreground (this blocks until Ctrl+C) +# Use --no-sync to skip dependency resolution (avoids grpcio version conflicts) +uv run --no-sync python examples/tutorial/async-inf/robot_client_drtc.py + +# If robot client exits normally, cleanup will be called via trap diff --git a/scripts/start_drtc_server.sh b/scripts/start_drtc_server.sh new file mode 100755 index 00000000000..c0ef11eb73b --- /dev/null +++ b/scripts/start_drtc_server.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# ============================================================================= +# Async Inference Startup Script +# ============================================================================= +# +# Starts the async inference components in the correct order: +# 1. Policy server (includes trajectory visualization via HTTP/WebSocket) +# 2. Robot client (connects to policy server) +# +# The trajectory visualization runs inside the policy server, so you can +# view it at http://localhost:8088 once the policy server starts. +# +# Usage: +# ./scripts/start_async_inference.sh # Normal mode +# +# Environment variables: +# POLICY_SERVER_DELAY_S - Seconds to wait after starting policy server (default: 3) +# LEROBOT_DEBUG - Set to 1 for debug logging +# +# ============================================================================= + +set -e + +# Optional debug tracing for this script +if [ "${LEROBOT_DEBUG:-0}" = "1" ]; then + set -x +fi + +# Configuration +POLICY_SERVER_DELAY_S="${POLICY_SERVER_DELAY_S:-3}" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +LOG_DIR="$PROJECT_ROOT/logs" +LOG_FILE="$LOG_DIR/policy_server.log" + +# PIDs for cleanup +POLICY_SERVER_PID="" + +# Cleanup function +cleanup() { + echo "" + echo "Shutting down async inference components..." + + if [ -n "$POLICY_SERVER_PID" ] && kill -0 "$POLICY_SERVER_PID" 2>/dev/null; then + echo "Stopping policy server (PID: $POLICY_SERVER_PID)..." + kill -TERM "$POLICY_SERVER_PID" 2>/dev/null || true + wait "$POLICY_SERVER_PID" 2>/dev/null || true + fi + + echo "Cleanup complete." + exit 0 +} + +# Register signal handlers +trap cleanup SIGINT SIGTERM + +# Change to project root +cd "$PROJECT_ROOT" + +echo "==============================================" +echo " Async Inference Startup Script" +echo "==============================================" +echo "Project root: $PROJECT_ROOT" +echo "" + +# ----------------------------------------------------------------------------- +# Step 1: Start Policy Server (includes trajectory visualization) +# ----------------------------------------------------------------------------- +echo "[1/2] Starting policy server..." +# Ensure log directory exists and capture server output for debugging. +mkdir -p "$LOG_DIR" +echo " Policy server logs: $LOG_FILE" +# Use --no-sync to skip dependency resolution (avoids grpcio version conflicts) +uv run --no-sync python examples/tutorial/async-inf/policy_server_drtc.py >"$LOG_FILE" 2>&1 & +POLICY_SERVER_PID=$! +echo " Policy server started (PID: $POLICY_SERVER_PID)" +echo " Trajectory visualization: http://localhost:8088" +echo " Waiting ${POLICY_SERVER_DELAY_S}s for server to initialize..." +sleep "$POLICY_SERVER_DELAY_S" + +# Verify policy server is still running +if ! kill -0 "$POLICY_SERVER_PID" 2>/dev/null; then + echo "ERROR: Policy server failed to start!" + echo "" + echo "---- policy server log (last 200 lines) ----" + tail -n 200 "$LOG_FILE" 2>/dev/null || true + exit 1 +fi +echo " Policy server is running." +echo "" + +# Keep the policy server alive until Ctrl+C (otherwise the EXIT trap will stop it). +echo " Press Ctrl+C to stop the policy server." + +# `wait` returns the child exit code. With `set -e`, capture it explicitly so we can print logs. +set +e +wait "$POLICY_SERVER_PID" +POLICY_SERVER_EXIT_CODE=$? +set -e + +echo "" +echo "Policy server exited with code: $POLICY_SERVER_EXIT_CODE" +echo "---- policy server log (last 200 lines) ----" +tail -n 200 "$LOG_FILE" 2>/dev/null || true + +exit "$POLICY_SERVER_EXIT_CODE" diff --git a/src/lerobot/async_inference/configs_drtc.py b/src/lerobot/async_inference/configs_drtc.py new file mode 100644 index 00000000000..8042558db30 --- /dev/null +++ b/src/lerobot/async_inference/configs_drtc.py @@ -0,0 +1,542 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration classes for the DRTC async inference implementation. + +These configurations follow the DRTC algorithm +with proper SPSC mailboxes, Jacobson-Karels latency estimation, +cool-down mechanism, and freshest-observation-wins merging. +""" + +from dataclasses import dataclass, field + +from lerobot.robots.config import RobotConfig + +from .constants import DEFAULT_FPS, DEFAULT_OBS_QUEUE_TIMEOUT +from .utils.simulation import DisconnectConfig, DropConfig, DuplicateConfig, ReorderConfig, SpikeDelayConfig + +# ============================================================================= +# Robot Client Configuration +# ============================================================================= + + +@dataclass +class RobotClientDrtcConfig: + """Configuration for the DRTC robot client. + + This configuration follows the DRTC algorithm + with proper SPSC mailboxes, Jacobson-Karels latency estimation, + cool-down mechanism, and freshest-observation-wins merging. + """ + + # Policy configuration + policy_type: str = field(metadata={"help": "Type of policy to use (e.g., 'act', 'smolvla')"}) + pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"}) + + # Robot configuration + robot: RobotConfig = field(metadata={"help": "Robot configuration"}) + + # Actions per chunk (should be <= policy's max action horizon) + actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk (H in the paper)"}) + + # Hardware metadata (for experiment reports) + robot_type: str = field(default="", metadata={"help": "Robot type identifier (e.g. so101)"}) + gpu: str = field(default="", metadata={"help": "GPU used for inference (e.g. RTX 4070 TI SUPER)"}) + client_host: str = field( + default="", metadata={"help": "Description of the client host (e.g. local server)"} + ) + server_host: str = field( + default="", metadata={"help": "Description of the server host (e.g. local server)"} + ) + + # Task instruction for the robot + task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"}) + + # Network configuration + server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"}) + + # Device configuration (for policy inference on server) + policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"}) + + # Control frequency + fps: int = field(default=DEFAULT_FPS, metadata={"help": "Control loop frequency in Hz"}) + + # DRTC parameters + s_min: int = field( + default=14, + metadata={ + "help": "Minimum execution horizon in action steps (s_min from RTC paper). " + "Trigger inference when schedule_size <= s_min. " + "Effective execution horizon is max(s_min, latency_steps)." + }, + ) + epsilon: int = field( + default=1, + metadata={ + "help": "Cooldown buffer in action steps. " + "After triggering inference, cooldown is set to latency_steps + epsilon. " + "Small values (1-2) prevent over-triggering without adding significant delay." + }, + ) + latency_estimator_type: str = field( + default="jk", + metadata={"help": "Latency estimator type: 'jk' (Jacobson-Karels), 'max_last_10', or 'fixed'"}, + ) + latency_alpha: float = field( + default=0.125, metadata={"help": "Jacobson-Karels smoothing factor for RTT mean"} + ) + latency_beta: float = field( + default=0.25, metadata={"help": "Jacobson-Karels smoothing factor for RTT deviation"} + ) + latency_k: float = field( + default=1.5, metadata={"help": "Jacobson-Karels scaling factor for deviation (K)"} + ) + # Debug configuration + debug_visualize_queue_size: bool = field( + default=False, metadata={"help": "Visualize the action queue size after stopping"} + ) + + # RTC (client-driven, server-side inpainting; flow policies only) + rtc_enabled: bool = field( + default=True, + metadata={"help": "Enable RTC-style inpainting on the policy server (flow policies only)"}, + ) + rtc_max_guidance_weight: float | None = field( + default=None, + metadata={ + "help": "RTC max guidance weight (clamp). If None, uses num_flow_matching_steps " + "(Alex Soare optimization: https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html)" + }, + ) + rtc_prefix_attention_schedule: str = field( + default="linear", + metadata={"help": "RTC prefix attention schedule: zeros|ones|linear|exp"}, + ) + rtc_sigma_d: float = field( + default=0.2, + metadata={ + "help": "RTC prior variance σ_d. Lower values (e.g., 0.2) give stronger guidance " + "and smoother transitions. 1.0 = original RTC behavior. " + "(Alex Soare optimization: https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html)" + }, + ) + rtc_full_trajectory_alignment: bool = field( + default=False, + metadata={ + "help": "Skip gradient computation in RTC and use error directly. " + "Faster and smoother when distance between chunks is small." + }, + ) + num_flow_matching_steps: int | None = field( + default=8, + metadata={ + "help": "Override for number of flow matching denoising steps. " + "If None, uses the policy's default (e.g., 10 for PI0/SmolVLA). " + "Higher values = smoother but slower inference. " + "(Alex Soare optimization: Beta should scale with n)" + }, + ) + + # Diagnostic metrics (console output; avg/max only) + metrics_diagnostic_enabled: bool = field( + default=True, + metadata={"help": "Enable periodic diagnostic metrics printed to console (avg/max timings)"}, + ) + metrics_diagnostic_interval_s: float = field( + default=2.0, metadata={"help": "How often to print diagnostic metrics (seconds)"} + ) + metrics_diagnostic_window_s: float = field( + default=10.0, metadata={"help": "Rolling window for diagnostic metrics (seconds)"} + ) + metrics_diagnostic_verbose: bool = field( + default=False, + metadata={"help": "If True, include full timing/counter details in diagnostic console output"}, + ) + + # Trajectory visualization (sends data to policy server via gRPC) + trajectory_viz_enabled: bool = field( + default=False, + metadata={"help": "Enable sending trajectory data to policy server for visualization"}, + ) + trajectory_viz_ws_url: str = field( + default="ws://localhost:8089", + metadata={"help": "WebSocket URL for trajectory visualization server (for executed actions)"}, + ) + + # Control-loop clocking (optional) + control_use_deadline_clock: bool = field( + default=True, + metadata={"help": "Use a deadline-based control clock (reduces jitter under overruns)"}, + ) + + # Observation sender robustness + obs_fallback_on_failure: bool = field( + default=True, + metadata={ + "help": "If robot observation capture fails, reuse the last good observation to avoid stalling" + }, + ) + obs_fallback_max_age_s: float = field( + default=2.0, + metadata={"help": "Max age (seconds) of the last good observation that may be reused on failure"}, + ) + + # Simulation mode (for experiments) + use_mock_robot: bool = field( + default=False, + metadata={ + "help": "Use mock robot instead of real hardware (for experiments without a physical robot)" + }, + ) + cooldown_enabled: bool = field( + default=True, + metadata={"help": "Enable cooldown mechanism (set False for classic async baseline)"}, + ) + inference_reset_mode: str = field( + default="cooldown", + metadata={ + "help": "Mode for resetting inference readiness: " + "'cooldown' (default) decrements each tick and allows recovery from drops; " + "'merge_reset' resets only when actions are merged (RTC-style, stalls on drops)" + }, + ) + + # Drop injection (for experiments) + drop_obs_config: DropConfig | None = field( + default=None, + metadata={ + "help": "Configuration for observation drop injection. " + "Example: DropConfig(random_drop_p=0.05) or DropConfig(burst_period_s=20, burst_duration_s=1)" + }, + ) + drop_action_config: DropConfig | None = field( + default=None, + metadata={ + "help": "Configuration for action chunk drop injection. " + "Example: DropConfig(random_drop_p=0.05) or DropConfig(burst_period_s=20, burst_duration_s=1)" + }, + ) + + # Duplicate injection (for experiments) + dup_obs_config: DuplicateConfig | None = field( + default=None, + metadata={ + "help": "Configuration for observation duplicate injection. " + "Example: DuplicateConfig(duplicates=[DuplicateEvent(start_s=5, duration_s=1)])" + }, + ) + dup_action_config: DuplicateConfig | None = field( + default=None, + metadata={ + "help": "Configuration for action chunk duplicate injection. " + "Example: DuplicateConfig(duplicates=[DuplicateEvent(start_s=5, duration_s=1)])" + }, + ) + + # Reorder injection (for experiments) + reorder_obs_config: ReorderConfig | None = field( + default=None, + metadata={ + "help": "Configuration for observation reorder injection (pairwise hold-and-swap). " + "Example: ReorderConfig(reorders=[ReorderEvent(start_s=5, duration_s=2)])" + }, + ) + reorder_action_config: ReorderConfig | None = field( + default=None, + metadata={ + "help": "Configuration for action chunk reorder injection (pairwise hold-and-swap). " + "Example: ReorderConfig(reorders=[ReorderEvent(start_s=5, duration_s=2)])" + }, + ) + + # Disconnect injection (for experiments) + disconnect_config: DisconnectConfig | None = field( + default=None, + metadata={ + "help": "Configuration for network disconnect injection (blocks obs and action threads). " + "Example: DisconnectConfig(disconnects=[DisconnectEvent(start_s=5, duration_s=3)])" + }, + ) + + # Spike injection (for experiments, passed to server) + # List of dicts: [{"start_s": 5.0, "delay_ms": 2000}, ...] + spikes: list[dict] = field( + default_factory=list, + metadata={ + "help": "Explicit spike events as list of dicts. " + "Example: [{'start_s': 5, 'delay_ms': 2000}, {'start_s': 15, 'delay_ms': 1000}]" + }, + ) + + # Experiment metrics (disk output; CSV export) + metrics_path: str | None = field( + default=None, + metadata={"help": "Path to write experiment metrics CSV (None = disabled)"}, + ) + # Action smoothing to reduce policy jitter / servo hunting + # Modes: "none", "adaptive_lowpass", "hold_stable", "butterworth" + action_filter_mode: str = field( + default="none", + metadata={ + "help": "Action filtering mode: " + "'none' = no filtering, " + "'adaptive_lowpass' = IIR filter with adaptive alpha based on delta magnitude, " + "'hold_stable' = hold previous action when delta is below threshold (eliminates jitter), " + "'butterworth' = proper low-pass filter with configurable cutoff frequency" + }, + ) + action_filter_alpha_min: float = field( + default=0.1, + metadata={ + "help": "Low-pass filter alpha for small deltas (heavy smoothing). " + "Used when action delta is below deadband threshold. Range: (0, 1]. " + "Lower = more smoothing. 0.1 gives strong attenuation of high-freq jitter." + }, + ) + action_filter_alpha_max: float = field( + default=0.5, + metadata={ + "help": "Low-pass filter alpha for large deltas (faster response). " + "Used when action delta exceeds deadband threshold. Range: (0, 1]" + }, + ) + action_filter_deadband: float = field( + default=0.05, + metadata={ + "help": "Deadband threshold in action units (radians for joints). " + "For 'adaptive_lowpass': deltas below this get alpha_min, above get alpha_max. " + "For 'hold_stable': deltas below this are ignored entirely. " + "Default 0.05 rad ≈ 3 degrees." + }, + ) + action_filter_butterworth_cutoff: float = field( + default=10.0, + metadata={ + "help": "Butterworth filter cutoff frequency in Hz. " + "Frequencies above this are attenuated. Should be < fps/2 (Nyquist). " + "Recommended: 10-12 Hz for 60 Hz control rate." + }, + ) + action_filter_butterworth_order: int = field( + default=2, + metadata={ + "help": "Butterworth filter order (1-4). Higher = sharper frequency rolloff but more phase lag." + }, + ) + action_filter_gain: float = field( + default=1.0, + metadata={ + "help": "Gain multiplier applied after filtering to compensate amplitude attenuation. " + "Values > 1.0 boost the filtered signal." + }, + ) + action_filter_past_buffer_size: int = field( + default=5, + metadata={ + "help": "Number of past executed actions to keep in filter buffer. " + "Used by 'median' and 'butterworth' modes for history." + }, + ) + + @property + def environment_dt(self) -> float: + """Environment time step in seconds.""" + return 1.0 / self.fps + + def __post_init__(self): + """Validate configuration after initialization.""" + if not self.server_address: + raise ValueError("server_address cannot be empty") + if not self.policy_type: + raise ValueError("policy_type cannot be empty") + if not self.pretrained_name_or_path: + raise ValueError("pretrained_name_or_path cannot be empty") + if self.fps <= 0: + raise ValueError(f"fps must be positive, got {self.fps}") + if self.actions_per_chunk <= 0: + raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}") + if self.s_min <= 0: + raise ValueError(f"s_min must be positive, got {self.s_min}") + if self.s_min >= self.actions_per_chunk: + raise ValueError( + f"s_min must be < actions_per_chunk, got {self.s_min} >= {self.actions_per_chunk}" + ) + if self.epsilon < 0: + raise ValueError(f"epsilon must be non-negative, got {self.epsilon}") + if self.latency_estimator_type not in ("jk", "max_last_10", "fixed"): + raise ValueError( + f"latency_estimator_type must be 'jk', 'max_last_10', or 'fixed', got {self.latency_estimator_type}" + ) + if self.inference_reset_mode not in ("cooldown", "merge_reset"): + raise ValueError( + f"inference_reset_mode must be 'cooldown' or 'merge_reset', got {self.inference_reset_mode}" + ) + if self.metrics_diagnostic_interval_s <= 0: + raise ValueError( + f"metrics_diagnostic_interval_s must be positive, got {self.metrics_diagnostic_interval_s}" + ) + if self.metrics_diagnostic_window_s <= 0: + raise ValueError( + f"metrics_diagnostic_window_s must be positive, got {self.metrics_diagnostic_window_s}" + ) + if self.obs_fallback_max_age_s <= 0: + raise ValueError(f"obs_fallback_max_age_s must be positive, got {self.obs_fallback_max_age_s}") + if self.rtc_max_guidance_weight is not None and self.rtc_max_guidance_weight <= 0: + raise ValueError( + f"rtc_max_guidance_weight must be positive or None, got {self.rtc_max_guidance_weight}" + ) + if self.rtc_sigma_d <= 0: + raise ValueError(f"rtc_sigma_d must be positive, got {self.rtc_sigma_d}") + if self.num_flow_matching_steps is not None and self.num_flow_matching_steps <= 0: + raise ValueError( + f"num_flow_matching_steps must be positive or None, got {self.num_flow_matching_steps}" + ) + if self.action_filter_mode not in ("none", "adaptive_lowpass", "hold_stable", "butterworth"): + raise ValueError( + f"action_filter_mode must be 'none', 'adaptive_lowpass', 'hold_stable', or 'butterworth', " + f"got {self.action_filter_mode}" + ) + if self.action_filter_alpha_min <= 0 or self.action_filter_alpha_min > 1: + raise ValueError(f"action_filter_alpha_min must be in (0, 1], got {self.action_filter_alpha_min}") + if self.action_filter_alpha_max <= 0 or self.action_filter_alpha_max > 1: + raise ValueError(f"action_filter_alpha_max must be in (0, 1], got {self.action_filter_alpha_max}") + if self.action_filter_deadband < 0: + raise ValueError( + f"action_filter_deadband must be non-negative, got {self.action_filter_deadband}" + ) + if self.action_filter_butterworth_cutoff <= 0: + raise ValueError( + f"action_filter_butterworth_cutoff must be positive, got {self.action_filter_butterworth_cutoff}" + ) + if self.action_filter_butterworth_cutoff >= self.fps / 2: + raise ValueError( + f"action_filter_butterworth_cutoff must be < fps/2 (Nyquist), " + f"got {self.action_filter_butterworth_cutoff} >= {self.fps / 2}" + ) + if self.action_filter_butterworth_order < 1 or self.action_filter_butterworth_order > 4: + raise ValueError( + f"action_filter_butterworth_order must be 1-4, got {self.action_filter_butterworth_order}" + ) + if self.action_filter_gain <= 0: + raise ValueError(f"action_filter_gain must be positive, got {self.action_filter_gain}") + if self.action_filter_past_buffer_size < 1: + raise ValueError( + f"action_filter_past_buffer_size must be >= 1, got {self.action_filter_past_buffer_size}" + ) + + +# ============================================================================= +# Policy Server Configuration +# ============================================================================= + + +@dataclass +class PolicyServerDrtcConfig: + """Configuration for the DRTC PolicyServer. + + This class defines all configurable parameters for the PolicyServer, + following the 2-thread model from the DRTC paper. + """ + + # Networking configuration + host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"}) + port: int = field(default=8080, metadata={"help": "Port number to bind the server to"}) + + # Timing configuration + fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second (control frequency)"}) + + # Diagnostic metrics (console output; avg/max only) + metrics_diagnostic_enabled: bool = field( + default=True, + metadata={"help": "Enable periodic diagnostic metrics printed to console (avg/max timings)"}, + ) + metrics_diagnostic_interval_s: float = field( + default=2.0, metadata={"help": "How often to print diagnostic metrics (seconds)"} + ) + metrics_diagnostic_window_s: float = field( + default=10.0, metadata={"help": "Rolling window for diagnostic metrics (seconds)"} + ) + metrics_diagnostic_verbose: bool = field( + default=False, + metadata={"help": "If True, include full timing/counter details in diagnostic console output"}, + ) + + # Observation queue timeout + obs_queue_timeout: float = field( + default=DEFAULT_OBS_QUEUE_TIMEOUT, + metadata={"help": "Timeout for observation queue in seconds"}, + ) + + # Mock policy configuration (for simulation experiments) + mock_policy: bool = field( + default=False, + metadata={"help": "Use mock policy instead of real model (for experiments)"}, + ) + mock_spike_config: SpikeDelayConfig | None = field( + default=None, + metadata={ + "help": "Configuration for mock inference latency spikes. " + "Example: SpikeDelayConfig.from_dicts([{'start_s': 5, 'delay_ms': 2000}])" + }, + ) + mock_action_dim: int = field( + default=6, + metadata={"help": "Action dimension for mock policy output"}, + ) + + # Model warmup (CUDA kernel compilation + memory allocation on first pass) + warmup_passes: int = field( + default=2, + metadata={ + "help": "Number of dummy inference passes to run after loading the model. " + "This eliminates CUDA cold-start latency from the first real measurement. " + "Set to 0 to disable warmup." + }, + ) + + # Trajectory visualization (receives data from robot client via gRPC) + trajectory_viz_enabled: bool = field( + default=False, + metadata={"help": "Enable trajectory visualization server (HTTP + WebSocket)"}, + ) + trajectory_viz_http_port: int = field( + default=8088, + metadata={"help": "HTTP port for trajectory visualization web page"}, + ) + trajectory_viz_ws_port: int = field( + default=8089, + metadata={"help": "WebSocket port for trajectory data streaming"}, + ) + + @property + def environment_dt(self) -> float: + """Environment time step in seconds.""" + return 1.0 / self.fps + + def __post_init__(self): + """Validate configuration after initialization.""" + if self.port < 1 or self.port > 65535: + raise ValueError(f"Port must be between 1 and 65535, got {self.port}") + if self.fps <= 0: + raise ValueError(f"fps must be positive, got {self.fps}") + if self.obs_queue_timeout < 0: + raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}") + if self.metrics_diagnostic_interval_s <= 0: + raise ValueError( + f"metrics_diagnostic_interval_s must be positive, got {self.metrics_diagnostic_interval_s}" + ) + if self.metrics_diagnostic_window_s <= 0: + raise ValueError( + f"metrics_diagnostic_window_s must be positive, got {self.metrics_diagnostic_window_s}" + ) diff --git a/src/lerobot/async_inference/drtc_timed.py b/src/lerobot/async_inference/drtc_timed.py new file mode 100644 index 00000000000..6d372037e50 --- /dev/null +++ b/src/lerobot/async_inference/drtc_timed.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass + +from .helpers import Action, RawObservation + + +@dataclass +class DrtcTimedData: + """Base timing payload for the DRTC control-step clock.""" + + timestamp: float + control_step: int + + def get_timestamp(self): + return self.timestamp + + def get_control_step(self): + return self.control_step + + +@dataclass +class DrtcAction(DrtcTimedData): + """A DRTC action identified by its source control step and execution step.""" + + action_step: int + action: Action = None + + def get_action_step(self): + return self.action_step + + def get_action(self): + return self.action + + +@dataclass +class DrtcObservation(DrtcTimedData): + """A DRTC observation that carries the target chunk start step.""" + + observation: RawObservation = None + chunk_start_step: int = 0 + server_received_ts: float = 0.0 + + def get_observation(self): + return self.observation diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 8b12920d941..9b40601ba94 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -16,41 +16,92 @@ import logging.handlers import os import time +from collections.abc import Sequence from dataclasses import dataclass, field from pathlib import Path from typing import Any -import torch +import numpy as np -from lerobot.configs.types import PolicyFeature -from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features - -# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config -from lerobot.policies import ( # noqa: F401 - ACTConfig, - DiffusionConfig, - PI0Config, - PI05Config, - SmolVLAConfig, - VQBeTConfig, -) +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.robots.robot import Robot from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import init_logging -Action = torch.Tensor +Action = Any # observation as received from the robot (can be numpy arrays, floats, etc.) RawObservation = dict[str, Any] # observation as those recorded in LeRobot dataset (keys are different) -LeRobotObservation = dict[str, torch.Tensor] +LeRobotObservation = dict[str, Any] # observation, ready for policy inference (image keys resized) -Observation = dict[str, torch.Tensor] +Observation = dict[str, Any] + + +def _validate_feature_names(features: dict[str, dict]) -> None: + """Validate that feature names do not contain invalid characters. + + We keep this local to avoid importing `lerobot.datasets.utils` (which is heavyweight). + """ + invalid_features = {name: ft for name, ft in features.items() if "/" in name} + if invalid_features: + raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") -def visualize_action_queue_size(action_queue_size: list[int]) -> None: +def hw_to_dataset_features( + hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True +) -> dict[str, dict]: + """Lightweight version of `lerobot.datasets.utils.hw_to_dataset_features`. + + The async inference client only needs a small subset of dataset feature logic, and importing + the full dataset stack (datasets/pandas/pyarrow/torchvision/...) is very expensive on small + devices like a Raspberry Pi. + """ + features: dict[str, dict] = {} + + joint_fts = { + key: ftype + for key, ftype in hw_features.items() + if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) + } + cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} + + if joint_fts and prefix == OBS_STR: + features[f"{prefix}.state"] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + for key, shape in cam_fts.items(): + features[f"{prefix}.images.{key}"] = { + "dtype": "video" if use_video else "image", + "shape": shape, + "names": ["height", "width", "channels"], + } + + _validate_feature_names(features) + return features + + +def build_dataset_frame( + ds_features: dict[str, dict], values: dict[str, Any], prefix: str +) -> dict[str, np.ndarray]: + """Lightweight version of `lerobot.datasets.utils.build_dataset_frame`.""" + frame: dict[str, np.ndarray] = {} + for key, ft in ds_features.items(): + if not key.startswith(prefix): + continue + if ft["dtype"] == "float32" and len(ft["shape"]) == 1: + frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) + elif ft["dtype"] in ["image", "video"]: + frame[key] = values[key.removeprefix(f"{prefix}.images.")] + return frame + + +def visualize_action_queue_size(action_queue_size: Sequence[int]) -> None: import matplotlib.pyplot as plt _, ax = plt.subplots() @@ -71,7 +122,9 @@ def is_image_key(k: str) -> bool: return k.startswith(OBS_IMAGES) -def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor: +def resize_robot_observation_image(image: Any, resize_dims: tuple[int, int, int]) -> Any: + import torch + assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}" # (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution image = image.permute(2, 0, 1) @@ -90,6 +143,8 @@ def raw_observation_to_observation( lerobot_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature], ) -> Observation: + import torch + observation = {} observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features) @@ -104,8 +159,10 @@ def raw_observation_to_observation( return observation -def prepare_image(image: torch.Tensor) -> torch.Tensor: +def prepare_image(image: Any) -> Any: """Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor""" + import torch + image = image.type(torch.float32) / 255 image = image.contiguous() @@ -114,8 +171,10 @@ def prepare_image(image: torch.Tensor) -> torch.Tensor: def extract_state_from_raw_observation( lerobot_obs: RawObservation, -) -> torch.Tensor: +) -> Any: """Extract the state from a raw observation.""" + import torch + state = torch.tensor(lerobot_obs[OBS_STATE]) if state.ndim == 1: @@ -127,8 +186,10 @@ def extract_state_from_raw_observation( def extract_images_from_raw_observation( lerobot_obs: RawObservation, camera_key: str, -) -> dict[str, torch.Tensor]: +) -> Any: """Extract the images from a raw observation.""" + import torch + return torch.tensor(lerobot_obs[camera_key]) @@ -147,6 +208,8 @@ def prepare_raw_observation( ) -> Observation: """Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as policy_image_features).""" + import torch + # 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} -> # -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray} lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features) @@ -203,8 +266,7 @@ class TimedData: Args: timestamp: Unix timestamp relative to data's creation. - data: The actual data to wrap a timestamp around. - timestep: The timestep of the data. + timestep: Monotone async-inference step associated with this data. """ timestamp: float @@ -219,7 +281,9 @@ def get_timestep(self): @dataclass class TimedAction(TimedData): - action: Action + """A timed action associated with a legacy async-inference timestep.""" + + action: Action = None def get_action(self): return self.action @@ -227,7 +291,9 @@ def get_action(self): @dataclass class TimedObservation(TimedData): - observation: RawObservation + """A timed observation associated with a legacy async-inference timestep.""" + + observation: RawObservation = None must_go: bool = False def get_observation(self): @@ -270,10 +336,37 @@ class RemotePolicyConfig: actions_per_chunk: int device: str = "cpu" rename_map: dict[str, str] = field(default_factory=dict) - - -def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool: + # Client-driven RTC configuration (optional; server may ignore if policy doesn't support RTC) + rtc_enabled: bool = False + rtc_max_guidance_weight: float | None = None # None = use num_flow_matching_steps (Alex Soare opt) + rtc_prefix_attention_schedule: str = "linear" + rtc_sigma_d: float = 1.0 # Prior variance (0.2 = stronger guidance, 1.0 = original RTC) + rtc_full_trajectory_alignment: bool = False # Skip gradient for faster/smoother transitions + # Denoising steps override (Alex Soare: Beta should scale with n) + num_flow_matching_steps: int | None = None # None = use policy default (e.g., 10 for PI0/SmolVLA) + # Spike injection (client-driven, for experiments) + # List of dicts: [{"start_s": 5.0, "delay_ms": 2000}, ...] + spikes: list[dict] = field(default_factory=list) + # Diagnostics: when True, the server also enables verbose diagnostic output + diagnostics_verbose: bool = False + + def __setstate__(self, state: dict[str, Any]) -> None: + """Back-compat for pickles created before RTC/spike fields existed.""" + self.__dict__.update(state) + self.__dict__.setdefault("rtc_enabled", False) + self.__dict__.setdefault("rtc_max_guidance_weight", None) # Default to auto (Alex Soare opt) + self.__dict__.setdefault("rtc_prefix_attention_schedule", "linear") + self.__dict__.setdefault("rtc_sigma_d", 1.0) + self.__dict__.setdefault("rtc_full_trajectory_alignment", False) + self.__dict__.setdefault("num_flow_matching_steps", None) # Default to policy config + # Spike injection defaults (new format) + self.__dict__.setdefault("spikes", []) + + +def _compare_observation_states(obs1_state: Any, obs2_state: Any, atol: float) -> bool: """Check if two observation states are similar, under a tolerance threshold""" + import torch + return bool(torch.linalg.norm(obs1_state - obs2_state) < atol) diff --git a/src/lerobot/async_inference/lww_register.py b/src/lerobot/async_inference/lww_register.py new file mode 100644 index 00000000000..8967194e82f --- /dev/null +++ b/src/lerobot/async_inference/lww_register.py @@ -0,0 +1,112 @@ +"""Thread-safe last-write-wins register keyed by control step. + +The control step t is the monotone logical clock in DRTC: it increments +every tick of the robot control loop. This register implements a +monotone join: + + state := state ⊔ incoming + +where ⊔ keeps the state with the larger control step. + +The system uses two clocks: +- control_step (t): monotone per control-loop tick; used for LWW / watermarks. +- action_step (j): execution index; incremented when an action executes. +""" + +from __future__ import annotations + +import threading +from dataclasses import dataclass + + +@dataclass(frozen=True) +class LWWState[T]: + """Last-write-wins state element.""" + + control_step: int + value: T + + def __or__(self, other: LWWState[T]) -> LWWState[T]: + """Join (⊔): keep the state with the larger control_step. + + Tie-breaking is intentionally stable: if control_step is equal, keep `self`. + """ + + if other.control_step > self.control_step: + return other + return self + + +@dataclass(frozen=True) +class LWWCursor: + """Monotone consumer cursor (watermark) for read-once semantics.""" + + watermark: int + + def __or__(self, other: LWWCursor) -> LWWCursor: + return self if self.watermark >= other.watermark else other + + +class LWWReader[T]: + """Per-consumer read-once view of an `LWWRegister`. + + The cursor (watermark) is stored inside the reader, so call sites don't need + to carry `_last_*` or explicit cursor arguments. + """ + + def __init__(self, register: LWWRegister[T], *, initial_watermark: int): + self._register = register + self._cursor = LWWCursor(watermark=initial_watermark) + + @property + def cursor(self) -> LWWCursor: + return self._cursor + + def read_if_newer(self) -> tuple[LWWState[T], LWWCursor, bool]: + state = self._register.read() + is_new = state.control_step > self._cursor.watermark + if is_new: + self._cursor = self._cursor | LWWCursor(watermark=state.control_step) + return state, self._cursor, is_new + + +class LWWRegister[T]: + """A thread-safe LWW register holding a single `LWWState`. + + Notes: + - This register has no "consume" semantics. Consumers must track a watermark + (via LWWReader) to avoid re-processing the same state repeatedly. + - Updates are monotone w.r.t. control_step: stale (or equal) updates cannot overwrite. + """ + + def __init__(self, *, initial_control_step: int, initial_value: T): + self._lock = threading.Lock() + self._state: LWWState[T] = LWWState(control_step=initial_control_step, value=initial_value) + + def reader(self, *, initial_watermark: int = -1) -> LWWReader[T]: + """Create a per-consumer reader with an internal monotone cursor.""" + + return LWWReader(self, initial_watermark=initial_watermark) + + def read(self) -> LWWState[T]: + with self._lock: + return self._state + + def update(self, control_step: int, value: T) -> LWWState[T]: + state, _ = self.update_if_newer(control_step, value) + return state + + def update_if_newer(self, control_step: int, value: T) -> tuple[LWWState[T], bool]: + """Update the register iff the incoming control_step is strictly newer. + + Returns: + (state, did_update) + """ + + incoming = LWWState(control_step=control_step, value=value) + with self._lock: + prev = self._state + new = prev | incoming + did_update = new is not prev + self._state = new + return new, did_update diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index aedce2a7486..c50f348f88f 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -24,6 +24,15 @@ ``` """ +# ruff: noqa: E402, I001 + +import os as _os +import sys as _sys +import time as _time + +_IMPORT_TIMING_ENABLED = _os.getenv("LEROBOT_IMPORT_TIMING", "0") == "1" +_IMPORT_T0 = _time.perf_counter() if _IMPORT_TIMING_ENABLED else 0.0 + import logging import pickle # nosec import threading @@ -34,6 +43,8 @@ from queue import Empty, Queue from typing import Any +import cv2 # type: ignore +import numpy as np import draccus import grpc import torch @@ -62,6 +73,24 @@ raw_observation_to_observation, ) +if _IMPORT_TIMING_ENABLED: + _sys.stderr.write( + f"[import-timing] {__name__} imports: {(_time.perf_counter() - _IMPORT_T0) * 1000.0:.2f}ms\n" + ) + + +def _infer_model_action_horizon(policy_config: Any) -> tuple[str, int] | None: + """Infer the maximum action horizon from a loaded policy config.""" + if policy_config is None: + return None + + for field_name in ("chunk_size", "n_action_steps", "horizon"): + value = getattr(policy_config, field_name, None) + if isinstance(value, int) and value > 0: + return field_name, value + + return None + class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): prefix = "policy_server" @@ -90,6 +119,10 @@ def __init__(self, config: PolicyServerConfig): self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None + @staticmethod + def _ms(seconds: float) -> float: + return seconds * 1000.0 + @property def running(self): return not self.shutdown_event.is_set() @@ -118,13 +151,17 @@ def Ready(self, request, context): # noqa: N802 def SendPolicyInstructions(self, request, context): # noqa: N802 """Receive policy instructions from the robot client""" + t_total_start = time.perf_counter() + if not self.running: self.logger.warning("Server is not running. Ignoring policy instructions.") return services_pb2.Empty() client_id = context.peer() + t0 = time.perf_counter() policy_specs = pickle.loads(request.data) # nosec + t_deserialize = time.perf_counter() - t0 if not isinstance(policy_specs, RemotePolicyConfig): raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}") @@ -142,6 +179,7 @@ def SendPolicyInstructions(self, request, context): # noqa: N802 f"Actions per chunk: {policy_specs.actions_per_chunk} | " f"Device: {policy_specs.device}" ) + self.logger.debug("Policy instructions payload deserialized in %.2fms", self._ms(t_deserialize)) self.device = policy_specs.device self.policy_type = policy_specs.policy_type # act, pi0, etc. @@ -150,12 +188,29 @@ def SendPolicyInstructions(self, request, context): # noqa: N802 policy_class = get_policy_class(self.policy_type) - start = time.perf_counter() + t_load_start = time.perf_counter() self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path) - self.policy.to(self.device) + t_loaded = time.perf_counter() + + t_to_start = time.perf_counter() + self.policy.to(self.device) # includes parameter/device moves + t_to_done = time.perf_counter() + + inferred_horizon = _infer_model_action_horizon(getattr(self.policy, "config", None)) + if inferred_horizon is not None: + horizon_field, model_horizon = inferred_horizon + if self.actions_per_chunk > model_horizon: + raise ValueError( + "Requested actions_per_chunk " + f"({self.actions_per_chunk}) exceeds model-supported horizon " + f"({model_horizon}, from policy config field '{horizon_field}') " + f"for checkpoint '{policy_specs.pretrained_name_or_path}'. " + f"Set actions_per_chunk <= {model_horizon}." + ) # Load preprocessor and postprocessor, overriding device to match requested device device_override = {"device": self.device} + t_pp_start = time.perf_counter() self.preprocessor, self.postprocessor = make_pre_post_processors( self.policy.config, pretrained_path=policy_specs.pretrained_name_or_path, @@ -165,10 +220,16 @@ def SendPolicyInstructions(self, request, context): # noqa: N802 }, postprocessor_overrides={"device_processor": device_override}, ) + t_pp_done = time.perf_counter() - end = time.perf_counter() - - self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") + self.logger.info( + "Policy init timing | from_pretrained: %.2fms | to(%s): %.2fms | pre/post: %.2fms | total: %.2fms", + self._ms(t_loaded - t_load_start), + self.device, + self._ms(t_to_done - t_to_start), + self._ms(t_pp_done - t_pp_start), + self._ms(time.perf_counter() - t_total_start), + ) return services_pb2.Empty() @@ -177,13 +238,31 @@ def SendObservations(self, request_iterator, context): # noqa: N802 client_id = context.peer() self.logger.debug(f"Receiving observations from {client_id}") + t_total_start = time.perf_counter() receive_time = time.time() # comparing timestamps so need time.time() - start_deserialize = time.perf_counter() + + t_recv_start = time.perf_counter() received_bytes = receive_bytes_in_chunks( request_iterator, None, self.shutdown_event, self.logger ) # blocking call while looping over request_iterator + t_recv_done = time.perf_counter() + + t_deser_start = time.perf_counter() timed_observation = pickle.loads(received_bytes) # nosec - deserialize_time = time.perf_counter() - start_deserialize + t_deser_done = time.perf_counter() + + t_decode_start = time.perf_counter() + decoded_observation, decode_stats = _decode_images_from_transport(timed_observation.get_observation()) + timed_observation.observation = decoded_observation + t_decode_done = time.perf_counter() + if decode_stats["images_decoded"] > 0: + self.logger.debug( + "Decoded %s images from transport in %.2fms | encoded_bytes=%s -> raw_bytes=%s", + decode_stats["images_decoded"], + self._ms(t_decode_done - t_decode_start), + decode_stats["encoded_bytes_total"], + decode_stats["raw_bytes_total"], + ) self.logger.debug(f"Received observation #{timed_observation.get_timestep()}") @@ -203,13 +282,29 @@ def SendObservations(self, request_iterator, context): # noqa: N802 self.logger.debug( f"Server timestamp: {receive_time:.6f} | " f"Client timestamp: {obs_timestamp:.6f} | " - f"Deserialization time: {deserialize_time:.6f}s" + f"Chunk-receive time: {self._ms(t_recv_done - t_recv_start):.2f}ms | " + f"Deserialize time: {self._ms(t_deser_done - t_deser_start):.2f}ms | " + f"Payload bytes: {len(received_bytes)}" ) - if not self._enqueue_observation( - timed_observation # wrapping a RawObservation - ): + t_enqueue_start = time.perf_counter() + enqueued = self._enqueue_observation(timed_observation) # wrapping a RawObservation + t_enqueue_done = time.perf_counter() + + if not enqueued: self.logger.debug(f"Observation #{obs_timestep} has been filtered out") + else: + self.logger.debug( + "Observation #%s enqueued | enqueue time: %.2fms | queue size: %s", + obs_timestep, + self._ms(t_enqueue_done - t_enqueue_start), + self.observation_queue.qsize(), + ) + + self.logger.debug( + "SendObservations total time: %.2fms", + self._ms(time.perf_counter() - t_total_start), + ) return services_pb2.Empty() @@ -221,12 +316,22 @@ def GetActions(self, request, context): # noqa: N802 # Generate action based on the most recent observation and its timestep try: - getactions_starts = time.perf_counter() + t_total_start = time.perf_counter() + + t_wait_start = time.perf_counter() obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout) + t_wait_done = time.perf_counter() + self.logger.info( f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})" ) + self.logger.debug( + "GetActions waited %.2fms for observation | queue size after get: %s", + self._ms(t_wait_done - t_wait_start), + self.observation_queue.qsize(), + ) + with self._predicted_timesteps_lock: self._predicted_timesteps.add(obs.get_timestep()) @@ -248,14 +353,31 @@ def GetActions(self, request, context): # noqa: N802 self.logger.debug( f"Action chunk #{obs.get_timestep()} generated | " - f"Inference time: {inference_time:.2f}s |" - f"Serialize time: {serialize_time:.2f}s |" - f"Total time: {inference_time + serialize_time:.2f}s" + f"Inference time: {self._ms(inference_time):.2f}ms | " + f"Serialize time: {self._ms(serialize_time):.2f}ms | " + f"Pickle bytes: {len(actions_bytes)}" ) - time.sleep( - max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts)) - ) # sleep controls inference latency + # sleep controls inference latency (wall-clock budget for the entire GetActions call) + elapsed = time.perf_counter() - t_total_start + target = self.config.inference_latency + sleep_s = max(0.0, target - max(0.0, elapsed)) + if sleep_s > 0: + t_sleep_start = time.perf_counter() + time.sleep(sleep_s) + t_sleep_done = time.perf_counter() + self.logger.debug( + "GetActions throttling sleep: %.2fms (target %.2fms, elapsed %.2fms)", + self._ms(t_sleep_done - t_sleep_start), + self._ms(target), + self._ms(elapsed), + ) + else: + self.logger.debug( + "GetActions no sleep (target %.2fms, elapsed %.2fms)", + self._ms(target), + self._ms(elapsed), + ) return actions @@ -317,13 +439,21 @@ def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: t_0 + i*environment_dt for i in range(len(action_chunk)) """ return [ - TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action) + # Convert to numpy so the robot client does not need torch installed just to unpickle actions. + TimedAction( + timestamp=t_0 + i * self.config.environment_dt, + timestep=i_0 + i, + action=action.detach().cpu().numpy(), + ) for i, action in enumerate(action_chunk) ] def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: """Get an action chunk from the policy. The chunk contains only""" + t0 = time.perf_counter() chunk = self.policy.predict_action_chunk(observation) + t1 = time.perf_counter() + self.logger.debug("Policy predict_action_chunk time: %.2fms", self._ms(t1 - t0)) if chunk.ndim != 3: chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim) @@ -358,8 +488,12 @@ def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAc start_inference = time.perf_counter() action_tensor = self._get_action_chunk(observation) inference_time = time.perf_counter() - start_inference - self.logger.info( - f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}" + self.logger.debug( + "Model timings | prepare: %.2fms | preprocess: %.2fms | inference: %.2fms | action shape: %s", + self._ms(prepare_time), + self._ms(preprocessing_time), + self._ms(inference_time), + tuple(action_tensor.shape), ) """4. Apply postprocessor""" @@ -374,7 +508,15 @@ def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAc for i in range(chunk_size): # Extract action at timestep i: (B, action_dim) single_action = action_tensor[:, i, :] + t_action_post_start = time.perf_counter() processed_action = self.postprocessor(single_action) + t_action_post_done = time.perf_counter() + self.logger.debug( + "Postprocess action[%s/%s] time: %.2fms", + i + 1, + chunk_size, + self._ms(t_action_post_done - t_action_post_start), + ) processed_actions.append(processed_action) # Stack back to (B, chunk_size, action_dim), then remove batch dim @@ -384,23 +526,21 @@ def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAc action_tensor = action_tensor.detach().cpu() """5. Convert to TimedAction list""" + t_time_chunk_start = time.perf_counter() action_chunk = self._time_action_chunk( observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() ) + t_time_chunk_done = time.perf_counter() postprocess_stops = time.perf_counter() postprocessing_time = postprocess_stops - start_postprocess - self.logger.info( - f"Observation {observation_t.get_timestep()} | " - f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms" - ) - self.logger.debug( f"Observation {observation_t.get_timestep()} | " f"Prepare time: {1000 * prepare_time:.2f}ms | " f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | " f"Inference time: {1000 * inference_time:.2f}ms | " f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | " + f"Timing chunk time: {1000 * (t_time_chunk_done - t_time_chunk_start):.2f}ms | " f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms" ) @@ -439,3 +579,35 @@ def serve(cfg: PolicyServerConfig): if __name__ == "__main__": serve() + + +def _decode_images_from_transport(observation: Any) -> tuple[Any, dict[str, int]]: + """Recursively decode JPEG-marked images back into uint8 HWC3 RGB numpy arrays.""" + stats = {"images_decoded": 0, "raw_bytes_total": 0, "encoded_bytes_total": 0} + + def _maybe_decode_payload(x: Any) -> Any: + if isinstance(x, dict) and x.get("__lerobot_image_encoding__") == "jpeg": + data = x.get("data") + if not isinstance(data, (bytes, bytearray)): + raise TypeError("JPEG payload missing bytes 'data'") + + buf = np.frombuffer(data, dtype=np.uint8) + bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if bgr is None: + raise RuntimeError("OpenCV failed to decode JPEG payload") + + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + stats["images_decoded"] += 1 + stats["encoded_bytes_total"] += len(data) + stats["raw_bytes_total"] += int(rgb.nbytes) + return rgb + + if isinstance(x, dict): + return {k: _maybe_decode_payload(v) for k, v in x.items()} + if isinstance(x, list): + return [_maybe_decode_payload(v) for v in x] + if isinstance(x, tuple): + return tuple(_maybe_decode_payload(v) for v in x) + return x + + return _maybe_decode_payload(observation), stats diff --git a/src/lerobot/async_inference/policy_server_drtc.py b/src/lerobot/async_inference/policy_server_drtc.py new file mode 100644 index 00000000000..f66f5e6e284 --- /dev/null +++ b/src/lerobot/async_inference/policy_server_drtc.py @@ -0,0 +1,899 @@ +""" +DRTC Policy Server + +This implementation follows the DRTC algorithm with: +- 2-thread architecture (observation receiver + main inference loop) +- SPSC last-write-wins registers for observation/actions handoff + +Threading model (2 threads): +- Main thread: inference loop, runs policy, sends actions +- Observation receiver thread: receives observations from clients via gRPC + +Example: +```shell +python -m lerobot.async_inference.policy_server_drtc \ + --host=127.0.0.1 \ + --port=8080 \ + --fps=30 \ + --obs_queue_timeout=2 +``` +""" + +import logging +import pickle # nosec +import threading +import time +from collections import OrderedDict +from concurrent import futures +from contextlib import suppress +from typing import Any + +import draccus +import grpc +import numpy as np +import torch + +from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.processor import ( + PolicyAction, + PolicyProcessorPipeline, +) +from lerobot.transport import ( + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore +) +from lerobot.transport.utils import receive_bytes_in_chunks + +from .configs_drtc import PolicyServerDrtcConfig +from .constants import SUPPORTED_POLICIES +from .drtc_timed import DrtcObservation +from .helpers import ( + Observation, + RemotePolicyConfig, + get_logger, + raw_observation_to_observation, +) +from .lww_register import LWWRegister +from .rtc_guidance import AsyncRTCConfig, AsyncRTCProcessor +from .utils.compression import decode_images_from_transport +from .utils.metrics import DiagnosticMetrics, EvActionChunk, Metrics +from .utils.simulation import SpikeDelaySimulator +from .utils.trajectory_viz import TrajectoryVizServer +from .utils.viz_utils import compute_prefix_weights_for_viz + +_INITIAL_K = -(2**63) + + +def _infer_model_action_horizon(policy_config: Any) -> tuple[str, int] | None: + """Infer the maximum action horizon from a loaded policy config.""" + if policy_config is None: + return None + + for field_name in ("chunk_size", "n_action_steps", "horizon"): + value = getattr(policy_config, field_name, None) + if isinstance(value, int) and value > 0: + return field_name, value + + return None + + +class ActionChunkCache: + """LRU cache for raw action chunks, keyed by source control step (t). + + Used for RTC inpainting: the server caches raw (pre-postprocess) action chunks + so the client can reference them by source control step + index range instead + of sending post-processed actions (which have different dimensions). + """ + + def __init__(self, max_size: int = 10): + """Initialize the cache. + + Args: + max_size: Maximum number of chunks to cache (oldest evicted first). + """ + self._cache: OrderedDict[int, torch.Tensor] = OrderedDict() + self._max_size = max_size + + def put(self, src_step: int, raw_actions: torch.Tensor) -> None: + """Store a raw action chunk keyed by source step. + + Args: + src_step: The source step (observation timestep) for this chunk. + raw_actions: Raw action tensor of shape (B, T, A) or (T, A). + """ + # If already exists, remove it first so it goes to the end (most recent) + if src_step in self._cache: + del self._cache[src_step] + + # Evict oldest if at capacity + while len(self._cache) >= self._max_size: + self._cache.popitem(last=False) + + # Store a detached clone to avoid holding onto computation graph + self._cache[src_step] = raw_actions.detach().clone() + + def get(self, src_step: int) -> torch.Tensor | None: + """Retrieve a cached chunk by source step. + + Args: + src_step: The source step to look up. + + Returns: + The cached tensor or None if not found. + """ + return self._cache.get(src_step) + + def clear(self) -> None: + """Clear all cached chunks.""" + self._cache.clear() + + +class PolicyServerDrtc(services_pb2_grpc.AsyncInferenceServicer): + """DRTC policy server. + + This implementation follows the 2-thread model from the paper: + - Main thread: runs the inference loop + - Observation receiver thread: receives observations from clients via gRPC + + Thread communication uses SPSC last-write-wins registers (keyed by timesteps). + """ + + prefix = "policy_server_drtc" + logger = get_logger(prefix) + + def __init__(self, config: PolicyServerDrtcConfig): + """Initialize the policy server. + + Args: + config: Server configuration. + """ + self.config = config + self.shutdown_event = threading.Event() + + # Diagnostic metrics (console only; avg/max timings). + diag = DiagnosticMetrics( + fps=config.fps, + window_s=config.metrics_diagnostic_window_s, + interval_s=config.metrics_diagnostic_interval_s, + enabled=config.metrics_diagnostic_enabled, + verbose=config.metrics_diagnostic_verbose, + prefix="DIAG_SERVER", + ) + diag.start() + self._metrics = Metrics(experiment=None, diagnostic=diag) + + # SPSC LWW registers + # - Receiver thread -> inference producer: latest observation (by control_step) + # - Inference producer -> StreamActionsDense: latest dense actions (by control_step) + self._obs_reg: LWWRegister[DrtcObservation | None] = LWWRegister( + initial_control_step=_INITIAL_K, initial_value=None + ) + self._action_reg: LWWRegister[services_pb2.ActionsDense | None] = LWWRegister( + initial_control_step=_INITIAL_K, initial_value=None + ) + + self._policy_ready = threading.Event() + self._producer_thread: threading.Thread | None = None + + # Policy components (set by SendPolicyInstructions) + self.device: str | None = None + self.policy_type: str | None = None + self.lerobot_features: dict[str, Any] | None = None + self.actions_per_chunk: int | None = None + self.policy: Any = None + self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None + self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None + + # Client-driven RTC (optional) + self._rtc_cfg: AsyncRTCConfig | None = None + + # Action chunk cache for RTC (stores raw actions before postprocessing). + # Placeholder; resized to match actions_per_chunk in SendPolicyInstructions. + self._action_cache = ActionChunkCache(max_size=10) + + # Spike delay simulator for experiments + self._delay_simulator = SpikeDelaySimulator(config=config.mock_spike_config) + + # Trajectory visualization server (HTTP + WebSocket) + self._trajectory_viz_server: TrajectoryVizServer | None = None + self._trajectory_viz_thread: threading.Thread | None = None + if config.trajectory_viz_enabled: + self._trajectory_viz_server = TrajectoryVizServer( + ws_port=config.trajectory_viz_ws_port, + http_port=config.trajectory_viz_http_port, + ) + self._trajectory_viz_thread = threading.Thread( + target=self._trajectory_viz_server.start, + name="trajectory_viz_server", + daemon=True, + ) + self._trajectory_viz_thread.start() + print( + "Trajectory visualization server started on " + f"http://0.0.0.0:{config.trajectory_viz_http_port} " + f"(WebSocket: ws://0.0.0.0:{config.trajectory_viz_ws_port})" + ) + + @property + def running(self) -> bool: + return not self.shutdown_event.is_set() + + @property + def policy_image_features(self): + return self.policy.config.image_features + + def _reset_server(self) -> None: + """Reset server state when a new client connects. + + Joins the old producer thread before reassigning registers so the + thread doesn't leak (it holds a reader bound to the old register). + """ + self.shutdown_event.set() + self._policy_ready.clear() + + # Wait for the old producer thread to observe shutdown and exit + # before replacing registers, so it doesn't loop forever on the + # old register after shutdown_event is cleared. + if self._producer_thread is not None and self._producer_thread.is_alive(): + self._producer_thread.join(timeout=5.0) + if self._producer_thread.is_alive(): + self.logger.warning( + "Producer thread did not exit within 5s during reset; " + "a new thread will be started anyway." + ) + self._producer_thread = None + + # Reset registers (avoid leaking prior session values) + self._obs_reg = LWWRegister(initial_control_step=_INITIAL_K, initial_value=None) + self._action_reg = LWWRegister(initial_control_step=_INITIAL_K, initial_value=None) + self._action_cache.clear() + + # ------------------------------------------------------------------------- + # gRPC Service Methods (called by receiver thread) + # ------------------------------------------------------------------------- + + def Ready(self, request, context): # noqa: N802 + """Handle client ready signal. Resets server state for new session.""" + self._metrics.diagnostic.counter("client_ready", 1) + self._reset_server() + self.shutdown_event.clear() + return services_pb2.Empty() + + def SendTrajectoryChunk(self, request, context): # noqa: N802 + """Receive trajectory chunk from robot client for visualization.""" + if self._trajectory_viz_server is None: + return services_pb2.Empty() + + # Decode the packed float32 actions + num_actions = request.num_actions + action_dim = request.action_dim + if num_actions > 0 and action_dim > 0: + actions_flat = np.frombuffer(request.actions_f32, dtype=np.float32) + actions = actions_flat.reshape(num_actions, action_dim).tolist() + else: + actions = [] + + # Create EvActionChunk event and forward to viz server + event = EvActionChunk( + src_control_step=request.source_step, # proto field is source_step + actions=actions, + frozen_len=request.frozen_len, + timestamp=request.timestamp, + ) + self._trajectory_viz_server.on_chunk(event) + + return services_pb2.Empty() + + def SendPolicyInstructions(self, request, context): # noqa: N802 + """Receive and load policy from client instructions.""" + if not self.running: + return services_pb2.Empty() + + t_total_start = time.perf_counter() + + # Deserialize policy configuration + policy_specs = pickle.loads(request.data) # nosec + + if not isinstance(policy_specs, RemotePolicyConfig): + raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}") + + if policy_specs.policy_type not in SUPPORTED_POLICIES: + raise ValueError( + f"Policy type {policy_specs.policy_type} not supported. " + f"Supported policies: {SUPPORTED_POLICIES}" + ) + + self.device = policy_specs.device + self.policy_type = policy_specs.policy_type + self.lerobot_features = policy_specs.lerobot_features + self.actions_per_chunk = policy_specs.actions_per_chunk + + # Resize RTC chunk cache to match the client's chunk size so we always + # keep enough history for the full action horizon. + self._action_cache = ActionChunkCache(max_size=self.actions_per_chunk) + + # Skip loading real policy in mock mode + if self.config.mock_policy: + self._metrics.diagnostic.counter("mock_policy_mode", 1) + self._policy_ready.set() + return services_pb2.Empty() + + # Load policy + policy_class = get_policy_class(self.policy_type) + + t_load_start = time.perf_counter() + self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path) + t_load_done = time.perf_counter() + + t_to_start = time.perf_counter() + self.policy.to(self.device) + t_to_done = time.perf_counter() + + inferred_horizon = _infer_model_action_horizon(getattr(self.policy, "config", None)) + if inferred_horizon is not None: + horizon_field, model_horizon = inferred_horizon + if self.actions_per_chunk > model_horizon: + raise ValueError( + "Requested actions_per_chunk " + f"({self.actions_per_chunk}) exceeds model-supported horizon " + f"({model_horizon}, from policy config field '{horizon_field}') " + f"for checkpoint '{policy_specs.pretrained_name_or_path}'. " + f"Set actions_per_chunk <= {model_horizon}." + ) + + # Load preprocessor and postprocessor + device_override = {"device": self.device} + t_pp_start = time.perf_counter() + self.preprocessor, self.postprocessor = make_pre_post_processors( + self.policy.config, + pretrained_path=policy_specs.pretrained_name_or_path, + preprocessor_overrides={ + "device_processor": device_override, + "rename_observations_processor": {"rename_map": policy_specs.rename_map}, + }, + postprocessor_overrides={"device_processor": device_override}, + ) + t_pp_done = time.perf_counter() + self._metrics.diagnostic.timing_s("policy_load_ms", t_load_done - t_load_start) + self._metrics.diagnostic.timing_s("policy_to_ms", t_to_done - t_to_start) + self._metrics.diagnostic.timing_s("policy_processors_ms", t_pp_done - t_pp_start) + self._metrics.diagnostic.timing_s("policy_total_ms", time.perf_counter() - t_total_start) + + # Apply num_flow_matching_steps override if provided by client + # (Alex Soare optimization: Beta should scale with n) + num_flow_steps = getattr(policy_specs, "num_flow_matching_steps", None) + if num_flow_steps is not None: + cfg_obj = getattr(self.policy, "config", None) + if cfg_obj is not None: + # PI0/PI05 use num_inference_steps, SmolVLA uses num_steps + if hasattr(cfg_obj, "num_inference_steps"): + cfg_obj.num_inference_steps = num_flow_steps + elif hasattr(cfg_obj, "num_steps"): + cfg_obj.num_steps = num_flow_steps + else: + self._metrics.diagnostic.counter("num_flow_steps_override_ignored", 1) + + # Optional: enable RTC via client instructions (server-side inpainting) + if getattr(policy_specs, "rtc_enabled", False): + # Handle optional max_guidance_weight (None = use num_flow_matching_steps, Alex Soare opt) + max_gw_raw = getattr(policy_specs, "rtc_max_guidance_weight", None) + max_gw = float(max_gw_raw) if max_gw_raw is not None else None + + self._rtc_cfg = AsyncRTCConfig( + enabled=True, + prefix_attention_schedule=str( + getattr(policy_specs, "rtc_prefix_attention_schedule", "linear") + ), + max_guidance_weight=max_gw, + sigma_d=float(getattr(policy_specs, "rtc_sigma_d", 1.0)), + full_trajectory_alignment=bool(getattr(policy_specs, "rtc_full_trajectory_alignment", False)), + ) + # NOTE: We do NOT pass self.postprocessor to RTC guidance because: + # - RTC operates INSIDE the model's denoising loop in raw action space (e.g. 32 dims) + # - The postprocessor (NormalizeProcessor) expects executable action space (e.g. 6 dims) + # - These dimensions are incompatible; the model's action head converts at the end + # - For now, RTC guidance compares in raw model space (prev must match model dims) + rtc = AsyncRTCProcessor(self._rtc_cfg, postprocess=None) + + # Flow policies expect `policy.rtc_processor` and `policy.model.rtc_processor`. + self.policy.rtc_processor = rtc + model_value = getattr(self.policy, "model", None) + if model_value is not None: + model_value.rtc_processor = rtc + + # Satisfy policy-side `_rtc_enabled()` checks without importing RTCConfig. + cfg_obj = getattr(self.policy, "config", None) + if cfg_obj is not None: + with suppress(Exception): + cfg_obj.rtc_config = type("RTCConfigShim", (), {"enabled": True})() + + # Apply spike configuration from client (for experiments) + spikes = getattr(policy_specs, "spikes", []) + if spikes: + self._delay_simulator = SpikeDelaySimulator.from_dicts(spikes) + self._metrics.diagnostic.counter("spike_events_configured", len(spikes)) + + # Warmup: run dummy inference passes to trigger CUDA kernel compilation + # and memory allocation so the first real measurement isn't inflated. + if self.config.warmup_passes > 0: + self._warmup_model(num_passes=self.config.warmup_passes) + + self._policy_ready.set() + + # Start producer thread (if needed) to generate actions outside the RPC path (lower jitter). + if self._producer_thread is None or not self._producer_thread.is_alive(): + self._producer_thread = threading.Thread( + target=self._inference_producer_loop, + name="policy_server_drtc_inference_producer", + daemon=True, + ) + self._producer_thread.start() + + return services_pb2.Empty() + + def SendObservations(self, request_iterator, context): # noqa: N802 + """Receive observations from client and enqueue for inference. + + This method is called by the gRPC receiver thread. + """ + t_total_start = time.perf_counter() + + # Receive observation bytes (stamp receive_time AFTER full payload + # arrives so that client-to-server latency captures the actual + # network transfer of the chunked image payload, not just the + # gRPC handler dispatch time). + t_recv_start = time.perf_counter() + received_bytes = receive_bytes_in_chunks(request_iterator, None, self.shutdown_event, self.logger) + t_recv_done = time.perf_counter() + receive_time = time.time() + + # Deserialize + t_deser_start = time.perf_counter() + timed_observation = pickle.loads(received_bytes) # nosec + t_deser_done = time.perf_counter() + + # Decode images + t_decode_start = time.perf_counter() + decoded_observation, _ = decode_images_from_transport(timed_observation.get_observation()) + timed_observation.observation = decoded_observation + t_decode_done = time.perf_counter() + + # Stamp the server receive time for granular latency decomposition + timed_observation.server_received_ts = receive_time + + obs_control_step = timed_observation.get_control_step() + obs_timestamp = timed_observation.get_timestamp() + + # Diagnostics + # Provide a stable `step` field for compact diagnostics. + self._metrics.diagnostic.set_context( + step=obs_control_step, last_obs_step=obs_control_step, chunk_size=self.actions_per_chunk + ) + self._metrics.diagnostic.timing_s("obs_recv_ms", t_recv_done - t_recv_start) + self._metrics.diagnostic.timing_s("deser_ms", t_deser_done - t_deser_start) + self._metrics.diagnostic.timing_s("obs_decode_ms", t_decode_done - t_decode_start) + self._metrics.diagnostic.timing_s("obs_one_way_latency_ms", receive_time - obs_timestamp) + self._metrics.diagnostic.timing_s("obs_total_ms", time.perf_counter() - t_total_start) + + # Publish newest observation (monotone w.r.t. control_step) + self._obs_reg.update_if_newer(obs_control_step, timed_observation) + + return services_pb2.Empty() + + def StreamActionsDense(self, request, context): # noqa: N802 + """Server-streaming dense actions RPC (streaming-only action transport).""" + if not self._policy_ready.is_set(): + return + reader = self._action_reg.reader(initial_watermark=_INITIAL_K) + while self.running and context.is_active(): + state, _, is_new = reader.read_if_newer() + dense = state.value + if not is_new or dense is None: + time.sleep(0.01) + continue + yield dense + + # ------------------------------------------------------------------------- + # Inference Pipeline + # ------------------------------------------------------------------------- + + def _publish_dense(self, dense: services_pb2.ActionsDense) -> None: + control_step = int(dense.source_control_step) + self._action_reg.update_if_newer(control_step, dense) + + def _warmup_model(self, num_passes: int = 2) -> None: + """Run dummy inference passes to warm up CUDA kernels and memory allocations. + + The first forward pass through a PyTorch model on GPU triggers JIT compilation + of CUDA kernels and cuDNN workspace allocation, adding hundreds of milliseconds + to inference time. Running a few dummy passes here ensures this overhead is paid + during startup, not during the first real measurement. + + Args: + num_passes: Number of dummy inference passes to run. + """ + if self.preprocessor is None or self.postprocessor is None: + self.logger.warning("Cannot warmup: pre/post processors not initialized") + return + if self.policy is None: + self.logger.warning("Cannot warmup: policy not loaded") + return + + self.logger.info(f"Warming up model with {num_passes} dummy inference pass(es)...") + t_warmup_start = time.perf_counter() + + try: + # Build a dummy observation matching the format produced by + # raw_observation_to_observation(): {OBS_STATE: (1, state_dim), image_keys: (1, C, H, W), task: str} + dummy_obs: dict[str, Any] = {} + + # State: derive dimensionality from lerobot_features + if self.lerobot_features: + state_features = self.lerobot_features.get("observation.state", []) + state_dim = len(state_features) if isinstance(state_features, (list, tuple)) else 6 + else: + state_dim = 6 + dummy_obs["observation.state"] = torch.zeros(1, state_dim) + + # Images: use policy's image_features to get (C, H, W) shapes + for key, feat in self.policy_image_features.items(): + c, h, w = feat.shape + # After prepare_image + unsqueeze: float32 in [0, 1], shape (1, C, H, W) + dummy_obs[key] = torch.zeros(1, c, h, w, dtype=torch.float32) + + # Task string (VLA models require this) + dummy_obs["task"] = "warmup" + + for i in range(num_passes): + t_pass_start = time.perf_counter() + + # Preprocess + obs = self.preprocessor(dummy_obs) + + # Inference -- call policy directly (not _get_action_chunk) + # to avoid recording warmup timings in diagnostic metrics. + with torch.no_grad(): + action_tensor = self.policy.predict_action_chunk(obs) + + # Postprocess (same path as real inference) + if action_tensor.ndim != 3: + action_tensor = action_tensor.unsqueeze(0) + action_tensor = action_tensor[:, : self.actions_per_chunk, :] + b, t_dim, a = action_tensor.shape + flat = action_tensor.reshape(b * t_dim, a) + flat = self.postprocessor(flat) + + t_pass_done = time.perf_counter() + self.logger.info( + f" Warmup pass {i + 1}/{num_passes}: {(t_pass_done - t_pass_start) * 1000:.1f}ms" + ) + + t_warmup_done = time.perf_counter() + warmup_total_ms = (t_warmup_done - t_warmup_start) * 1000 + self.logger.info(f"Model warmup complete ({warmup_total_ms:.0f}ms total)") + self._metrics.diagnostic.timing_ms("warmup_total_ms", warmup_total_ms) + + except Exception as e: + self.logger.error(f"Warmup failed (non-fatal, first inference may be slow): {e}") + self._metrics.diagnostic.counter("warmup_failed", 1) + + def _inference_producer_loop(self) -> None: + """Continuously produce the latest action chunk from the latest observation (low jitter).""" + reader = self._obs_reg.reader(initial_watermark=_INITIAL_K) + consecutive_errors = 0 + + while self.running: + if not self._policy_ready.is_set(): + time.sleep(0.01) + continue + + state, _, is_new = reader.read_if_newer() + obs = state.value + if not is_new or obs is None: + time.sleep(0.01) + continue + + try: + t_total_start = time.perf_counter() + + # Apply simulated delay (for experiments) + self._delay_simulator.apply_delay() + + t_infer_start = time.perf_counter() + + # Use mock policy or real policy + if self.config.mock_policy: + dense = self._mock_predict_action_chunk_dense(obs) + else: + dense = self._predict_action_chunk_dense(obs) + t_infer_done = time.perf_counter() + + # Stamp server-side timestamps for granular latency decomposition + dense.server_obs_received_ts = float(getattr(obs, "server_received_ts", 0.0)) + dense.server_action_sent_ts = time.time() + + self._publish_dense(dense) + # Provide a stable `step` field for compact diagnostics. + self._metrics.diagnostic.set_context( + step=int(obs.get_control_step()), + last_infer_src_step=int(obs.get_control_step()), + chunk_size=self.actions_per_chunk, + ) + self._metrics.diagnostic.timing_s("infer_total_ms", t_infer_done - t_infer_start) + self._metrics.diagnostic.timing_s( + "producer_loop_total_ms", time.perf_counter() - t_total_start + ) + consecutive_errors = 0 + except Exception as e: + consecutive_errors += 1 + self.logger.error("Error in inference producer loop: %s", e, exc_info=True) + self._metrics.diagnostic.counter("inference_producer_error", 1) + # Exponential backoff: 0.1s, 0.2s, 0.4s, ... capped at 2s + backoff = min(0.1 * (2 ** (consecutive_errors - 1)), 2.0) + time.sleep(backoff) + + def _mock_predict_action_chunk_dense(self, observation_t: DrtcObservation) -> services_pb2.ActionsDense: + """Generate mock actions for simulation experiments (no real model).""" + action_dim = self.config.mock_action_dim + actions_per_chunk = self.actions_per_chunk or 50 + + # Generate random actions + actions_np = np.random.randn(actions_per_chunk, action_dim).astype(np.float32) * 0.1 + payload = np.asarray(actions_np, dtype=np.float32, order="C") + + dense = services_pb2.ActionsDense( + timestamp=float(observation_t.get_timestamp()), + source_control_step=int(observation_t.get_control_step()), + chunk_start_step=int(observation_t.chunk_start_step), + dt=float(self.config.environment_dt), + num_actions=int(payload.shape[0]), + action_dim=int(payload.shape[1]), + actions_f32=payload.tobytes(order="C"), + ) + return dense + + def _predict_action_chunk_dense(self, observation_t: DrtcObservation) -> services_pb2.ActionsDense: + """Run inference on an observation and return dense packed actions (lower jitter).""" + if self.actions_per_chunk is None: + raise RuntimeError("actions_per_chunk is not set; did SendPolicyInstructions run?") + if self.preprocessor is None or self.postprocessor is None: + raise RuntimeError("pre/post processors not initialized; did SendPolicyInstructions run?") + + # Optional RTC metadata (client-provided hard-mask prefix + estimated delay). + rtc_meta = None + raw_obs_any = observation_t.get_observation() + if isinstance(raw_obs_any, dict): + rtc_meta = raw_obs_any.get("__rtc__") + + # Remove RTC metadata before policy preprocessing (avoid surprising processors). + if rtc_meta is not None and isinstance(raw_obs_any, dict): + raw_obs = dict(raw_obs_any) + raw_obs.pop("__rtc__", None) + else: + raw_obs = raw_obs_any + + # 1. Prepare observation + observation: Observation = raw_observation_to_observation( + raw_obs, + self.lerobot_features, + self.policy_image_features, + ) + + # 2. Preprocess + observation = self.preprocessor(observation) + + # 3. Inference (avoid autograd / reduce variance) + # NOTE: Do NOT use `torch.inference_mode()` here: RTC guidance needs to temporarily + # enable gradients for the inpainting correction term, and inference_mode cannot be + # overridden. `torch.no_grad()` keeps the normal path efficient while still allowing + # nested `torch.enable_grad()` for RTC. + src_control_step = int(observation_t.get_control_step()) + + with torch.no_grad(): + rtc_kwargs: dict[str, Any] = {} + if rtc_meta is not None and self._rtc_cfg is not None and self._rtc_cfg.enabled: + try: + d = int(rtc_meta.get("latency_steps", 0)) + action_schedule_spans = rtc_meta.get("action_schedule_spans") + + # overlap_end from client: where fresh region starts (H - max(s_min, d)) + chunk_len = self.actions_per_chunk + overlap_end = int(rtc_meta.get("overlap_end") or (chunk_len - d)) + self._metrics.diagnostic.counter("rtc_meta_seen", 1) + + # Reconstruct prefix tensor from multiple cached chunks + if action_schedule_spans: + slices: list[torch.Tensor] = [] + for control_src_step, start_idx, end_idx in action_schedule_spans: + cached_chunk = self._action_cache.get(int(control_src_step)) + if cached_chunk is None: + self._metrics.diagnostic.counter("rtc_cache_miss", 1) + else: + self._metrics.diagnostic.counter("rtc_cache_hit", 1) + if cached_chunk is not None: + # Extract slice from cached chunk (B, T, A) or (T, A) + if cached_chunk.ndim == 2: + slices.append(cached_chunk[start_idx:end_idx, :]) + else: + # Squeeze batch dim for concatenation + slices.append(cached_chunk[0, start_idx:end_idx, :]) + + if slices: + # Concatenate all slices along time dimension -> (T_total, A) + prefix_tensor = torch.cat(slices, dim=0) + prefix_tensor = prefix_tensor.unsqueeze(0) # (1, T_total, A) + prefix_len = prefix_tensor.shape[1] + + # Clamp overlap_end to what we actually have in the prefix + # This allows graceful degradation when cache is incomplete + effective_overlap_end = min(overlap_end, prefix_len) + + # Zero-pad to max_action_dim if model uses padded action space + max_action_dim = getattr(self.policy.config, "max_action_dim", None) + if max_action_dim is not None and prefix_tensor.shape[-1] < max_action_dim: + b, t, a = prefix_tensor.shape + padded = torch.zeros( + b, + t, + max_action_dim, + device=prefix_tensor.device, + dtype=prefix_tensor.dtype, + ) + padded[:, :, :a] = prefix_tensor + prefix_tensor = padded + + rtc_kwargs = { + "inference_delay": d, + "prev_chunk_left_over": prefix_tensor.to(device=self.device), + "overlap_end": effective_overlap_end, # Clamped for RTC guidance + "overlap_end_intended": overlap_end, # Original for visualization + } + self._metrics.diagnostic.counter("rtc_applied", 1) + else: + self._metrics.diagnostic.counter("rtc_not_applied_no_slices", 1) + else: + self._metrics.diagnostic.counter("rtc_not_applied_empty_prefix", 1) + except Exception: + self._metrics.diagnostic.counter("rtc_meta_error", 1) + rtc_kwargs = {} + + action_tensor = self._get_action_chunk(observation, **rtc_kwargs) + + # Ensure (B, T, A) + if action_tensor.ndim != 3: + action_tensor = action_tensor.unsqueeze(0) + action_tensor = action_tensor[:, : self.actions_per_chunk, :] + + b, t, a = action_tensor.shape + + # Cache raw action chunk BEFORE postprocessing (for future RTC inpainting) + # Key by control_step so RTC action_schedule_spans spans can look up the right chunk. + if src_control_step >= 0: + self._action_cache.put(src_control_step, action_tensor) + + # 4. Vectorized postprocess: (B, T, A_in) -> (B*T, A_in) -> (B, T, A_out) + flat = action_tensor.reshape(b * t, a) + flat = self.postprocessor(flat) + if not isinstance(flat, torch.Tensor): + raise TypeError(f"postprocessor must return torch.Tensor, got {type(flat)}") + a_out = flat.shape[-1] + action_tensor = flat.reshape(b, t, a_out) + + # Drop batch dim and move to CPU once + actions_cpu = action_tensor.squeeze(0).detach().to("cpu") + actions_np = actions_cpu.to(torch.float32).numpy() + + payload = np.asarray(actions_np, dtype=np.float32, order="C") + + # Emit action chunk to trajectory visualization (if enabled) + if self._trajectory_viz_server is not None: + # Build RTC params dict for visualization + rtc_params_viz: dict[str, Any] | None = None + prefix_weights_viz: list[float] | None = None + + if self._rtc_cfg is not None and self._rtc_cfg.enabled and rtc_kwargs: + d_viz = rtc_kwargs.get("inference_delay", 0) + # Use intended overlap_end for visualization (not clamped to prefix length) + overlap_end_viz = rtc_kwargs.get( + "overlap_end_intended", rtc_kwargs.get("overlap_end", self.actions_per_chunk) + ) + chunk_len_viz = self.actions_per_chunk + + rtc_params_viz = { + "d": d_viz, + "H": chunk_len_viz, + "overlap_end": overlap_end_viz, + "sigma_d": self._rtc_cfg.sigma_d, + "schedule": self._rtc_cfg.prefix_attention_schedule, + "max_guidance_weight": self._rtc_cfg.max_guidance_weight, + "full_trajectory_alignment": self._rtc_cfg.full_trajectory_alignment, + } + prefix_weights_viz = compute_prefix_weights_for_viz( + d_viz, overlap_end_viz, chunk_len_viz, self._rtc_cfg.prefix_attention_schedule + ) + + # Create and emit the event + actions_list = actions_np.tolist() + event = EvActionChunk( + src_control_step=src_control_step, + actions=actions_list, + frozen_len=rtc_kwargs.get("inference_delay", 0) if rtc_kwargs else 0, + timestamp=time.time(), + rtc_params=rtc_params_viz, + prefix_weights=prefix_weights_viz, + ) + self._trajectory_viz_server.on_chunk(event) + + dense_kwargs: dict[str, Any] = { + "timestamp": float(observation_t.get_timestamp()), + "source_control_step": int(observation_t.get_control_step()), + "chunk_start_step": int(observation_t.chunk_start_step), + "dt": float(self.config.environment_dt), + "num_actions": int(payload.shape[0]), + "action_dim": int(payload.shape[1]), + "actions_f32": payload.tobytes(order="C"), + } + dense = services_pb2.ActionsDense(**dense_kwargs) + return dense + + def _get_action_chunk(self, observation: dict[str, torch.Tensor], **kwargs: Any) -> torch.Tensor: + """Get action chunk from the policy.""" + t0 = time.perf_counter() + chunk = self.policy.predict_action_chunk(observation, **kwargs) + t1 = time.perf_counter() + self._metrics.diagnostic.timing_s("policy_predict_ms", t1 - t0) + + if chunk.ndim != 3: + chunk = chunk.unsqueeze( + 0 + ) # Add batch dimension: (chunk_size, action_dim) -> (1, chunk_size, action_dim) + + return chunk[:, : self.actions_per_chunk, :] + + def stop(self) -> None: + """Stop the server.""" + self._reset_server() + self._metrics.diagnostic.stop() + + +@draccus.wrap() +def serve_drtc(cfg: PolicyServerDrtcConfig) -> None: + """Start the DRTC PolicyServer.""" + # Create server instance + policy_server = PolicyServerDrtc(cfg) + + # Setup gRPC server + server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) + services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) + bound_port = server.add_insecure_port(f"{cfg.host}:{cfg.port}") + if bound_port == 0: + raise RuntimeError( + f"Failed to bind gRPC server to {cfg.host}:{cfg.port}. " + "Is the port already in use, or are you binding to an unavailable interface?" + ) + + server_started = False + try: + server.start() + server_started = True + print(f"PolicyServerDrtc listening on {cfg.host}:{bound_port}") + logging.getLogger("policy_server_drtc").info("gRPC server bound to %s:%s", cfg.host, bound_port) + server.wait_for_termination() + except KeyboardInterrupt: + print("KeyboardInterrupt received; shutting down") + except Exception: + policy_server.logger.error("Policy server crashed", exc_info=True) + raise + finally: + # Best-effort cleanup to avoid dangling threads on failures. + try: + policy_server.stop() + except Exception: + policy_server.logger.error("Error while stopping policy server", exc_info=True) + if server_started: + server.stop(grace=5) + print("Server terminated") + + +if __name__ == "__main__": + serve_drtc() diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index 0ee70a0e629..eee9183e125 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -45,6 +45,7 @@ import draccus import grpc +import numpy as np import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 @@ -170,6 +171,16 @@ def start(self): self.logger.error(f"Failed to connect to policy server: {e}") return False + @staticmethod + def _get_action_device_type(action: Any) -> str: + """Best-effort device detection for torch and NumPy action payloads.""" + action_device = getattr(action, "device", None) + if hasattr(action_device, "type"): + return action_device.type + if isinstance(action_device, str): + return action_device + return "cpu" + def stop(self): """Stop the robot client""" self.shutdown_event.set() @@ -288,15 +299,19 @@ def receive_actions(self, verbose: bool = False): # Log device type of received actions if len(timed_actions) > 0: - received_device = timed_actions[0].get_action().device.type + received_device = self._get_action_device_type(timed_actions[0].get_action()) self.logger.debug(f"Received actions on device: {received_device}") # Move actions to client_device (e.g., for downstream planners that need GPU) client_device = self.config.client_device if client_device != "cpu": for timed_action in timed_actions: - if timed_action.get_action().device.type != client_device: - timed_action.action = timed_action.get_action().to(client_device) + action = timed_action.get_action() + if self._get_action_device_type(action) != client_device: + if isinstance(action, np.ndarray): + timed_action.action = torch.from_numpy(action).to(client_device) + else: + timed_action.action = action.to(client_device) self.logger.debug(f"Converted actions to device: {client_device}") else: self.logger.debug(f"Actions kept on device: {client_device}") diff --git a/src/lerobot/async_inference/robot_client_drtc.py b/src/lerobot/async_inference/robot_client_drtc.py new file mode 100644 index 00000000000..c7424efde1e --- /dev/null +++ b/src/lerobot/async_inference/robot_client_drtc.py @@ -0,0 +1,1353 @@ +import logging +import pickle # nosec +import threading +import time +from collections import deque +from contextlib import suppress +from dataclasses import asdict, dataclass +from queue import Empty, Full, Queue +from typing import Any + +import grpc +import numpy as np +from sortedcontainers import SortedDict + +from lerobot.robots.utils import make_robot_from_config +from lerobot.transport import ( + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore +) +from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks + +from .configs_drtc import RobotClientDrtcConfig +from .constants import SUPPORTED_ROBOTS +from .drtc_timed import DrtcAction, DrtcObservation +from .helpers import ( + RawObservation, + RemotePolicyConfig, + get_logger, + map_robot_keys_to_lerobot_features, + visualize_action_queue_size, +) +from .lww_register import LWWReader, LWWRegister +from .utils.action_filter import ( + ActionFilter, + ButterworthFilter, + FilterContext, + NoFilter, +) +from .utils.compression import encode_images_for_transport +from .utils.latency_estimation import make_latency_estimator +from .utils.metrics import DiagnosticMetrics, EvExecutedAction, ExperimentMetricsWriter, Metrics +from .utils.simulation import ( + DisconnectSimulator, + DropSimulator, + DuplicateSimulator, + MockRobot, + ReorderSimulator, +) +from .utils.trajectory_viz import TrajectoryVizClient + + +@dataclass +class ScheduledAction: + """An action scheduled for execution at a specific step. + + Attributes: + action: The action tensor/array to execute. + src_control_step: The control-loop tick t that produced this action (freshness key). + chunk_start_step: The action step n_k where the source chunk starts (for RTC offset math). + """ + + action: np.ndarray + src_control_step: int + chunk_start_step: int + + +@dataclass +class MergeStats: + """Statistics from merging an action chunk into the schedule. + + Used for tracking action discontinuity (L2 distance between old and new + actions at overlapping timesteps) to assess RTC smoothness. + + Attributes: + overlap_count: Number of overlapping non-hard-masked actions compared. + mean_l2: Mean L2 distance across overlapping actions (0.0 if no overlap). + max_l2: Maximum L2 distance across overlapping actions (0.0 if no overlap). + """ + + overlap_count: int + mean_l2: float + max_l2: float + + +class ActionSchedule: + def __init__(self): + self._schedule: SortedDict[int, ScheduledAction] = SortedDict() + + def __len__(self) -> int: + return len(self._schedule) + + def pop_front(self) -> tuple[int, np.ndarray, int, int] | None: + """Pop and return the first (lowest action step) scheduled action. + + Returns: + Tuple of (step, action, src_control_step, chunk_start_step) or None if empty. + """ + if not self._schedule: + return None + # SortedDict maintains sorted key order; pop first (lowest key) item + step, scheduled = self._schedule.popitem(0) + return step, scheduled.action, scheduled.src_control_step, scheduled.chunk_start_step + + def get_masking_chunk_spans( + self, *, current_step: int, max_len: int + ) -> list[tuple[int, int, int]] | None: + """Get list of (src_control_step, start_idx, end_idx) spans for RTC masking prefix. + + This returns information needed to look up raw actions in the server's cache + (keyed by src_control_step). The offset within a cached chunk is computed as + ``step - scheduled.chunk_start_step``. + + The prefix covers both hard mask and soft mask regions. + Handles prefixes that span multiple source chunks due to merging. + + Args: + current_step: The current action step being executed. + max_len: Total number of actions to include (d + epsilon). + + Returns: + List of (src_control_step, start_idx, end_idx) tuples in execution order, + or None if empty. Each tuple specifies a contiguous slice from a cached + chunk on the server. + """ + if max_len <= 0: + return None + + chunks: list[tuple[int, int, int]] = [] + current_src_control_step: int | None = None + current_start: int | None = None + current_end: int = 0 + count = 0 + + for step, scheduled in self._schedule.items(): + if step <= current_step: + continue + + # Index of this action within its source chunk (offset by chunk_start_step) + chunk_idx = step - scheduled.chunk_start_step + + if current_src_control_step is None: + # First action in prefix + current_src_control_step = scheduled.src_control_step + current_start = chunk_idx + current_end = chunk_idx + 1 + elif scheduled.src_control_step == current_src_control_step and chunk_idx == current_end: + # Contiguous with current span (same source, consecutive index) + current_end = chunk_idx + 1 + else: + # New span - save current and start new + if current_start is not None: + chunks.append((current_src_control_step, current_start, current_end)) + current_src_control_step = scheduled.src_control_step + current_start = chunk_idx + current_end = chunk_idx + 1 + + count += 1 + if count >= max_len: + break + + # Save final span + if current_src_control_step is not None and current_start is not None: + chunks.append((current_src_control_step, current_start, current_end)) + + return chunks if chunks else None + + def get_size(self) -> int: + """Get the current schedule size.""" + return len(self._schedule) + + def is_empty(self) -> bool: + """Check if schedule is empty.""" + return len(self._schedule) == 0 + + def merge( + self, + incoming_actions: list[DrtcAction], + src_control_step: int, + chunk_start_step: int, + current_action_step: int, + logger: logging.Logger | None = None, + ) -> MergeStats: + """Merge incoming actions using freshest-observation-wins strategy. + + Args: + incoming_actions: List of DrtcAction from the server. + src_control_step: The control-loop tick t that produced this chunk (freshness key). + chunk_start_step: The action step n_k where this chunk starts. + current_action_step: The most recently executed action step (n*). + logger: Optional logger for debug output. + + Returns: + MergeStats with L2 discrepancy metrics for overlapping actions. + """ + # Use counters instead of per-action logging to avoid ~1ms per log call + stale_count = 0 + inserted_count = 0 + updated_count = 0 + + # Track L2 discrepancy for overlapping actions (non-hard-masked) + l2_distances: list[float] = [] + + for timed_action in incoming_actions: + step = timed_action.get_action_step() + action = timed_action.get_action() + + # Skip stale actions (already executed) + if step <= current_action_step: + stale_count += 1 + continue + + # TODO - revisit this check + existing = self._schedule.get(step) + if existing is None: + self._schedule[step] = ScheduledAction( + action=action, src_control_step=src_control_step, chunk_start_step=chunk_start_step + ) + inserted_count += 1 + continue + + # Compute L2 discrepancy for ALL overlapping actions (for analysis metrics) + old_arr = np.asarray(existing.action, dtype=np.float32).reshape(-1) + new_arr = np.asarray(action, dtype=np.float32).reshape(-1) + if old_arr.shape == new_arr.shape and old_arr.size > 0: + l2 = float(np.linalg.norm(new_arr - old_arr)) + l2_distances.append(l2) + + if src_control_step > existing.src_control_step: + # Fresher observation wins (only for non-hard-masked actions) + self._schedule[step] = ScheduledAction( + action=action, src_control_step=src_control_step, chunk_start_step=chunk_start_step + ) + updated_count += 1 + + # Single summary log instead of per-action logs (saves ~20ms for 23 log calls) + if logger and stale_count: + logger.debug( + f"Merge stats: {stale_count} stale, {inserted_count} inserted, {updated_count} updated" + ) + + overlap_count = len(l2_distances) + if overlap_count > 0: + mean_l2 = float(np.mean(l2_distances)) + max_l2 = float(np.max(l2_distances)) + else: + mean_l2 = 0.0 + max_l2 = 0.0 + + return MergeStats(overlap_count=overlap_count, mean_l2=mean_l2, max_l2=max_l2) + + def clear(self) -> None: + """Clear all scheduled actions.""" + self._schedule.clear() + + +@dataclass +class ObservationRequest: + """Request for an observation capture, sent from main thread to obs sender. + + Attributes: + control_step: The control-loop tick t when this request was made (LWW key). + chunk_start_step: The action step n_k where the resulting chunk should start. + task: The task description string. + """ + + control_step: int + chunk_start_step: int + task: str + rtc_meta: dict[str, Any] | None = None + + +@dataclass +class ReceivedActionChunk: + """Action chunk received from the server with metadata. + + Attributes: + actions: List of DrtcAction from the server. + src_control_step: The control-loop tick t that produced this chunk. + chunk_start_step: The action step n_k where this chunk starts. + measured_latency: Measured round-trip time for this chunk. + obs_sent_ts: Wall-clock timestamp when the client sent the observation (Unix seconds). + server_obs_received_ts: Wall-clock timestamp when the server received the observation. + server_action_sent_ts: Wall-clock timestamp when the server sent the action chunk. + action_received_ts: Wall-clock timestamp when the client received the action chunk. + """ + + actions: list[DrtcAction] + src_control_step: int + chunk_start_step: int + measured_latency: float + obs_sent_ts: float | None = None + server_obs_received_ts: float | None = None + server_action_sent_ts: float | None = None + action_received_ts: float | None = None + + +class RobotClientDrtc: + prefix = "robot_client_drtc" + logger = get_logger(prefix) + + @staticmethod + def _ms(seconds: float) -> float: + return seconds * 1000.0 + + def __init__(self, config: RobotClientDrtcConfig): + """Initialize the DRTC robot client. + + Args: + config: Configuration for the robot client. + """ + self.config = config + + # Use mock robot when no physical robot is available + if config.use_mock_robot: + self.robot = MockRobot() + self.robot.connect() + # Mock features for simulation + lerobot_features = { + "observation.state": list(self.robot.state_features), + "action": list(self.robot.action_features), + } + else: + self.robot = make_robot_from_config(config.robot) + self.robot.connect() + lerobot_features = map_robot_keys_to_lerobot_features(self.robot) + + self._obs_drop_sim = DropSimulator(config=config.drop_obs_config) + self._action_drop_sim = DropSimulator(config=config.drop_action_config) + self._obs_dup_sim = DuplicateSimulator(config=config.dup_obs_config) + self._action_dup_sim = DuplicateSimulator(config=config.dup_action_config) + self._obs_reorder_sim = ReorderSimulator(config=config.reorder_obs_config) + self._action_reorder_sim = ReorderSimulator(config=config.reorder_action_config) + self._disconnect_sim = DisconnectSimulator(config=config.disconnect_config) + + self.server_address = config.server_address + self.policy_config = RemotePolicyConfig( + config.policy_type, + config.pretrained_name_or_path, + lerobot_features, + config.actions_per_chunk, + config.policy_device, + rtc_enabled=config.rtc_enabled, + rtc_max_guidance_weight=config.rtc_max_guidance_weight, + rtc_prefix_attention_schedule=config.rtc_prefix_attention_schedule, + rtc_sigma_d=config.rtc_sigma_d, + rtc_full_trajectory_alignment=config.rtc_full_trajectory_alignment, + num_flow_matching_steps=config.num_flow_matching_steps, + spikes=config.spikes, + diagnostics_verbose=config.metrics_diagnostic_verbose, + ) + + self.channel = grpc.insecure_channel( + self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s") + ) + self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel) + + # Shutdown coordination + self.shutdown_event = threading.Event() + self._active_action_stream: grpc.Future | None = None # Cancel on stop to unblock action_receiver + + # Action state: n(t), initialized to -1 per algorithm. + # Note: Only the main control loop thread reads/writes action_step. + self.action_step: int = -1 + + # Control-loop tick counter t ∈ ℕ (monotone, incremented every tick). + # Used as the LWW logical clock so that dropped messages never stall watermarks. + self.control_step: int = 0 + + # Latency estimation (configurable: JK or max_last_10) + # Upper bound: d <= H/2 per RTC constraint (with s = d, d <= H - s becomes d <= H/2) + self.latency_estimator = make_latency_estimator( + kind=config.latency_estimator_type, + fps=config.fps, + alpha=config.latency_alpha, + beta=config.latency_beta, + k=config.latency_k, + action_chunk_size=config.actions_per_chunk, + s_min=config.s_min, + ) + + # Action schedule (replaces Queue with OrderedDict) + self.action_schedule = ActionSchedule() + + # Cool-down counter O^c(t). + # Note: Only the main control loop thread reads/writes obs_cooldown. + self.obs_cooldown: int = 0 + + # SPSC Mailboxes (one-slot queues) + # Observation request register: main thread -> observation sender + self._obs_request_reg: LWWRegister[ObservationRequest | None] = LWWRegister( + initial_control_step=-1, initial_value=None + ) + + # Action register: action receiver -> main thread + self._action_reg: LWWRegister[ReceivedActionChunk | None] = LWWRegister( + initial_control_step=-1, initial_value=None + ) + self._action_reader: LWWReader[ReceivedActionChunk | None] = self._action_reg.reader() + + # Synchronization barrier for thread startup + self.start_barrier = threading.Barrier(3) # 3 threads: main, obs sender, action receiver + + # Debug tracking (bounded to ~5 min at control rate to prevent unbounded growth) + _max_queue_history = self.config.fps * 300 # 5 minutes + self.action_queue_sizes: deque[int] = deque(maxlen=_max_queue_history) + + # Metrics (two categories): + # - experiment: written to disk (CSV + trajectory JSON) when metrics_path is set + # - diagnostic: periodic console output (avg/max timings) when enabled + diag = DiagnosticMetrics( + fps=config.fps, + window_s=config.metrics_diagnostic_window_s, + interval_s=config.metrics_diagnostic_interval_s, + enabled=config.metrics_diagnostic_enabled, + verbose=config.metrics_diagnostic_verbose, + prefix="DIAG", + ) + diag.start() + + exp: ExperimentMetricsWriter | None = None + if config.metrics_path: + exp = ExperimentMetricsWriter( + path=config.metrics_path, + simulation_config=self._build_simulation_config(), + experiment_config=self._build_experiment_config(), + ) + + self._metrics = Metrics(experiment=exp, diagnostic=diag) + + # Trajectory visualization: send chunks to policy server via gRPC + # Uses a queue + background thread to avoid blocking the control loop + self._trajectory_chunk_queue: Queue[services_pb2.TrajectoryChunk] = Queue(maxsize=10) + self._trajectory_sender_thread: threading.Thread | None = None + self._trajectory_viz_client: TrajectoryVizClient | None = None + if config.trajectory_viz_enabled: + self._trajectory_sender_thread = threading.Thread( + target=self._trajectory_chunk_sender, + name="trajectory_chunk_sender", + daemon=True, + ) + self._trajectory_sender_thread.start() + + # WebSocket client for sending executed actions directly to viz server + self._trajectory_viz_client = TrajectoryVizClient(ws_url=config.trajectory_viz_ws_url) + self._trajectory_viz_client.start() + + # Action filter (class-based, with optional hard-mask lookahead) + self._action_filter: ActionFilter = self._create_action_filter() + + @property + def running(self) -> bool: + return not self.shutdown_event.is_set() + + @property + def current_action_step(self) -> int: + """Get the most recently executed action step n*(t). + + Note: Only the main control loop thread should access this property. + """ + return max(self.action_step, -1) + + def _create_action_filter(self) -> ActionFilter: + """Create the action filter based on configuration. + + Returns: + Configured ActionFilter instance. + """ + cfg = self.config + mode = cfg.action_filter_mode + + if mode == "none": + return NoFilter() + elif mode == "butterworth": + return ButterworthFilter( + cutoff=cfg.action_filter_butterworth_cutoff, + order=cfg.action_filter_butterworth_order, + fps=cfg.fps, + gain=cfg.action_filter_gain, + past_buffer_size=cfg.action_filter_past_buffer_size, + ) + else: + return NoFilter() + + def start(self) -> bool: + """Start the robot client and connect to the policy server.""" + try: + t_total_start = time.perf_counter() + + # Server handshake + t_ready_start = time.perf_counter() + self.stub.Ready(services_pb2.Empty()) + t_ready_done = time.perf_counter() + self._metrics.diagnostic.timing_s("ready_rpc_ms", t_ready_done - t_ready_start) + + # Send policy configuration + policy_config_bytes = pickle.dumps(self.policy_config) + policy_setup = services_pb2.PolicySetup(data=policy_config_bytes) + + t_policy_rpc_start = time.perf_counter() + self.stub.SendPolicyInstructions(policy_setup) + t_policy_rpc_done = time.perf_counter() + self._metrics.diagnostic.timing_s("policy_rpc_ms", t_policy_rpc_done - t_policy_rpc_start) + + self.shutdown_event.clear() + + # Seed cooldown with s_min so the trigger gate works before the + # first real RTT measurement. The estimator itself stays unseeded; + # it will initialise from the first real measurement (zero variance). + self.obs_cooldown = self.config.s_min + self.config.epsilon + self._metrics.diagnostic.timing_s("client_init_total_ms", time.perf_counter() - t_total_start) + + return True + + except grpc.RpcError as e: + self.logger.error(f"Failed to connect to policy server: {e}") + return False + + def stop(self) -> None: + """Stop the robot client.""" + self.shutdown_event.set() + + # Cancel active gRPC action stream so action_receiver unblocks promptly + stream = self._active_action_stream + if stream is not None: + with suppress(Exception): + stream.cancel() + self._active_action_stream = None + + # Flush experiment metrics if enabled (disk output; behavior unchanged) + if self._metrics.experiment is not None and self.config.metrics_path: + self._metrics.experiment.flush(self.config.metrics_path) + + # Stop trajectory viz client if enabled + if self._trajectory_viz_client is not None: + self._trajectory_viz_client.stop() + + self.robot.disconnect() + + self.channel.close() + self._metrics.diagnostic.stop() + + def signal_stop(self) -> None: + """Signal the client to stop without disconnecting the robot. + + Use this when you want to stop the control loop but keep the robot + and server connection alive for subsequent experiments. + """ + self.shutdown_event.set() + + # Cancel active gRPC action stream so action_receiver unblocks promptly + stream = self._active_action_stream + if stream is not None: + with suppress(Exception): + stream.cancel() + self._active_action_stream = None + + # Flush experiment metrics if enabled (disk output; behavior unchanged) + if self._metrics.experiment is not None and self.config.metrics_path: + try: + self._metrics.experiment.flush(self.config.metrics_path) + except Exception as e: + import traceback as _tb + + self.logger.error(f"Failed to flush experiment metrics: {e}") + _tb.print_exc() + + def _build_experiment_config(self) -> dict: + """Build a serialisable dict of core experiment parameters. + + Captures robot/hardware metadata, policy, DRTC, and + action-filter settings so the plotter can render a configuration + table in LaTeX output. + """ + # Build camera summary from robot config + cameras = getattr(self.config.robot, "cameras", {}) + num_cameras = len(cameras) + camera_parts = [] + for name, cam_cfg in cameras.items(): + w = getattr(cam_cfg, "width", "?") + h = getattr(cam_cfg, "height", "?") + camera_parts.append(f"{name} ({w}x{h})") + cameras_str = ", ".join(camera_parts) if camera_parts else "none" + + return { + # Robot / hardware + "robot_type": self.config.robot_type, + "gpu": self.config.gpu, + "client_host": self.config.client_host, + "server_host": self.config.server_host, + "num_cameras": num_cameras, + "cameras": cameras_str, + # Policy + "policy_type": self.config.policy_type, + "pretrained_name_or_path": self.config.pretrained_name_or_path, + "chunk_size": self.config.actions_per_chunk, + "fps": self.config.fps, + "s_min": self.config.s_min, + "epsilon": self.config.epsilon, + "latency_estimator_type": self.config.latency_estimator_type, + "latency_alpha": self.config.latency_alpha, + "latency_beta": self.config.latency_beta, + "latency_k": self.config.latency_k, + # Flow matching / RTC + "num_flow_matching_steps": self.config.num_flow_matching_steps, + "rtc_enabled": self.config.rtc_enabled, + "rtc_max_guidance_weight": self.config.rtc_max_guidance_weight, + "rtc_prefix_attention_schedule": self.config.rtc_prefix_attention_schedule, + "rtc_sigma_d": self.config.rtc_sigma_d, + "rtc_full_trajectory_alignment": self.config.rtc_full_trajectory_alignment, + # Action filter + "filter_type": self.config.action_filter_mode, + "filter_cutoff": self.config.action_filter_butterworth_cutoff, + "gain": self.config.action_filter_gain, + } + + def _build_simulation_config(self) -> dict: + """Build a serialisable dict of all configured simulation events. + + Captures drop, duplicate, reorder, and spike configs so they can be + stored alongside trajectory data for post-hoc visualisation. + """ + + def _events_to_dicts(config, attr: str) -> list[dict]: + if config is None: + return [] + return [asdict(ev) for ev in getattr(config, attr, [])] + + return { + "drop_obs": _events_to_dicts(self.config.drop_obs_config, "drops"), + "drop_action": _events_to_dicts(self.config.drop_action_config, "drops"), + "dup_obs": _events_to_dicts(self.config.dup_obs_config, "duplicates"), + "dup_action": _events_to_dicts(self.config.dup_action_config, "duplicates"), + "reorder_obs": _events_to_dicts(self.config.reorder_obs_config, "reorders"), + "reorder_action": _events_to_dicts(self.config.reorder_action_config, "reorders"), + "disconnect": _events_to_dicts(self.config.disconnect_config, "disconnects"), + "spikes": list(self.config.spikes) if self.config.spikes else [], + } + + # ------------------------------------------------------------------------- + # Observation Sender Thread + # ------------------------------------------------------------------------- + + def observation_sender(self) -> None: + """Captures, encodes, and sends observations to the policy server.""" + self.start_barrier.wait() + + last_good_observation: RawObservation | None = None + last_good_observation_time: float | None = None + consecutive_capture_failures = 0 + reader = self._obs_request_reg.reader() + idle_start = time.perf_counter() + + while self.running: + try: + state, _, is_new = reader.read_if_newer() + request = state.value + if not is_new or request is None: + time.sleep(0.01) + continue + + # Emit wait time (how long obs sender was idle waiting for work) + self._metrics.diagnostic.timing_s("obs_wait_ms", time.perf_counter() - idle_start) + + t_capture_start = time.perf_counter() + + # Capture observation from robot + used_fallback = False + start_rtt_timestamp = time.time() + try: + raw_observation = self.robot.get_observation() + last_good_observation = raw_observation + last_good_observation_time = time.time() + consecutive_capture_failures = 0 + except Exception as e: + consecutive_capture_failures += 1 + if ( + self.config.obs_fallback_on_failure + and last_good_observation is not None + and last_good_observation_time is not None + and (time.time() - last_good_observation_time) <= self.config.obs_fallback_max_age_s + ): + used_fallback = True + raw_observation = last_good_observation + self._metrics.diagnostic.counter("obs_fallback_used", 1) + else: + self.logger.error( + "Observation capture failed (%s). No usable fallback (consecutive_failures=%s).", + e, + consecutive_capture_failures, + ) + continue + + # Avoid mutating cached observation dict if we are reusing it. + if used_fallback: + raw_observation = dict(raw_observation) + raw_observation["task"] = request.task + if request.rtc_meta is not None: + raw_observation["__rtc__"] = request.rtc_meta + + t_capture_done = time.perf_counter() + + # Encode images for transport + t_encode_start = time.perf_counter() + encoded_observation, _ = encode_images_for_transport(raw_observation, jpeg_quality=60) + t_encode_done = time.perf_counter() + self._metrics.diagnostic.timing_s("obs_encode_ms", t_encode_done - t_encode_start) + + # Create timed observation + timed_obs = DrtcObservation( + timestamp=start_rtt_timestamp, + control_step=request.control_step, + observation=encoded_observation, + chunk_start_step=request.chunk_start_step, + ) + + # Network disconnect simulation (blocks until window ends) + disconnect_sleep = self._disconnect_sim.wait_if_disconnected() + if disconnect_sleep > 0: + self._metrics.diagnostic.counter("disconnect_sim", 1) + if self._metrics.experiment is not None: + self._metrics.experiment.record_sim_event("disconnect") + continue + + # Check if observation should be dropped (simulation/experiments) + if self._obs_drop_sim.should_drop(): + self._metrics.diagnostic.counter("obs_dropped_sim", 1) + if self._metrics.experiment is not None: + self._metrics.experiment.record_sim_event("obs_dropped") + continue + + # Reorder injection (hold-and-swap before send) + obs_items = self._obs_reorder_sim.process(timed_obs) + if not obs_items: + self._metrics.diagnostic.counter("obs_reorder_held", 1) + if self._metrics.experiment is not None: + self._metrics.experiment.record_sim_event("obs_reorder_held") + continue + if len(obs_items) > 1: + self._metrics.diagnostic.counter("obs_reorder_swapped", 1) + if self._metrics.experiment is not None: + self._metrics.experiment.record_sim_event("obs_reorder_swapped") + + # Send each item (1 normally, 2 when a swap completes) + t_send_start = time.perf_counter() + for obs_item in obs_items: + self._send_observation(obs_item) + + # Duplicate injection (after send) + if self._obs_dup_sim.should_duplicate(): + self._send_observation(obs_item) + self._metrics.diagnostic.counter("obs_duplicated_sim", 1) + if self._metrics.experiment is not None: + self._metrics.experiment.record_sim_event("obs_duplicated") + t_send_done = time.perf_counter() + self._metrics.diagnostic.timing_s("obs_capture_ms", t_capture_done - t_capture_start) + self._metrics.diagnostic.timing_s("obs_send_ms", t_send_done - t_send_start) + idle_start = time.perf_counter() + + except Exception as e: + self.logger.error("Error in observation sender: %s", e, exc_info=True) + + def _send_observation(self, obs: DrtcObservation) -> bool: + """Send a timed observation to the policy server via gRPC.""" + try: + observation_bytes = pickle.dumps(obs) + observation_iterator = send_bytes_in_chunks( + observation_bytes, + services_pb2.Observation, + log_prefix="[CLIENT] Observation", + silent=True, + ) + _ = self.stub.SendObservations(observation_iterator) + return True + except grpc.RpcError as e: + self.logger.error(f"Error sending observation: {e}") + return False + + # ------------------------------------------------------------------------- + # Trajectory Chunk Sender Thread + # ------------------------------------------------------------------------- + + def _trajectory_chunk_sender(self) -> None: + """Background thread that sends trajectory chunks to the policy server.""" + while self.running: + try: + # Wait for a chunk to send (with timeout to check shutdown) + try: + chunk = self._trajectory_chunk_queue.get(timeout=0.1) + except Empty: + continue + + # Send to server (best-effort, don't block on errors) + try: + self.stub.SendTrajectoryChunk(chunk) + except grpc.RpcError as e: + self.logger.debug("Trajectory chunk send failed: %s", e) + self._metrics.diagnostic.counter("trajectory_chunk_send_rpc_error", 1) + + except Exception as e: + self.logger.error("Error in trajectory chunk sender: %s", e, exc_info=True) + self._metrics.diagnostic.counter("trajectory_chunk_sender_error", 1) + + def _queue_trajectory_chunk( + self, + src_control_step: int, + actions: list[np.ndarray], + frozen_len: int, + ) -> None: + """Queue a trajectory chunk for sending to the policy server (non-blocking).""" + if not actions: + return + + # Convert actions to packed float32 bytes + action_dim = actions[0].shape[0] if len(actions) > 0 else 0 + actions_array = np.stack([a.astype(np.float32) for a in actions], axis=0) + actions_bytes = actions_array.tobytes() + + chunk = services_pb2.TrajectoryChunk( + source_step=src_control_step, + num_actions=len(actions), + action_dim=action_dim, + actions_f32=actions_bytes, + frozen_len=frozen_len, + timestamp=time.time(), + ) + + # Non-blocking put: drop if queue is full + try: + self._trajectory_chunk_queue.put_nowait(chunk) + except Full: + # Drop oldest and add new + with suppress(Empty): + self._trajectory_chunk_queue.get_nowait() + with suppress(Full): + self._trajectory_chunk_queue.put_nowait(chunk) + + # ------------------------------------------------------------------------- + # Action Receiver Thread + # ------------------------------------------------------------------------- + + def action_receiver(self) -> None: + """Receives actions from the server via streaming.""" + self.start_barrier.wait() + last_chunk_time: float | None = None + while self.running: + try: + t_rpc_start = time.perf_counter() + stream = self.stub.StreamActionsDense(services_pb2.Empty()) + self._active_action_stream = stream # Store for cancellation on stop + t_rpc_done = time.perf_counter() + self._metrics.diagnostic.timing_s("rpc_ms", t_rpc_done - t_rpc_start) + + for dense in stream: + if not self.running: + break + t_chunk_received = time.perf_counter() + # Emit chunk gap timing (time since last chunk) + if last_chunk_time is not None: + self._metrics.diagnostic.timing_s("chunk_gap_ms", t_chunk_received - last_chunk_time) + last_chunk_time = t_chunk_received + + # Network disconnect simulation (blocks until window ends) + disconnect_sleep = self._disconnect_sim.wait_if_disconnected() + if disconnect_sleep > 0: + self._metrics.diagnostic.counter("disconnect_sim", 1) + if self._metrics.experiment is not None: + self._metrics.experiment.record_sim_event("disconnect") + continue + + # Reorder injection (hold-and-swap before handle) + dense_items = self._action_reorder_sim.process(dense) + if not dense_items: + self._metrics.diagnostic.counter("action_reorder_held", 1) + if self._metrics.experiment is not None: + self._metrics.experiment.record_sim_event("action_reorder_held") + continue + if len(dense_items) > 1: + self._metrics.diagnostic.counter("action_reorder_swapped", 1) + if self._metrics.experiment is not None: + self._metrics.experiment.record_sim_event("action_reorder_swapped") + + for dense_item in dense_items: + self._handle_actions_dense(dense_item, rpc_ms=0.0) + + # Duplicate injection (after handle) + if self._action_dup_sim.should_duplicate(): + self._handle_actions_dense(dense_item, rpc_ms=0.0) + self._metrics.diagnostic.counter("action_chunk_duplicated_sim", 1) + if self._metrics.experiment is not None: + self._metrics.experiment.record_sim_event("action_duplicated") + + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNIMPLEMENTED: + self.logger.error( + "Server does not implement StreamActionsDense. " + "This client is streaming-only for actions; please update the server." + ) + self.stop() + return + self.logger.error(f"Error in StreamActionsDense: {e}") + time.sleep(0.1) + + def _handle_actions_dense(self, dense: services_pb2.ActionsDense, rpc_ms: float) -> None: + """Decode a dense action chunk into DrtcAction list and publish to main thread.""" + receive_time = time.time() + + num_actions = int(dense.num_actions) + action_dim = int(dense.action_dim) + if num_actions <= 0 or action_dim <= 0: + return + + t_deser_start = time.perf_counter() + actions = np.frombuffer(dense.actions_f32, dtype=np.float32) + if actions.size != num_actions * action_dim: + raise ValueError( + f"ActionsDense buffer size mismatch: {actions.size} != {num_actions * action_dim}" + ) + actions = actions.reshape(num_actions, action_dim) + t_deser_done = time.perf_counter() + + timestamp = float(dense.timestamp) + source_control_step = int(dense.source_control_step) + chunk_start_step = int(dense.chunk_start_step) + dt = float(dense.dt) + + measured_latency = receive_time - timestamp + timed_actions = [ + DrtcAction( + timestamp=timestamp + i * dt, + control_step=source_control_step, + action_step=chunk_start_step + i, + action=actions[i], + ) + for i in range(num_actions) + ] + + # Extract raw timestamps for the round-trip journey (stored in CSV as-is) + server_obs_received_ts = float(dense.server_obs_received_ts) + server_action_sent_ts = float(dense.server_action_sent_ts) + if server_obs_received_ts > 0 and server_action_sent_ts > 0: + obs_sent_ts = timestamp + action_received_ts = receive_time + else: + obs_sent_ts = None + server_obs_received_ts = None + server_action_sent_ts = None + action_received_ts = None + + self._metrics.diagnostic.timing_ms("rpc_ms", rpc_ms) + self._metrics.diagnostic.timing_s("deser_ms", t_deser_done - t_deser_start) + self._metrics.diagnostic.timing_s("total_latency_rtt_ms", measured_latency) + + # Check if action chunk should be dropped (simulation/experiments) + if self._action_drop_sim.should_drop(): + self._metrics.diagnostic.counter("action_chunk_dropped_sim", 1) + if self._metrics.experiment is not None: + self._metrics.experiment.record_sim_event("action_dropped") + return + + self._publish_received_actions( + timed_actions=timed_actions, + src_control_step=source_control_step, + chunk_start_step=chunk_start_step, + measured_latency=measured_latency, + obs_sent_ts=obs_sent_ts, + server_obs_received_ts=server_obs_received_ts, + server_action_sent_ts=server_action_sent_ts, + action_received_ts=action_received_ts, + ) + + def _publish_received_actions( + self, + *, + timed_actions: list[DrtcAction], + src_control_step: int, + chunk_start_step: int, + measured_latency: float, + obs_sent_ts: float | None = None, + server_obs_received_ts: float | None = None, + server_action_sent_ts: float | None = None, + action_received_ts: float | None = None, + ) -> None: + chunk = ReceivedActionChunk( + actions=timed_actions, + src_control_step=src_control_step, + chunk_start_step=chunk_start_step, + measured_latency=measured_latency, + obs_sent_ts=obs_sent_ts, + server_obs_received_ts=server_obs_received_ts, + server_action_sent_ts=server_action_sent_ts, + action_received_ts=action_received_ts, + ) + _, accepted = self._action_reg.update_if_newer(control_step=src_control_step, value=chunk) + + if self._metrics.experiment is not None: + self._metrics.experiment.record_register_event( + register_name="client_action", + control_step=src_control_step, + chunk_start_step=chunk_start_step, + accepted=accepted, + ) + + # ------------------------------------------------------------------------- + # Main Thread: Control Loop + # ------------------------------------------------------------------------- + + def control_loop(self, task: str | None = None) -> None: + """Main control loop following Algorithm 1 from the paper. + + This loop: + 1. Executes actions if available + 2. Checks inference trigger condition and requests observations + 3. Processes incoming action chunks + 4. Maintains control frequency + + Args: + task: Optional task override (uses config.task if not provided). + """ + self.start_barrier.wait() + + task = task or self.config.task + + prev_loop_start: float | None = None + next_tick: float | None = time.perf_counter() if self.config.control_use_deadline_clock else None + + while self.running: + t_loop_start = time.perf_counter() + if prev_loop_start is not None: + self._metrics.diagnostic.timing_s("loop_dt_ms", t_loop_start - prev_loop_start) + prev_loop_start = t_loop_start + + # Experiment metrics tracking for this tick + _tick_obs_triggered = False + _tick_action_received = False + _tick_measured_latency_ms: float | None = None + _tick_obs_sent_ts: float | None = None + _tick_server_obs_received_ts: float | None = None + _tick_server_action_sent_ts: float | None = None + _tick_action_received_ts: float | None = None + _tick_chunk_overlap_count: int | None = None + _tick_chunk_mean_l2: float | None = None + _tick_chunk_max_l2: float | None = None + + # Phase timing tracking + _phase_exec_ms = 0.0 + _phase_trigger_ms = 0.0 + _phase_merge_ms = 0.0 + + # --------------------------------------------------------------------- + # Step 1: Execute action if available + # --------------------------------------------------------------------- + t_phase1_start = time.perf_counter() + if not self.action_schedule.is_empty(): + result = self.action_schedule.pop_front() + if result is not None: + step, action, src_control_step, chunk_start_step = result + + # Apply action filter to reduce jitter from policy micro-updates + ctx = FilterContext(action=action) + filtered_action = self._action_filter.apply(ctx) + + t_send_start = time.perf_counter() + self.robot.send_action(self._action_array_to_dict(filtered_action)) + t_send_done = time.perf_counter() + + # Keep action_step aligned with the schedule's action-step keys. + # Only the main control loop thread writes this. + self.action_step = step + self._metrics.diagnostic.timing_s("send_action_ms", t_send_done - t_send_start) + + # Stream executed action to the visualization server (best-effort). + if self._trajectory_viz_client is not None: + self._trajectory_viz_client.on_executed_action( + EvExecutedAction( + step=step, + action=filtered_action.tolist(), + timestamp=time.time(), + ) + ) + + # Record executed action for experiment trajectory visualization + if self._metrics.experiment is not None: + self._metrics.experiment.record_executed_action( + step=step, + action=filtered_action, + src_control_step=src_control_step, + chunk_start_step=chunk_start_step, + ) + + t_phase1_end = time.perf_counter() + _phase_exec_ms = self._ms(t_phase1_end - t_phase1_start) + + # Track queue size for debugging and starvation detection + schedule_size = self.action_schedule.get_size() + self.action_queue_sizes.append(schedule_size) + is_starved = schedule_size == 0 + if is_starved: + self._metrics.diagnostic.counter("starvation", 1) + + # --------------------------------------------------------------------- + # Step 2: Check inference trigger condition + # --------------------------------------------------------------------- + t_phase2_start = time.perf_counter() + latency_steps = self.latency_estimator.estimate_steps + epsilon = self.config.epsilon + s_min = self.config.s_min + chunk_len = self.config.actions_per_chunk + + trigger_threshold = chunk_len - s_min + if self.config.cooldown_enabled: + should_trigger = schedule_size <= trigger_threshold and self.obs_cooldown == 0 + else: + # Classic async baseline: always trigger when schedule is low + should_trigger = schedule_size <= trigger_threshold + + if should_trigger: + current_step = self.current_action_step + + # Clamp to 0 so the server produces chunks starting at 0 on startup (consistent with the + # original async inference implementation that uses max(latest_action, 0)). + rtc_meta: dict[str, Any] | None = None + if self.config.rtc_enabled: + t_rtc_start = time.perf_counter() + + # RTC paper: effective execution horizon is s = max(s_min, d) + # - d = latency_steps = hard mask region (weight 1.0) + # - overlap_end = H - s = where fresh region starts + # - Soft mask region: [d, overlap_end) with decaying weight + d = int(latency_steps) + s = max(s_min, d) # Effective execution horizon + overlap_end = chunk_len - s # Where fresh region starts + + # Get masking spans from schedule (handles multi-chunk prefixes) + # Returns list of (src_step, start_idx, end_idx) for server cache lookup + action_schedule_spans = self.action_schedule.get_masking_chunk_spans( + current_step=current_step, max_len=overlap_end + ) + + rtc_meta = { + "enabled": True, + "latency_steps": d, # Hard mask region [0, d) + "action_schedule_spans": action_schedule_spans, # List of (src_step, start, end) or None + "overlap_end": overlap_end, # H - max(s_min, d): where fresh region starts + } + t_rtc_end = time.perf_counter() + self._metrics.diagnostic.timing_s("rtc_build_ms", t_rtc_end - t_rtc_start) + + request = ObservationRequest( + control_step=self.control_step, + chunk_start_step=max(current_step, 0), + task=task, + rtc_meta=rtc_meta, + ) + + # Always reset cooldown when trigger fires (before attempting put) + # Cooldown = latency_steps + epsilon (buffer to prevent over-triggering) + if self.config.cooldown_enabled: + self.obs_cooldown = latency_steps + epsilon + + # Publish newest request (monotone w.r.t. control_step t) + _, obs_accepted = self._obs_request_reg.update_if_newer( + control_step=request.control_step, + value=request, + ) + + if self._metrics.experiment is not None: + self._metrics.experiment.record_register_event( + register_name="client_obs_request", + control_step=request.control_step, + accepted=obs_accepted, + ) + + _tick_obs_triggered = True + self._metrics.diagnostic.counter("obs_triggered", 1) + else: + # Decrement cooldown: O^c(t+1) = max(O^c(t) - 1, 0) + # Only decrement in 'cooldown' mode (default behavior for drop recovery) + # In 'merge_reset' mode, cooldown is only reset when actions are merged + if self.config.cooldown_enabled and self.config.inference_reset_mode == "cooldown": + self.obs_cooldown = max(self.obs_cooldown - 1, 0) + + t_phase2_end = time.perf_counter() + _phase_trigger_ms = self._ms(t_phase2_end - t_phase2_start) + + # --------------------------------------------------------------------- + # Step 3: Check for incoming action chunks + # --------------------------------------------------------------------- + t_phase3_start = time.perf_counter() + state, _, is_new = self._action_reader.read_if_newer() + chunk = state.value + if is_new and chunk is not None: + current_step = self.current_action_step + latency_steps = self.latency_estimator.estimate_steps + + # Update latency estimate + self.latency_estimator.update(chunk.measured_latency) + + # Merge actions into schedule + merge_stats = self.action_schedule.merge( + incoming_actions=chunk.actions, + src_control_step=chunk.src_control_step, + chunk_start_step=chunk.chunk_start_step, + current_action_step=current_step, + ) + + _tick_action_received = True + _tick_measured_latency_ms = self._ms(chunk.measured_latency) + _tick_obs_sent_ts = chunk.obs_sent_ts + _tick_server_obs_received_ts = chunk.server_obs_received_ts + _tick_server_action_sent_ts = chunk.server_action_sent_ts + _tick_action_received_ts = chunk.action_received_ts + + # Track discrepancy stats from the merge + _tick_chunk_overlap_count = merge_stats.overlap_count + _tick_chunk_mean_l2 = merge_stats.mean_l2 + _tick_chunk_max_l2 = merge_stats.max_l2 + + # In merge_reset mode, reset cooldown when actions are merged + # This mimics RTC-style behavior where inference readiness is gated + # by action arrival rather than time-based cooldown + if self.config.inference_reset_mode == "merge_reset": + self.obs_cooldown = 0 + + # Send action chunk to policy server for trajectory visualization + if self.config.trajectory_viz_enabled and chunk.actions: + # Extract action arrays from the DRTC action payloads + actions_arrays = [ta.action for ta in chunk.actions] + self._queue_trajectory_chunk( + src_control_step=chunk.src_control_step, + actions=actions_arrays, + frozen_len=latency_steps, + ) + + # Record chunk for experiment trajectory visualization + if self._metrics.experiment is not None and chunk.actions: + actions_arrays = [ta.action for ta in chunk.actions] + self._metrics.experiment.record_chunk( + src_control_step=chunk.src_control_step, + actions=actions_arrays, + frozen_len=int(latency_steps), + chunk_start_step=chunk.chunk_start_step, + ) + + t_phase3_end = time.perf_counter() + _phase_merge_ms = self._ms(t_phase3_end - t_phase3_start) + + # Diagnostic phase timings (avg/max only; printed periodically by DiagnosticMetrics) + self._metrics.diagnostic.timing_ms("phase_exec_ms", _phase_exec_ms) + self._metrics.diagnostic.timing_ms("phase_trigger_ms", _phase_trigger_ms) + self._metrics.diagnostic.timing_ms("phase_merge_ms", _phase_merge_ms) + + # Advance the control-loop clock (always monotone, even when no action executes) + self.control_step += 1 + + # --------------------------------------------------------------------- + # Step 4: Maintain control frequency + # --------------------------------------------------------------------- + elapsed = time.perf_counter() - t_loop_start + if next_tick is None: + sleep_s = max(0.0, self.config.environment_dt - elapsed) + if sleep_s > 0: + time.sleep(sleep_s) + else: + self._metrics.diagnostic.counter("overrun", 1) + else: + # Deadline-based clock: reduces drift and jitter when occasional overruns happen. + next_tick += self.config.environment_dt + now = time.perf_counter() + sleep_s = next_tick - now + if sleep_s > 0: + time.sleep(sleep_s) + else: + # If we're behind, count an overrun and re-anchor to now to avoid runaway lag. + self._metrics.diagnostic.counter("overrun", 1) + next_tick = now + + self._metrics.diagnostic.set_context( + step=self.current_action_step, + schedule_size=self.action_schedule.get_size(), + latency_steps=self.latency_estimator.estimate_steps, + cooldown=self.obs_cooldown, + s_min=self.config.s_min, + fps=self.config.fps, + ) + + # Record experiment metrics for this tick + if self._metrics.experiment is not None: + self._metrics.experiment.record_tick( + step=self.current_action_step, + schedule_size=self.action_schedule.get_size(), + latency_estimate_steps=self.latency_estimator.estimate_steps, + latency_estimate_ms=self.latency_estimator.estimate_seconds * 1000.0, + cooldown=self.obs_cooldown, + obs_triggered=_tick_obs_triggered, + action_received=_tick_action_received, + measured_latency_ms=_tick_measured_latency_ms, + obs_sent_ts=_tick_obs_sent_ts, + server_obs_received_ts=_tick_server_obs_received_ts, + server_action_sent_ts=_tick_server_action_sent_ts, + action_received_ts=_tick_action_received_ts, + chunk_overlap_count=_tick_chunk_overlap_count, + chunk_mean_l2=_tick_chunk_mean_l2, + chunk_max_l2=_tick_chunk_max_l2, + ) + + def _action_array_to_dict(self, action_array: np.ndarray) -> dict[str, float]: + """Convert action array to dictionary keyed by robot action features.""" + return {key: action_array[i].item() for i, key in enumerate(self.robot.action_features)} + + +def async_client_drtc(cfg: RobotClientDrtcConfig) -> None: + """Run the DRTC async inference client.""" + + if cfg.robot.type not in SUPPORTED_ROBOTS: + raise ValueError(f"Robot {cfg.robot.type} not yet supported!") + + client = RobotClientDrtc(cfg) + + if client.start(): + # Start observation sender thread + obs_sender_thread = threading.Thread( + target=client.observation_sender, + name="observation_sender", + daemon=True, + ) + + # Start action receiver thread + action_receiver_thread = threading.Thread( + target=client.action_receiver, + name="action_receiver", + daemon=True, + ) + + obs_sender_thread.start() + action_receiver_thread.start() + + try: + # Main thread runs the control loop + client.control_loop() + + finally: + client.stop() + obs_sender_thread.join(timeout=2.0) + action_receiver_thread.join(timeout=2.0) + + if cfg.debug_visualize_queue_size: + visualize_action_queue_size(client.action_queue_sizes) + + +if __name__ == "__main__": + import draccus + + draccus.wrap()(async_client_drtc)() diff --git a/src/lerobot/async_inference/rtc_guidance.py b/src/lerobot/async_inference/rtc_guidance.py new file mode 100644 index 00000000000..ac9d9892dfb --- /dev/null +++ b/src/lerobot/async_inference/rtc_guidance.py @@ -0,0 +1,266 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server-scoped Real-Time Chunking (RTC) guidance for async inference. + +This intentionally does NOT depend on `lerobot.policies.rtc.*`. + +It provides a minimal interface compatible with the flow-policy sampling code paths +that expect `rtc_processor.denoise_step(...)`. +""" + +from __future__ import annotations + +import math +from collections.abc import Callable +from dataclasses import dataclass + +import torch +from torch import Tensor + + +@dataclass(frozen=True) +class AsyncRTCConfig: + """Configuration for async RTC guidance. + + Attributes: + enabled: Whether RTC guidance is enabled. + prefix_attention_schedule: Schedule for prefix attention weights (zeros|ones|linear|exp). + max_guidance_weight: Maximum guidance weight for clamping. If None, uses + num_flow_matching_steps (Alex Soare optimization). + sigma_d: Prior variance σ_d for the guidance weight formula. Lower values (e.g., 0.2) + give stronger guidance and smoother transitions. Default 1.0 matches original RTC. + Reference: https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html + full_trajectory_alignment: If True, skip gradient computation and use error directly + as correction. Faster and smoother when distance between chunks is small. + """ + + enabled: bool = False + prefix_attention_schedule: str = "linear" + max_guidance_weight: float | None = None # None = use num_flow_matching_steps (Alex Soare opt) + sigma_d: float = 1.0 # Prior variance (0.2 = stronger guidance, 1.0 = original RTC) + full_trajectory_alignment: bool = False # Skip gradient for faster/smoother transitions + + +class AsyncRTCProcessor: + """RTC-style prefix guidance wrapper around an existing denoiser. + + The call signature matches what PI0/PI05/SmolVLA flow models use in their sampling loop. + """ + + def __init__(self, cfg: AsyncRTCConfig, *, postprocess: Callable[[Tensor], Tensor] | None = None): + self.cfg = cfg + self._postprocess = postprocess + + def is_debug_enabled(self) -> bool: + return False + + def track(self, **_kwargs) -> None: + return + + def denoise_step( + self, + x_t: Tensor, + prev_chunk_left_over: Tensor | None, + inference_delay: int | None, + time: float | Tensor, + original_denoise_step_partial: Callable[[Tensor], Tensor], + overlap_end: int | None = None, + num_flow_matching_steps: int | None = None, + # Backwards compat: policies pass execution_horizon + execution_horizon: int | None = None, + ) -> Tensor: + """RTC guidance wrapper around an existing denoiser. + + Args: + x_t: Current noisy sample tensor. + prev_chunk_left_over: Previous chunk for inpainting guidance. + inference_delay: Latency in action steps (d). + time: Current denoising timestep (1 = noise, 0 = clean). + original_denoise_step_partial: Callable that computes base velocity given x_t. + overlap_end: Where soft masking region ends (H - d). If None, computed from + chunk size and inference_delay. + num_flow_matching_steps: Number of flow matching steps. Used as max_guidance_weight + when cfg.max_guidance_weight is None (Alex Soare optimization). + execution_horizon: Deprecated alias for overlap_end (for policy compatibility). + + Returns: + Guided velocity tensor. + """ + # No guidance if disabled or missing prefix / delay. + if not self.cfg.enabled or prev_chunk_left_over is None or inference_delay is None: + return original_denoise_step_partial(x_t) + + # Backwards compat: use execution_horizon if overlap_end not provided + if overlap_end is None and execution_horizon is not None: + overlap_end = execution_horizon + + tau = 1 - time # match existing RTC convention (inverted time) + + x_t_local = x_t.clone().detach() + + squeezed = False + if x_t_local.ndim < 3: + x_t_local = x_t_local.unsqueeze(0) + squeezed = True + + prev = prev_chunk_left_over + if prev.ndim < 3: + prev = prev.unsqueeze(0) + + batch_size, chunk_t, chunk_a = x_t_local.shape + prev_a = prev.shape[2] + + # Compute overlap_end if not provided: H - d (maximum soft masking) + if overlap_end is None: + overlap_end = chunk_t - inference_delay + + # Clamp overlap_end to available prefix length + prefix_len = prev.shape[1] + overlap_end = min(overlap_end, prefix_len) + + # With server-side zero-padding to max_action_dim, dimensions should now match. + # Log at debug level if they still differ (shouldn't happen after the fix). + if prev_a != chunk_a: + import logging + + logging.getLogger(__name__).debug( + "RTC dimension mismatch: prev_a=%d, chunk_a=%d", + prev_a, + chunk_a, + ) + + # Determine target action dimension: when postprocess is used, comparison happens + # in executable action space (prev's dimension), not raw model space. + target_a = prev_a if self._postprocess is not None else chunk_a + + # Pad prefix temporal dimension to match chunk_t, but keep action dimension as target_a. + if prev.shape[1] < chunk_t: + padded = torch.zeros( + batch_size, chunk_t, target_a, device=x_t_local.device, dtype=x_t_local.dtype + ) + padded[:, : prev.shape[1], :] = prev.to(device=x_t_local.device, dtype=x_t_local.dtype) + prev = padded + else: + prev = prev[:, :chunk_t, :target_a].to(device=x_t_local.device, dtype=x_t_local.dtype) + + # Build weights: frozen [0, d), soft mask [d, overlap_end), fresh [overlap_end, H) + weights_1d = self._get_prefix_weights(inference_delay, overlap_end, chunk_t).to(x_t_local.device) + weights = weights_1d.unsqueeze(0).unsqueeze(-1) # (1, T, 1) + + # We need gradients for the correction term (and optional postprocess), but we do NOT want + # to build a backward graph through the denoiser/model parameters. + with torch.enable_grad(): + with torch.no_grad(): + v_t = original_denoise_step_partial(x_t_local) + + x_t_local.requires_grad_(True) + + # Match policy-side convention: x1_t = x_t - time * v_t + time_tensor = torch.as_tensor(time, device=x_t_local.device, dtype=x_t_local.dtype) + x1_t = x_t_local - time_tensor * v_t.detach() + + # If we're guiding in executable-action space, apply the (differentiable) postprocessor here. + # This is used when the client only provides frozen actions in executable space. + x1_t_for_loss = x1_t + if self._postprocess is not None: + x1_t_for_loss = self._postprocess(x1_t_for_loss) + + err = (prev - x1_t_for_loss) * weights + + # Compute correction term + # If full_trajectory_alignment is enabled, skip gradient and use error directly. + # This is faster and smoother when distance between chunks is small. + if self.cfg.full_trajectory_alignment: + correction = err + else: + correction = torch.autograd.grad(x1_t_for_loss, x_t_local, err.detach(), retain_graph=False)[ + 0 + ] + + # Alex Soare optimization: Use num_flow_matching_steps as max_guidance_weight if not set. + # Reference: https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html + # The number of flow matching steps can be used as β without hyperparameter tuning. + max_gw = self.cfg.max_guidance_weight + if max_gw is None: + max_gw = float(num_flow_matching_steps) if num_flow_matching_steps is not None else 10.0 + max_guidance_weight = torch.as_tensor(max_gw, device=x_t_local.device) + + tau_tensor = torch.as_tensor(tau, device=x_t_local.device, dtype=x_t_local.dtype) + squared_one_minus_tau = (1 - tau_tensor) ** 2 + + # Alex Soare's formula with prior variance σ_d: + # Original RTC: inv_r2 = ((1-τ)² + τ²) / (1-τ)² + # With σ_d: inv_r2 = ((1-τ)² + τ² * σ_d²) / ((1-τ)² * σ_d²) + # Lower σ_d (e.g., 0.2) = narrower prior = stronger guidance = smoother transitions + prior_variance = torch.as_tensor(self.cfg.sigma_d**2, device=x_t_local.device, dtype=x_t_local.dtype) + inv_r2 = (squared_one_minus_tau + tau_tensor**2 * prior_variance) / ( + squared_one_minus_tau * prior_variance + ) + + c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight) + guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight) + guidance_weight = torch.minimum(guidance_weight, max_guidance_weight) + + result = v_t - guidance_weight * correction + if squeezed: + result = result.squeeze(0) + return result + + def _get_prefix_weights(self, start: int, end: int, total: int) -> Tensor: + start = int(start) + end = int(end) + total = int(total) + + start = min(start, end) + schedule = (self.cfg.prefix_attention_schedule or "linear").lower() + + if schedule == "zeros": + weights = torch.zeros(total) + weights[: min(start, total)] = 1.0 + return weights + if schedule == "ones": + weights = torch.ones(total) + weights[max(end, 0) :] = 0.0 + return weights + + lin = self._linweights(start, end, total) + if schedule == "exp": + lin = lin * torch.expm1(lin).div(math.e - 1) + + weights = self._add_trailing_zeros(lin, total, end) + weights = self._add_leading_ones(weights, start, total) + return weights + + @staticmethod + def _linweights(start: int, end: int, total: int) -> Tensor: + skip_steps_at_end = max(total - end, 0) + linspace_steps = total - skip_steps_at_end - start + if end <= start or linspace_steps <= 0: + return torch.tensor([]) + return torch.linspace(1, 0, linspace_steps + 2)[1:-1] + + @staticmethod + def _add_trailing_zeros(weights: Tensor, total: int, end: int) -> Tensor: + zeros_len = total - end + if zeros_len <= 0: + return weights + return torch.cat([weights, torch.zeros(zeros_len)]) + + @staticmethod + def _add_leading_ones(weights: Tensor, start: int, total: int) -> Tensor: + ones_len = min(start, total) + if ones_len <= 0: + return weights + return torch.cat([torch.ones(ones_len), weights]) diff --git a/src/lerobot/async_inference/utils/__init__.py b/src/lerobot/async_inference/utils/__init__.py new file mode 100644 index 00000000000..4d5fe881f1c --- /dev/null +++ b/src/lerobot/async_inference/utils/__init__.py @@ -0,0 +1,61 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility modules for async inference.""" + +from .action_filter import ( + ActionFilter, + ButterworthFilter, + FilterContext, + NoFilter, +) +from .latency_estimation import ( + JKLatencyEstimator, + LatencyEstimator, + LatencyEstimatorBase, + MaxLast10Estimator, + make_latency_estimator, +) +from .metrics import ( + DiagnosticMetrics, + EvActionChunk, + EvExecutedAction, + ExperimentMetricsWriter, + ExperimentTick, + Metrics, +) +from .simulation import DropSimulator, MockRobot, SpikeDelayConfig, SpikeDelaySimulator, SpikeEvent + +__all__ = [ + "ActionFilter", + "ButterworthFilter", + "DiagnosticMetrics", + "DropSimulator", + "EvActionChunk", + "EvExecutedAction", + "ExperimentMetricsWriter", + "ExperimentTick", + "FilterContext", + "JKLatencyEstimator", + "LatencyEstimator", + "LatencyEstimatorBase", + "make_latency_estimator", + "MaxLast10Estimator", + "Metrics", + "MockRobot", + "NoFilter", + "SpikeDelayConfig", + "SpikeDelaySimulator", + "SpikeEvent", +] diff --git a/src/lerobot/async_inference/utils/action_filter.py b/src/lerobot/async_inference/utils/action_filter.py new file mode 100644 index 00000000000..0c052ab3bfb --- /dev/null +++ b/src/lerobot/async_inference/utils/action_filter.py @@ -0,0 +1,126 @@ +"""Action filters for reducing jitter and smoothing robot control signals. + +This module provides a class-based hierarchy of action filters that can be +applied to robot control signals to reduce high-frequency noise and jitter +from policy micro-updates without significantly impacting intentional motion. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import numpy as np +from scipy.signal import butter, sosfilt, sosfilt_zi + + +@dataclass +class FilterContext: + """Context passed to filters each tick. + + Attributes: + action: The current action to filter. + """ + + action: np.ndarray + + +class ActionFilter(ABC): + """Base class for action filters. + + All filter implementations should inherit from this class and implement + the apply() method. + """ + + @abstractmethod + def apply(self, ctx: FilterContext) -> np.ndarray: + """Apply filter and return filtered action. + + Args: + ctx: Filter context containing current action and optional + frozen lookahead actions. + + Returns: + The filtered action array. + """ + pass + + @abstractmethod + def reset(self) -> None: + """Reset filter state (optional override).""" + pass + + +class NoFilter(ActionFilter): + """Pass-through filter that returns actions unchanged.""" + + def apply(self, ctx: FilterContext) -> np.ndarray: + return ctx.action + + def reset(self) -> None: + return None + + +class ButterworthFilter(ActionFilter): + """Butterworth low-pass filter for action smoothing. + + Provides frequency-selective filtering to attenuate high-frequency noise + while passing intentional low-frequency motion with minimal phase lag. + """ + + def __init__( + self, + cutoff: float, + order: int, + fps: float, + gain: float, + past_buffer_size: int, + ): + """Initialize the Butterworth filter. + + Args: + cutoff: Cutoff frequency in Hz. + order: Filter order (1-4). + fps: Control loop frequency in Hz. + gain: Amplitude gain compensation factor. + past_buffer_size: Number of past actions to keep in buffer. + """ + self.cutoff = cutoff + self.order = order + self.fps = fps + self.gain = gain + self.past_buffer_size = past_buffer_size + self._sos: np.ndarray | None = None + self._zi: np.ndarray | None = None + self._prev: np.ndarray | None = None + + def _init_filter(self, action: np.ndarray) -> None: + """Initialize filter coefficients and state.""" + nyquist = self.fps / 2.0 + normalized = min(max(self.cutoff / nyquist, 0.01), 0.99) + self._sos = butter(self.order, normalized, btype="low", output="sos") + zi_single = sosfilt_zi(self._sos) + self._zi = np.array([zi_single * action[j] for j in range(len(action))]) + + def apply(self, ctx: FilterContext) -> np.ndarray: + if self._sos is None: + self._init_filter(ctx.action) + self._prev = ctx.action.copy() + return ctx.action + + # Apply causal (stateful) filter for consistent smoothing + filtered = np.zeros_like(ctx.action) + for j in range(len(ctx.action)): + out, self._zi[j] = sosfilt(self._sos, [ctx.action[j]], zi=self._zi[j]) + filtered[j] = out[0] + + # Apply gain compensation + if self.gain != 1.0 and self._prev is not None: + delta = filtered - self._prev + filtered = self._prev + delta * self.gain + + self._prev = filtered.copy() + return filtered + + def reset(self) -> None: + self._sos = None + self._zi = None + self._prev = None diff --git a/src/lerobot/async_inference/utils/compression.py b/src/lerobot/async_inference/utils/compression.py new file mode 100644 index 00000000000..b02660bde90 --- /dev/null +++ b/src/lerobot/async_inference/utils/compression.py @@ -0,0 +1,85 @@ +from typing import Any + +import cv2 # type: ignore +import numpy as np + + +def _is_uint8_hwc3_image(x: Any) -> bool: + if not isinstance(x, np.ndarray): + return False + if x.dtype != np.uint8: + return False + if x.ndim != 3: + return False + h, w, c = x.shape + if h <= 0 or w <= 0: + return False + return c == 3 + + +def encode_images_for_transport( + observation: Any, + jpeg_quality: int, +) -> tuple[Any, dict[str, int]]: + """Recursively JPEG-encode uint8 HWC3 images inside an observation structure.""" + stats = {"images_encoded": 0, "raw_bytes_total": 0, "encoded_bytes_total": 0} + + def _encode_any(x: Any) -> Any: + if isinstance(x, dict): + return {k: _encode_any(v) for k, v in x.items()} + if isinstance(x, list): + return [_encode_any(v) for v in x] + if isinstance(x, tuple): + return tuple(_encode_any(v) for v in x) + + if not _is_uint8_hwc3_image(x): + return x + + bgr = cv2.cvtColor(x, cv2.COLOR_RGB2BGR) + ok, buf = cv2.imencode( + ".jpg", + bgr, + [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_quality)], + ) + if not ok: + raise RuntimeError("OpenCV failed to JPEG-encode image for transport") + + payload = bytes(buf) + stats["images_encoded"] += 1 + stats["raw_bytes_total"] += int(x.nbytes) + stats["encoded_bytes_total"] += len(payload) + return {"__lerobot_image_encoding__": "jpeg", "quality": int(jpeg_quality), "data": payload} + + return _encode_any(observation), stats + + +def decode_images_from_transport(observation: Any) -> tuple[Any, dict[str, int]]: + """Recursively decode JPEG-marked images back into uint8 HWC3 RGB numpy arrays.""" + stats = {"images_decoded": 0, "raw_bytes_total": 0, "encoded_bytes_total": 0} + + def _maybe_decode_payload(x: Any) -> Any: + if isinstance(x, dict) and x.get("__lerobot_image_encoding__") == "jpeg": + data = x.get("data") + if not isinstance(data, (bytes, bytearray)): + raise TypeError("JPEG payload missing bytes 'data'") + + buf = np.frombuffer(data, dtype=np.uint8) + bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR) + if bgr is None: + raise RuntimeError("OpenCV failed to decode JPEG payload") + + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + stats["images_decoded"] += 1 + stats["encoded_bytes_total"] += len(data) + stats["raw_bytes_total"] += int(rgb.nbytes) + return rgb + + if isinstance(x, dict): + return {k: _maybe_decode_payload(v) for k, v in x.items()} + if isinstance(x, list): + return [_maybe_decode_payload(v) for v in x] + if isinstance(x, tuple): + return tuple(_maybe_decode_payload(v) for v in x) + return x + + return _maybe_decode_payload(observation), stats diff --git a/src/lerobot/async_inference/utils/latency_estimation.py b/src/lerobot/async_inference/utils/latency_estimation.py new file mode 100644 index 00000000000..ca7cc875204 --- /dev/null +++ b/src/lerobot/async_inference/utils/latency_estimation.py @@ -0,0 +1,278 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Latency estimation classes for async inference. + +Provides abstract base class and implementations for estimating round-trip latency +in the DRTC algorithm. +""" + +import math +from abc import ABC, abstractmethod +from collections import deque + + +class LatencyEstimatorBase(ABC): + """Abstract base class for latency estimators. + + The estimate_steps property enforces the RTC constraint: d <= H/2 + where d is the inference delay and H is the prediction horizon (action_chunk_size). + With s = d (maximum soft masking), the constraint d <= H - s becomes d <= H/2. + """ + + def __init__( + self, + fps: float, + action_chunk_size: int | None = None, + s_min: int = 1, + ): + """Initialize the latency estimator. + + Args: + fps: Control loop frequency for quantizing to action steps. + action_chunk_size: Prediction horizon H (number of actions per chunk). + If provided, enables upper bound clamping to H/2. + s_min: Minimum execution horizon in steps. Used as the pre-measurement + fallback so estimate_steps returns s_min before any real RTT arrives. + """ + self._fps = fps + self._action_chunk_size = action_chunk_size + self._s_min = s_min + + @property + def fps(self) -> float: + return self._fps + + @abstractmethod + def update(self, measured_rtt: float) -> None: + """Update the latency estimate with a new RTT measurement.""" + ... + + @property + @abstractmethod + def estimate_seconds(self) -> float: + """Get the latency estimate in seconds.""" + ... + + @property + def estimate_steps(self) -> int: + """Get the latency estimate quantized to action steps. + + Upper-bounded by H/2 per RTC constraint: with s = d, d <= H - s becomes d <= H/2. + This ensures real-time execution is achievable. If the actual delay exceeds + this bound, the system gracefully degrades to synchronous-with-inpainting behavior. + """ + raw = max(1, math.ceil(self.estimate_seconds * self._fps)) + if self._action_chunk_size is not None: + d_max = self._action_chunk_size // 2 + return min(raw, max(1, d_max)) + return raw + + @abstractmethod + def reset(self) -> None: + """Reset the estimator state.""" + ... + + +class JKLatencyEstimator(LatencyEstimatorBase): + """Jacobson-Karels style latency estimator with exponential smoothing. + + Maintains a smoothed mean and deviation estimate of round-trip latency, + combining them to produce a conservative estimate that adapts to variance. + + Attributes: + fps: Control loop frequency for quantizing to action steps. + alpha: Smoothing factor for mean (default 0.125 per RFC 6298). + beta: Smoothing factor for deviation (default 0.25 per RFC 6298). + k: Scaling factor for deviation in estimate (paper suggests K=1 for faster recovery). + """ + + def __init__( + self, + fps: float, + alpha: float = 0.125, + beta: float = 0.25, + k: float = 1.0, + action_chunk_size: int | None = None, + s_min: int = 1, + ): + super().__init__(fps, action_chunk_size, s_min=s_min) + self.alpha = alpha + self.beta = beta + self.k = k + self.smoothed_rtt: float = 0.0 + self.rtt_deviation: float = 0.0 + self._initialized: bool = False + + def update(self, measured_rtt: float) -> None: + """Update the latency estimate with a new RTT measurement.""" + if not self._initialized: + self.smoothed_rtt = measured_rtt + self.rtt_deviation = 0 + self._initialized = True + return + + error = measured_rtt - self.smoothed_rtt + self.smoothed_rtt = (1 - self.alpha) * self.smoothed_rtt + self.alpha * measured_rtt + self.rtt_deviation = (1 - self.beta) * self.rtt_deviation + self.beta * abs(error) + + @property + def estimate_seconds(self) -> float: + """Get the latency estimate in seconds: ℓ̂ = ℓ̄ + K·σ""" + if not self._initialized: + return self._s_min / self._fps + return self.smoothed_rtt + self.k * self.rtt_deviation + + def reset(self) -> None: + """Reset the estimator state.""" + self.smoothed_rtt = 0.0 + self.rtt_deviation = 0.0 + self._initialized = False + + +class MaxLast10Estimator(LatencyEstimatorBase): + """Conservative latency estimator using max of last 10 measurements (RTC-style). + + Returns the maximum RTT observed in the last 10 measurements, providing a + conservative bound that is less adaptive but more stable under spikes. + """ + + def __init__( + self, + fps: float, + window_size: int = 10, + action_chunk_size: int | None = None, + s_min: int = 1, + ): + super().__init__(fps, action_chunk_size, s_min=s_min) + self._window_size = window_size + self._buffer: deque[float] = deque(maxlen=window_size) + + def update(self, measured_rtt: float) -> None: + """Add a new RTT measurement to the window.""" + self._buffer.append(measured_rtt) + + @property + def estimate_seconds(self) -> float: + """Get the latency estimate as max of last N measurements.""" + if not self._buffer: + return self._s_min / self._fps + return max(self._buffer) + + def reset(self) -> None: + """Reset the estimator state.""" + self._buffer.clear() + + +class FixedLatencyEstimator(LatencyEstimatorBase): + """Fixed latency estimator for baseline comparisons (SmolVLA-style). + + Returns a fixed, user-specified latency estimate regardless of measurements. + This represents the behavior of systems that assume a constant network latency + rather than adapting to actual conditions. + + Note: The estimate is still quantized to at least 1 action step, and + upper-bounded by H/2 if action_chunk_size is provided. + """ + + def __init__( + self, + fps: float, + fixed_latency_s: float = 0.1, + action_chunk_size: int | None = None, + s_min: int = 1, + ): + """Initialize with a fixed latency value. + + Args: + fps: Control loop frequency. + fixed_latency_s: Fixed latency estimate in seconds (default 100ms). + action_chunk_size: Prediction horizon H for upper bound clamping to H/2. + s_min: Minimum execution horizon (unused by fixed estimator, passed to base). + """ + super().__init__(fps, action_chunk_size, s_min=s_min) + self._fixed_latency_s = fixed_latency_s + + def update(self, measured_rtt: float) -> None: + """No-op: fixed estimator ignores measurements.""" + pass + + @property + def estimate_seconds(self) -> float: + """Get the fixed latency estimate in seconds.""" + return self._fixed_latency_s + + def reset(self) -> None: + """No-op: fixed estimator has no state to reset.""" + pass + + +# Backwards compatibility alias +LatencyEstimator = JKLatencyEstimator + + +def make_latency_estimator( + kind: str, + fps: float, + alpha: float = 0.125, + beta: float = 0.25, + k: float = 1.0, + fixed_latency_s: float = 0.1, + action_chunk_size: int | None = None, + s_min: int = 1, +) -> LatencyEstimatorBase: + """Factory function to create a latency estimator. + + Args: + kind: Type of estimator: + - "jk": Jacobson-Karels (adaptive, fast recovery) + - "max_last_10": Max of last 10 (conservative, RTC-style) + - "fixed": Fixed latency (non-adaptive baseline) + fps: Control loop frequency. + alpha: JK smoothing factor for mean. + beta: JK smoothing factor for deviation. + k: JK scaling factor for deviation. + fixed_latency_s: Fixed latency in seconds (only used if kind="fixed"). + action_chunk_size: Prediction horizon H. If provided, enables upper bound + clamping to H/2 per RTC constraint (with s = d, d <= H - s becomes d <= H/2). + s_min: Minimum execution horizon in steps. Used as the pre-measurement + fallback so estimate_steps returns s_min before any real RTT arrives. + + Returns: + A latency estimator instance. + """ + if kind == "jk": + return JKLatencyEstimator( + fps=fps, + alpha=alpha, + beta=beta, + k=k, + action_chunk_size=action_chunk_size, + s_min=s_min, + ) + elif kind == "max_last_10": + return MaxLast10Estimator( + fps=fps, + action_chunk_size=action_chunk_size, + s_min=s_min, + ) + elif kind == "fixed": + return FixedLatencyEstimator( + fps=fps, + fixed_latency_s=fixed_latency_s, + action_chunk_size=action_chunk_size, + s_min=s_min, + ) + else: + raise ValueError(f"Unknown latency estimator type: {kind}. Use 'jk', 'max_last_10', or 'fixed'.") diff --git a/src/lerobot/async_inference/utils/metrics.py b/src/lerobot/async_inference/utils/metrics.py new file mode 100644 index 00000000000..c6db105e71e --- /dev/null +++ b/src/lerobot/async_inference/utils/metrics.py @@ -0,0 +1,802 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metrics collection for async inference. + +This module provides two categories of metrics: +- **experiment**: per-tick metrics + trajectory data written to disk (CSV + JSON). + This is the current/default behavior and should remain stable for experiments. +- **diagnostic**: lightweight timing/counter summaries printed to the console + (avg/max only; no percentiles). +""" + +from __future__ import annotations + +import csv +import json +import logging +import threading +import time +from collections import defaultdict, deque +from contextlib import contextmanager, suppress +from dataclasses import dataclass +from pathlib import Path +from queue import Empty, Full, Queue +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Trajectory-viz event types (runtime streaming) +# ============================================================================= + + +class EvActionChunk(tuple): + """Action chunk event for real-time trajectory visualization. + + Kept as a lightweight tuple-like object for low-overhead passing between threads. + """ + + __slots__ = () + _fields = ("src_control_step", "actions", "frozen_len", "timestamp", "rtc_params", "prefix_weights") + + def __new__( + cls, + *, + src_control_step: int, + actions: list[list[float]], + frozen_len: int, + timestamp: float, + rtc_params: dict | None = None, + prefix_weights: list[float] | None = None, + ): + return tuple.__new__( + cls, + (src_control_step, actions, frozen_len, timestamp, rtc_params, prefix_weights), + ) + + @property + def src_control_step(self) -> int: # noqa: D401 + return self[0] + + @property + def actions(self) -> list[list[float]]: + return self[1] + + @property + def frozen_len(self) -> int: + return self[2] + + @property + def timestamp(self) -> float: + return self[3] + + @property + def rtc_params(self) -> dict | None: + return self[4] + + @property + def prefix_weights(self) -> list[float] | None: + return self[5] + + +class EvExecutedAction(tuple): + """Single executed action event for real-time visualization.""" + + __slots__ = () + _fields = ("step", "action", "timestamp") + + def __new__(cls, *, step: int, action: list[float], timestamp: float): + return tuple.__new__(cls, (step, action, timestamp)) + + @property + def step(self) -> int: + return self[0] + + @property + def action(self) -> list[float]: + return self[1] + + @property + def timestamp(self) -> float: + return self[2] + + +# ============================================================================= +# Experiment metrics (disk output; keep stable) +# ============================================================================= + + +@dataclass +class ExperimentTick: + """Single tick of experiment data.""" + + t: float # Wall-clock timestamp (Unix seconds) + step: int # Action step n(t) + schedule_size: int # |ψ(t)| + latency_estimate_steps: int # ℓ̂_Δ + latency_estimate_ms: float # ℓ̂ in milliseconds (unquantized) + cooldown: int # O^c(t) + stall: int # 1 if schedule_size == 0, else 0 + obs_triggered: int # 1 if obs request triggered this tick + action_received: int # 1 if action chunk merged this tick + measured_latency_ms: float | None # RTT of received chunk (if any) + # Raw wall-clock timestamps for the round-trip journey (Unix seconds). + # Durations can be derived: c2s = server_obs_received_ts - obs_sent_ts, etc. + obs_sent_ts: float | None # Client observation send timestamp + server_obs_received_ts: float | None # Server observation receive timestamp + server_action_sent_ts: float | None # Server action send timestamp + action_received_ts: float | None # Client action receive timestamp + # Action discontinuity metrics (L2 distance between overlapping chunks) + chunk_overlap_count: int | None # Number of overlapping actions compared + chunk_mean_l2: float | None # Mean L2 distance across overlapping actions + chunk_max_l2: float | None # Max L2 distance across overlapping actions + + +@dataclass +class TrajectoryChunk: + """Recorded action chunk for trajectory visualization.""" + + src_control_step: int # Chunk provenance (control step t that triggered inference) + actions: list[list[float]] # (T, A) action chunk as nested list + frozen_len: int # Number of frozen actions in this chunk + t: float # Timestamp (Unix seconds) + chunk_start_step: int | None = None # Start step of this chunk (provenance) + + +@dataclass +class ExecutedAction: + """Recorded executed action for trajectory visualization.""" + + step: int # Action step number + action: list[float] # Action values (one per joint) + src_control_step: int # Control step t that produced this action (provenance) + chunk_start_step: int # Start step of the source chunk (provenance) + t: float # Timestamp (Unix seconds) + + +@dataclass +class SimEvent: + """A recorded simulation event (drop, reorder, duplicate, etc.).""" + + event_type: str # e.g. "obs_dropped", "action_dropped", "obs_reorder_held", etc. + t: float # Timestamp (Unix seconds) + + +@dataclass +class RegisterEvent: + """A recorded LWW register write attempt (client-side).""" + + t: float # Wall-clock timestamp (Unix seconds) + register_name: str # e.g. "client_obs_request", "client_action" + control_step: int # The control_step used as the LWW key + chunk_start_step: int | None # Only meaningful for action registers + accepted: bool # Whether update_if_newer accepted the write + + +class ExperimentMetricsWriter: + """Collects per-tick experiment metrics and writes to CSV. + + Also collects trajectory data (action chunks and executed actions) + for post-hoc visualization of how chunks overlap and transition. + + Memory is bounded: + - ``_ticks`` are auto-flushed to CSV when the buffer exceeds + ``auto_flush_threshold`` (default 50 000 ≈ 16 min @ 50 Hz). + - ``_chunks`` and ``_executed`` use bounded deques (most-recent data + kept; oldest evicted). + - Running summary counters survive auto-flushes so ``get_summary()`` + always covers the full run. + """ + + _CSV_FIELDNAMES = [ + "t", + "step", + "schedule_size", + "latency_estimate_steps", + "latency_estimate_ms", + "cooldown", + "stall", + "obs_triggered", + "action_received", + "measured_latency_ms", + "obs_sent_ts", + "server_obs_received_ts", + "server_action_sent_ts", + "action_received_ts", + "chunk_overlap_count", + "chunk_mean_l2", + "chunk_max_l2", + ] + + def __init__( + self, + path: str | Path | None = None, + auto_flush_threshold: int = 50_000, + max_trajectory_entries: int = 10_000, + simulation_config: dict | None = None, + experiment_config: dict | None = None, + ): + self._path: Path | None = Path(path) if path else None + self._auto_flush_threshold = auto_flush_threshold + self._simulation_config: dict = simulation_config or {} + self._experiment_config: dict = experiment_config or {} + + # Lock to serialise flush operations. signal_stop() and stop() + # can both call flush() from different threads; without a lock the + # same ticks are written twice (once in "w" mode, once in "a") and + # the resulting CSV contains duplicate/corrupted rows. + self._flush_lock = threading.Lock() + + # Tick buffer (flushed periodically to CSV) + self._ticks: list[ExperimentTick] = [] + self._csv_header_written = False + + # Trajectory buffers (bounded deques — most recent data kept) + self._chunks: deque[TrajectoryChunk] = deque(maxlen=max_trajectory_entries) + self._executed: deque[ExecutedAction] = deque(maxlen=max_trajectory_entries) + + # Simulation event log (bounded deque) + self._sim_events: deque[SimEvent] = deque(maxlen=max_trajectory_entries) + + # LWW register event log (bounded deque) + self._register_events: deque[RegisterEvent] = deque(maxlen=max_trajectory_entries) + + # Running summary counters (survive auto-flushes) + self._total_ticks: int = 0 + self._total_stalls: int = 0 + self._total_obs_triggered: int = 0 + self._total_action_received: int = 0 + self._l2_count: int = 0 + self._l2_mean_sum: float = 0.0 + self._l2_mean_max: float = 0.0 + self._l2_max_max: float = 0.0 + + # ------------------------------------------------------------------ + # Recording + # ------------------------------------------------------------------ + + def record_tick( + self, + *, + step: int, + schedule_size: int, + latency_estimate_steps: int, + latency_estimate_ms: float, + cooldown: int, + obs_triggered: bool = False, + action_received: bool = False, + measured_latency_ms: float | None = None, + obs_sent_ts: float | None = None, + server_obs_received_ts: float | None = None, + server_action_sent_ts: float | None = None, + action_received_ts: float | None = None, + chunk_overlap_count: int | None = None, + chunk_mean_l2: float | None = None, + chunk_max_l2: float | None = None, + ) -> None: + """Record a single tick of experiment data.""" + tick = ExperimentTick( + t=time.time(), + step=step, + schedule_size=schedule_size, + latency_estimate_steps=latency_estimate_steps, + latency_estimate_ms=latency_estimate_ms, + cooldown=cooldown, + stall=1 if schedule_size == 0 else 0, + obs_triggered=1 if obs_triggered else 0, + action_received=1 if action_received else 0, + measured_latency_ms=measured_latency_ms, + obs_sent_ts=obs_sent_ts, + server_obs_received_ts=server_obs_received_ts, + server_action_sent_ts=server_action_sent_ts, + action_received_ts=action_received_ts, + chunk_overlap_count=chunk_overlap_count, + chunk_mean_l2=chunk_mean_l2, + chunk_max_l2=chunk_max_l2, + ) + self._ticks.append(tick) + + # Update running summary counters + self._total_ticks += 1 + if schedule_size == 0: + self._total_stalls += 1 + if obs_triggered: + self._total_obs_triggered += 1 + if action_received: + self._total_action_received += 1 + if chunk_mean_l2 is not None: + self._l2_count += 1 + self._l2_mean_sum += chunk_mean_l2 + self._l2_mean_max = max(self._l2_mean_max, chunk_mean_l2) + if chunk_max_l2 is not None: + self._l2_max_max = max(self._l2_max_max, chunk_max_l2) + + # Auto-flush when buffer exceeds threshold + if len(self._ticks) >= self._auto_flush_threshold: + self._auto_flush_ticks() + + def record_chunk( + self, + *, + src_control_step: int, + actions: list[np.ndarray] | list[list[float]], + frozen_len: int, + chunk_start_step: int | None = None, + ) -> None: + """Record an action chunk for trajectory visualization. + + Args: + src_control_step: The control step t that triggered this chunk's inference. + actions: List of action arrays (T, A) - can be numpy arrays or lists. + frozen_len: Number of frozen actions in this chunk. + chunk_start_step: The start step of this chunk (provenance). + """ + # Convert numpy arrays to lists for JSON serialization + actions_list: list[list[float]] = [] + for action in actions: + if isinstance(action, np.ndarray): + actions_list.append(action.tolist()) + else: + actions_list.append(list(action)) + + chunk = TrajectoryChunk( + src_control_step=src_control_step, + actions=actions_list, + frozen_len=frozen_len, + t=time.time(), + chunk_start_step=chunk_start_step, + ) + self._chunks.append(chunk) # deque evicts oldest automatically + + def record_executed_action( + self, + *, + step: int, + action: np.ndarray | list[float], + src_control_step: int, + chunk_start_step: int, + ) -> None: + """Record an executed action for trajectory visualization. + + Args: + step: The action step number. + action: The action values sent to the robot. + src_control_step: The control step t that produced this action. + chunk_start_step: The start step of the source chunk. + """ + action_list = action.tolist() if isinstance(action, np.ndarray) else list(action) + + executed = ExecutedAction( + step=step, + action=action_list, + src_control_step=src_control_step, + chunk_start_step=chunk_start_step, + t=time.time(), + ) + self._executed.append(executed) # deque evicts oldest automatically + + def record_sim_event(self, event_type: str) -> None: + """Record a simulation event (drop, reorder, duplicate, etc.). + + Args: + event_type: Event identifier, e.g. ``"obs_dropped"``, ``"action_dropped"``. + """ + self._sim_events.append(SimEvent(event_type=event_type, t=time.time())) + + def record_register_event( + self, + *, + register_name: str, + control_step: int, + accepted: bool, + chunk_start_step: int | None = None, + ) -> None: + """Record an LWW register write attempt. + + Args: + register_name: Identifier for the register, e.g. ``"client_obs_request"``. + control_step: The control_step used as the LWW key. + accepted: Whether ``update_if_newer`` accepted the write. + chunk_start_step: Start step of the source chunk (action registers only). + """ + self._register_events.append( + RegisterEvent( + t=time.time(), + register_name=register_name, + control_step=control_step, + chunk_start_step=chunk_start_step, + accepted=accepted, + ) + ) + + # ------------------------------------------------------------------ + # Flushing + # ------------------------------------------------------------------ + + @staticmethod + def _tick_to_row(tick: ExperimentTick) -> dict[str, Any]: + """Convert an ExperimentTick to a CSV-row dict.""" + return { + "t": tick.t, + "step": tick.step, + "schedule_size": tick.schedule_size, + "latency_estimate_steps": tick.latency_estimate_steps, + "latency_estimate_ms": tick.latency_estimate_ms, + "cooldown": tick.cooldown, + "stall": tick.stall, + "obs_triggered": tick.obs_triggered, + "action_received": tick.action_received, + "measured_latency_ms": tick.measured_latency_ms if tick.measured_latency_ms is not None else "", + "obs_sent_ts": tick.obs_sent_ts if tick.obs_sent_ts is not None else "", + "server_obs_received_ts": tick.server_obs_received_ts + if tick.server_obs_received_ts is not None + else "", + "server_action_sent_ts": tick.server_action_sent_ts + if tick.server_action_sent_ts is not None + else "", + "action_received_ts": tick.action_received_ts if tick.action_received_ts is not None else "", + "chunk_overlap_count": tick.chunk_overlap_count if tick.chunk_overlap_count is not None else "", + "chunk_mean_l2": tick.chunk_mean_l2 if tick.chunk_mean_l2 is not None else "", + "chunk_max_l2": tick.chunk_max_l2 if tick.chunk_max_l2 is not None else "", + } + + def _auto_flush_ticks(self) -> None: + """Incrementally flush buffered ticks to CSV and clear the buffer. + + Thread-safe: uses ``_flush_lock`` so that concurrent calls from + ``signal_stop()`` (timer thread) and ``stop()`` (main thread) are + serialised and each tick is written exactly once. + """ + with self._flush_lock: + if not self._path or not self._ticks: + return + self._path.parent.mkdir(parents=True, exist_ok=True) + + # Snapshot and clear under lock so no tick is written twice. + ticks_to_write = list(self._ticks) + self._ticks.clear() + + mode = "a" if self._csv_header_written else "w" + with open(self._path, mode, newline="") as f: + writer = csv.DictWriter(f, fieldnames=self._CSV_FIELDNAMES) + if not self._csv_header_written: + writer.writeheader() + self._csv_header_written = True + for tick in ticks_to_write: + writer.writerow(self._tick_to_row(tick)) + + def flush(self, path: str | Path | None = None) -> None: + """Write remaining metrics to disk. + + Args: + path: Override path (updates the stored path). If *None*, uses + the path provided at construction time. + """ + if path is not None: + self._path = Path(path) + if self._path is None: + return + + # Drain any remaining ticks to CSV + self._auto_flush_ticks() + + # Write trajectory data to JSON (from bounded deques) + if self._chunks or self._executed: + trajectory_path = self._path.with_suffix(".trajectory.json") + trajectory_data: dict[str, Any] = { + "chunks": [ + { + "source_step": c.src_control_step, + "actions": c.actions, + "frozen_len": c.frozen_len, + "t": c.t, + "chunk_start_step": c.chunk_start_step, + } + for c in self._chunks + ], + "executed": [ + { + "step": e.step, + "action": e.action, + "src_control_step": e.src_control_step, + "chunk_start_step": e.chunk_start_step, + "t": e.t, + } + for e in self._executed + ], + "sim_events": [{"event_type": ev.event_type, "t": ev.t} for ev in self._sim_events], + "register_events": [ + { + "t": rev.t, + "register_name": rev.register_name, + "control_step": rev.control_step, + "chunk_start_step": rev.chunk_start_step, + "accepted": rev.accepted, + } + for rev in self._register_events + ], + } + # Include the full simulation config so the plotter can overlay + # configured windows/events (drops, spikes, duplicates, reorder). + if self._simulation_config: + trajectory_data["simulation_config"] = self._simulation_config + # Include experiment config (policy, latency, filter params) for + # the plotter to render a configuration table in the LaTeX output. + if self._experiment_config: + trajectory_data["experiment_config"] = self._experiment_config + with open(trajectory_path, "w") as f: + json.dump(trajectory_data, f) + + # ------------------------------------------------------------------ + # Summary + # ------------------------------------------------------------------ + + def get_summary(self) -> dict[str, Any]: + """Get summary statistics from collected data. + + Uses running counters so the summary covers the full run even + after auto-flushes have cleared the tick buffer. + """ + if self._total_ticks == 0: + return {} + + summary: dict[str, Any] = { + "total_ticks": self._total_ticks, + "stall_count": self._total_stalls, + "stall_fraction": self._total_stalls / self._total_ticks, + "obs_triggered_count": self._total_obs_triggered, + "action_received_count": self._total_action_received, + } + + # Add L2 discrepancy summary if we have data + if self._l2_count > 0: + summary["mean_l2_avg"] = self._l2_mean_sum / self._l2_count + summary["mean_l2_max"] = self._l2_mean_max + summary["chunk_count"] = self._l2_count + summary["max_l2_max"] = self._l2_max_max + + # Add trajectory stats (bounded deque sizes, not full-run counts) + summary["trajectory_chunks"] = len(self._chunks) + summary["trajectory_executed"] = len(self._executed) + + return summary + + +# ============================================================================= +# Diagnostic metrics (console output; avg/max only) +# ============================================================================= + + +def _format_avg_max(values: list[float]) -> str: + if not values: + return "n/a" + avg = float(sum(values) / len(values)) + vmax = float(max(values)) + return f"{avg:.2f}/{vmax:.2f}" + + +class _EvTiming(tuple): + __slots__ = () + + def __new__(cls, name: str, ms: float): + return tuple.__new__(cls, (name, float(ms))) + + @property + def name(self) -> str: + return self[0] + + @property + def ms(self) -> float: + return self[1] + + +class _EvCounter(tuple): + __slots__ = () + + def __new__(cls, name: str, inc: int): + return tuple.__new__(cls, (name, int(inc))) + + @property + def name(self) -> str: + return self[0] + + @property + def inc(self) -> int: + return self[1] + + +class _EvContext(tuple): + __slots__ = () + + def __new__(cls, ctx: dict[str, Any]): + return tuple.__new__(cls, (ctx,)) + + @property + def ctx(self) -> dict[str, Any]: + return self[0] + + +class DiagnosticMetrics: + """Lossy, queue-based diagnostic metrics with periodic console summaries.""" + + def __init__( + self, + *, + fps: int, + window_s: float = 10.0, + interval_s: float = 2.0, + enabled: bool = True, + verbose: bool = False, + prefix: str = "DIAG", + ): + self._enabled = bool(enabled) + self._fps = int(fps) + self._window_s = float(window_s) + self._interval_s = float(interval_s) + self._verbose = bool(verbose) + self._prefix = str(prefix) + + self._shutdown = threading.Event() + self._queue: Queue = Queue(maxsize=4096) + self._thread: threading.Thread | None = None + + @staticmethod + def _ms(seconds: float) -> float: + return seconds * 1000.0 + + def start(self) -> None: + if not self._enabled: + return + if self._thread is not None and self._thread.is_alive(): + return + self._thread = threading.Thread( + target=self._consumer_loop, + name="metrics_diagnostic_consumer", + daemon=True, + ) + self._thread.start() + + def stop(self, timeout_s: float = 1.0) -> None: + if not self._enabled: + return + self._shutdown.set() + if self._thread is not None: + self._thread.join(timeout=timeout_s) + + def timing_ms(self, name: str, ms: float) -> None: + if not self._enabled: + return + with suppress(Full): + self._queue.put_nowait(_EvTiming(str(name), float(ms))) + + def timing_s(self, name: str, seconds: float) -> None: + self.timing_ms(name, self._ms(seconds)) + + def counter(self, name: str, inc: int = 1) -> None: + if not self._enabled: + return + with suppress(Full): + self._queue.put_nowait(_EvCounter(str(name), int(inc))) + + def set_context(self, **ctx: Any) -> None: + if not self._enabled: + return + with suppress(Full): + self._queue.put_nowait(_EvContext(dict(ctx))) + + @contextmanager + def time_block(self, name: str): + if not self._enabled: + yield + return + t0 = time.perf_counter() + try: + yield + finally: + self.timing_s(name, time.perf_counter() - t0) + + def _consumer_loop(self) -> None: + maxlen = max(10, int(self._fps * self._window_s)) + timings: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=maxlen)) + counters_total: dict[str, int] = defaultdict(int) + latest_ctx: dict[str, Any] = {} + + last_emit = time.perf_counter() + while not self._shutdown.is_set(): + try: + ev = self._queue.get(timeout=0.1) + except Empty: + ev = None + + if isinstance(ev, _EvTiming): + timings[ev.name].append(ev.ms) + elif isinstance(ev, _EvCounter): + counters_total[ev.name] += ev.inc + elif isinstance(ev, _EvContext): + latest_ctx = ev.ctx + + now = time.perf_counter() + if (now - last_emit) < self._interval_s: + continue + + last_emit = now + + # Default: compact summary of core fields + total RTT. + core_keys = ["step", "schedule_size", "latency_steps", "cooldown", "chunk_size", "s_min", "fps"] + core_ctx = " ".join(f"{k}={latest_ctx[k]}" for k in core_keys if k in latest_ctx) + + rtt_key = "total_latency_rtt_ms" + rtt_part = "" + if rtt_key in timings: + rtt_part = f"{rtt_key}(avg/max)={_format_avg_max(list(timings[rtt_key]))}" + else: + rtt_part = f"{rtt_key}(avg/max)=n/a" + + if not self._verbose: + # Skip emit when there is no data yet (startup period before + # the control loop or action receiver have produced any events). + if not core_ctx and rtt_key not in timings: + continue + parts = [p for p in [core_ctx, rtt_part] if p] + logger.info(f"{self._prefix} | " + " ".join(parts)) + continue + + # Verbose: include all context keys plus timing/counter details. + # Skip emit when there is no data yet (startup period). + if not latest_ctx and not timings and not counters_total: + continue + ctx_part = " ".join(f"{k}={v}" for k, v in latest_ctx.items()) + + # Prefer a stable ordering for common names; append others alphabetically. + preferred = [ + rtt_key, + "loop_dt_ms", + "phase_exec_ms", + "phase_trigger_ms", + "phase_merge_ms", + "send_action_ms", + "obs_wait_ms", + "obs_capture_ms", + "obs_encode_ms", + "obs_send_ms", + "rpc_ms", + "deser_ms", + "rtc_build_ms", + "chunk_gap_ms", + "policy_predict_ms", + "infer_total_ms", + "policy_load_ms", + "obs_recv_ms", + "obs_decode_ms", + ] + remaining = sorted([k for k in timings if k not in preferred]) + keys = [k for k in preferred if k in timings] + remaining + + timing_part = " ".join(f"{k}(avg/max)={_format_avg_max(list(timings[k]))}" for k in keys) + counter_part = " ".join(f"{k}={v}" for k, v in sorted(counters_total.items())) + + parts = [p for p in [ctx_part, timing_part, counter_part] if p] + logger.info(f"{self._prefix} | " + " | ".join(parts)) + + +@dataclass +class Metrics: + """Bundle of experiment (disk) + diagnostic (console) metrics.""" + + experiment: ExperimentMetricsWriter | None = None + diagnostic: DiagnosticMetrics | None = None diff --git a/src/lerobot/async_inference/utils/simulation.py b/src/lerobot/async_inference/utils/simulation.py new file mode 100644 index 00000000000..1cc33d44347 --- /dev/null +++ b/src/lerobot/async_inference/utils/simulation.py @@ -0,0 +1,662 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulation helpers for async inference experiments. + +Provides mock robot, drop simulation, and latency spike simulation +for testing and experimentation without real hardware. +""" + +import time +from dataclasses import dataclass, field +from typing import Any + +import numpy as np + +# ============================================================================= +# Configuration Dataclasses +# ============================================================================= + + +@dataclass +class SpikeEvent: + """A single spike event that fires once at a specific time. + + Attributes: + start_s: When to trigger the spike (seconds from experiment start) + delay_ms: How much delay to add when triggered (milliseconds) + """ + + start_s: float # When to trigger (seconds from start) + delay_ms: float # How much delay to add (milliseconds) + + +@dataclass +class SpikeDelayConfig: + """Configuration for explicit spike injection. + + Example usage: + # Add 2s spike at 5s and 1s spike at 15s into the experiment + config = SpikeDelayConfig( + spikes=[ + SpikeEvent(start_s=5.0, delay_ms=2000), + SpikeEvent(start_s=15.0, delay_ms=1000), + ] + ) + + # Or from dicts (for JSON/CLI compatibility) + config = SpikeDelayConfig.from_dicts([ + {"start_s": 5.0, "delay_ms": 2000}, + {"start_s": 15.0, "delay_ms": 1000}, + ]) + """ + + spikes: list[SpikeEvent] = field(default_factory=list) + + @classmethod + def from_dicts(cls, spike_dicts: list[dict]) -> "SpikeDelayConfig": + """Create config from list of dicts (for JSON/CLI compatibility).""" + spikes = [SpikeEvent(start_s=d["start_s"], delay_ms=d["delay_ms"]) for d in spike_dicts] + return cls(spikes=spikes) + + +@dataclass +class DropEvent: + """A single drop event with start time and duration. + + Attributes: + start_s: When to start dropping (seconds from experiment start) + duration_s: How long to drop (seconds) + """ + + start_s: float # When to start dropping (seconds from start) + duration_s: float # How long to drop (seconds) + + +@dataclass +class DropConfig: + """Configuration for drop injection using explicit time-based events. + + Example usage: + # Drop for 1 second starting at 5s into the experiment + config = DropConfig(drops=[ + DropEvent(start_s=5.0, duration_s=1.0), + ]) + + # Multiple drop events + config = DropConfig(drops=[ + DropEvent(start_s=5.0, duration_s=1.0), + DropEvent(start_s=15.0, duration_s=2.0), + ]) + + # Or from dicts (for JSON/CLI compatibility) + config = DropConfig.from_dicts([ + {"start_s": 5.0, "duration_s": 1.0}, + {"start_s": 15.0, "duration_s": 2.0}, + ]) + """ + + drops: list[DropEvent] = field(default_factory=list) + + @classmethod + def from_dicts(cls, drop_dicts: list[dict]) -> "DropConfig": + """Create config from list of dicts (for JSON/CLI compatibility).""" + drops = [DropEvent(start_s=d["start_s"], duration_s=d["duration_s"]) for d in drop_dicts] + return cls(drops=drops) + + +# ============================================================================= +# Mock Robot +# ============================================================================= + + +class MockRobot: + """Mock robot for simulation experiments (no hardware required).""" + + def __init__(self, action_dim: int = 6, state_dim: int = 6): + self._action_dim = action_dim + self._state_dim = state_dim + self._connected = False + self._step = 0 + + @property + def action_features(self) -> list[str]: + return [f"joint_{i}" for i in range(self._action_dim)] + + @property + def state_features(self) -> list[str]: + return [f"state_{i}" for i in range(self._state_dim)] + + def connect(self) -> None: + self._connected = True + + def disconnect(self) -> None: + self._connected = False + + def get_observation(self) -> dict[str, Any]: + """Return synthetic observation (random state, placeholder images).""" + self._step += 1 + obs = {} + # Random joint state (scalar float values) + for feat in self.state_features: + obs[feat] = float(np.random.randn()) + # Placeholder image (small random RGB) + obs["camera1"] = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + return obs + + def send_action(self, action: dict[str, float]) -> None: + """No-op for mock robot.""" + pass + + +# ============================================================================= +# Drop Simulator +# ============================================================================= + + +class DropSimulator: + """Simulates explicit time-based drops for observations/actions. + + Drops occur during specified time windows relative to experiment start. + + Example: + # Using DropConfig (preferred) + config = DropConfig(drops=[ + DropEvent(start_s=5.0, duration_s=1.0), + DropEvent(start_s=15.0, duration_s=2.0), + ]) + sim = DropSimulator(config=config) + + # Or from dicts + sim = DropSimulator.from_dicts([ + {"start_s": 5.0, "duration_s": 1.0}, + {"start_s": 15.0, "duration_s": 2.0}, + ]) + """ + + def __init__(self, config: DropConfig | None = None): + self._drops: list[DropEvent] = config.drops if config else [] + self._start_time: float | None = None + + @classmethod + def from_dicts(cls, drop_dicts: list[dict]) -> "DropSimulator": + """Create simulator from list of dicts (for JSON/CLI compatibility).""" + config = DropConfig.from_dicts(drop_dicts) + return cls(config=config) + + def should_drop(self) -> bool: + """Check if the current event should be dropped. + + Returns True if current time falls within any drop event window. + """ + if not self._drops: + return False + + now = time.time() + if self._start_time is None: + self._start_time = now + + elapsed = now - self._start_time + + # Check if we're within any drop window + for drop in self._drops: + if drop.start_s <= elapsed < drop.start_s + drop.duration_s: + print("DROP OBSERVATION TRUE") + return True + + return False + + def reset(self) -> None: + """Reset the simulator start time.""" + self._start_time = None + + +# ============================================================================= +# Spike Delay Simulator +# ============================================================================= + + +class SpikeDelaySimulator: + """Simulates explicit latency spike events for experiments. + + Each spike fires once when the elapsed time crosses its start_s threshold, + adding the specified delay_ms of latency. + + Example: + config = SpikeDelayConfig(spikes=[ + SpikeEvent(start_s=5.0, delay_ms=2000), + SpikeEvent(start_s=15.0, delay_ms=1000), + ]) + sim = SpikeDelaySimulator(config=config) + + # Or from dicts + sim = SpikeDelaySimulator.from_dicts([ + {"start_s": 5.0, "delay_ms": 2000}, + {"start_s": 15.0, "delay_ms": 1000}, + ]) + """ + + def __init__(self, config: SpikeDelayConfig | None = None): + self._spikes: list[SpikeEvent] = config.spikes if config else [] + self._fired: set[int] = set() # Track which spike indices have fired + self._start_time: float | None = None + + @classmethod + def from_dicts(cls, spike_dicts: list[dict]) -> "SpikeDelaySimulator": + """Create simulator from list of dicts (for JSON/CLI compatibility).""" + config = SpikeDelayConfig.from_dicts(spike_dicts) + return cls(config=config) + + def get_delay(self) -> float: + """Get delay if a spike should fire now, else 0. + + Each spike fires exactly once when elapsed time crosses its start_s. + Returns the delay in seconds. + """ + if not self._spikes: + return 0.0 + + now = time.time() + if self._start_time is None: + self._start_time = now + + elapsed = now - self._start_time + + # Check each spike - fire once when elapsed >= start_s + for i, spike in enumerate(self._spikes): + if i not in self._fired and elapsed >= spike.start_s: + self._fired.add(i) + return spike.delay_ms / 1000.0 + + return 0.0 + + def apply_delay(self) -> None: + """Sleep for any pending spike delay.""" + delay = self.get_delay() + if delay > 0: + time.sleep(delay) + + def reset(self) -> None: + """Reset the simulator (clear start time and fired spikes).""" + self._start_time = None + self._fired.clear() + + def pending_spikes(self) -> int: + """Return count of spikes that haven't fired yet.""" + return len(self._spikes) - len(self._fired) + + +# ============================================================================= +# Duplicate Simulator +# ============================================================================= + + +@dataclass +class DuplicateEvent: + """A single duplicate event with start time and duration. + + Attributes: + start_s: When to start duplicating (seconds from experiment start) + duration_s: How long to duplicate (seconds) + """ + + start_s: float # When to start duplicating (seconds from start) + duration_s: float # How long to duplicate (seconds) + + +@dataclass +class DuplicateConfig: + """Configuration for duplicate injection using explicit time-based events. + + Example usage: + # Duplicate for 1 second starting at 5s into the experiment + config = DuplicateConfig(duplicates=[ + DuplicateEvent(start_s=5.0, duration_s=1.0), + ]) + + # Or from dicts (for JSON/CLI compatibility) + config = DuplicateConfig.from_dicts([ + {"start_s": 5.0, "duration_s": 1.0}, + ]) + """ + + duplicates: list[DuplicateEvent] = field(default_factory=list) + + @classmethod + def from_dicts(cls, dup_dicts: list[dict]) -> "DuplicateConfig": + """Create config from list of dicts (for JSON/CLI compatibility).""" + duplicates = [DuplicateEvent(start_s=d["start_s"], duration_s=d["duration_s"]) for d in dup_dicts] + return cls(duplicates=duplicates) + + +class DuplicateSimulator: + """Simulates explicit time-based duplicates for observations/actions. + + During a duplicate window, ``should_duplicate()`` returns True, causing + the caller to send/handle the same message a second time. The + server-side LWW registers and schedule merge absorb the duplicate + (same ``control_step`` / ``src_control_step``). + + Example: + config = DuplicateConfig(duplicates=[ + DuplicateEvent(start_s=5.0, duration_s=1.0), + ]) + sim = DuplicateSimulator(config=config) + + # Or from dicts + sim = DuplicateSimulator.from_dicts([ + {"start_s": 5.0, "duration_s": 1.0}, + ]) + """ + + def __init__(self, config: DuplicateConfig | None = None): + self._duplicates: list[DuplicateEvent] = config.duplicates if config else [] + self._start_time: float | None = None + + @classmethod + def from_dicts(cls, dup_dicts: list[dict]) -> "DuplicateSimulator": + """Create simulator from list of dicts (for JSON/CLI compatibility).""" + config = DuplicateConfig.from_dicts(dup_dicts) + return cls(config=config) + + def should_duplicate(self) -> bool: + """Check if the current message should be duplicated. + + Returns True if current time falls within any duplicate event window. + """ + if not self._duplicates: + return False + + now = time.time() + if self._start_time is None: + self._start_time = now + + elapsed = now - self._start_time + + return any(dup.start_s <= elapsed < dup.start_s + dup.duration_s for dup in self._duplicates) + + def reset(self) -> None: + """Reset the simulator start time.""" + self._start_time = None + + +# ============================================================================= +# Reorder Simulator (Hold-and-Swap) +# ============================================================================= + + +@dataclass +class ReorderEvent: + """A single reorder event with start time and duration. + + During the window the simulator holds one message and lets the next + pass through, then releases the held message -- creating a single + pairwise swap. + + Attributes: + start_s: When to start reordering (seconds from experiment start) + duration_s: How long the reorder window lasts (seconds) + """ + + start_s: float # When to start (seconds from start) + duration_s: float # Window duration (seconds) + + +@dataclass +class ReorderConfig: + """Configuration for reorder injection using explicit time-based events. + + Example usage: + config = ReorderConfig(reorders=[ + ReorderEvent(start_s=5.0, duration_s=2.0), + ]) + + # Or from dicts (for JSON/CLI compatibility) + config = ReorderConfig.from_dicts([ + {"start_s": 5.0, "duration_s": 2.0}, + ]) + """ + + reorders: list[ReorderEvent] = field(default_factory=list) + + @classmethod + def from_dicts(cls, reorder_dicts: list[dict]) -> "ReorderConfig": + """Create config from list of dicts (for JSON/CLI compatibility).""" + reorders = [ReorderEvent(start_s=d["start_s"], duration_s=d["duration_s"]) for d in reorder_dicts] + return cls(reorders=reorders) + + +class ReorderSimulator: + """Simulates pairwise message reordering (hold-and-swap). + + Design (Option A from the plan): + - Outside any reorder window: ``process()`` passes items straight through. + - When a reorder window opens: the *first* item is held, and the *second* + item passes through immediately. On the *third* call (or any call + after the window closes) the held item is released ahead of the new one, + completing the swap. + + This creates a single pairwise swap per window -- the simplest reordering + that still exercises the LWW join. + + The caller wraps the send/handle path: + + items = self._reorder_sim.process(item) + for i in items: + send(i) # may yield 0, 1, or 2 items + + Example: + config = ReorderConfig(reorders=[ + ReorderEvent(start_s=5.0, duration_s=2.0), + ]) + sim = ReorderSimulator(config=config) + """ + + def __init__(self, config: ReorderConfig | None = None): + self._reorders: list[ReorderEvent] = config.reorders if config else [] + self._start_time: float | None = None + # Hold buffer: at most one item held while waiting for the swap partner + self._held: Any = None + self._holding: bool = False + + @classmethod + def from_dicts(cls, reorder_dicts: list[dict]) -> "ReorderSimulator": + """Create simulator from list of dicts (for JSON/CLI compatibility).""" + config = ReorderConfig.from_dicts(reorder_dicts) + return cls(config=config) + + def _in_reorder_window(self) -> bool: + """Check if current time falls within any reorder window.""" + if not self._reorders: + return False + + now = time.time() + if self._start_time is None: + self._start_time = now + + elapsed = now - self._start_time + + for reorder in self._reorders: + if reorder.start_s <= elapsed < reorder.start_s + reorder.duration_s: + return True + + return False + + def process(self, item: Any) -> list: + """Process an item through the hold-and-swap reorder logic. + + Returns a list of 0, 1, or 2 items to send/handle in order: + - 0 items: the item is being held (first item in a swap) + - 1 item: normal pass-through, or the fresh item when the window + closes while holding (the stale held item is dropped) + - 2 items: the swap partner followed by the previously held item + (completing the swap within the window) + """ + if not self._reorders: + return [item] + + in_window = self._in_reorder_window() + + if self._holding: + # We're holding a message -- release it + held = self._held + self._held = None + self._holding = False + + if in_window: + # Still in window: complete the swap -- new item first, held second + return [item, held] + else: + # Window closed while holding: drop the stale held item, + # pass through only the fresh one. Sending both causes + # server-side inference queuing (the held item was never + # sent, so the server's LWW accepts it, and the inference + # producer processes it before the fresh obs). + return [item] + + if in_window: + # Enter hold state: hold this item, return nothing + self._held = item + self._holding = True + return [] + + # Outside window, nothing held: pass through + return [item] + + def reset(self) -> None: + """Reset the simulator state.""" + self._start_time = None + self._held = None + self._holding = False + + +# ============================================================================= +# Disconnect Simulator +# ============================================================================= + + +@dataclass +class DisconnectEvent: + """A single network disconnect event with start time and duration. + + Attributes: + start_s: When to start the disconnect (seconds from experiment start) + duration_s: How long the disconnect lasts (seconds) + """ + + start_s: float # When to start (seconds from start) + duration_s: float # How long the disconnect lasts (seconds) + + +@dataclass +class DisconnectConfig: + """Configuration for disconnect injection using explicit time-based events. + + A disconnect blocks *both* the observation sender and action receiver + threads for the configured duration, simulating a full network outage. + + Example usage: + config = DisconnectConfig(disconnects=[ + DisconnectEvent(start_s=5.0, duration_s=3.0), + ]) + + # Or from dicts (for JSON/CLI compatibility) + config = DisconnectConfig.from_dicts([ + {"start_s": 5.0, "duration_s": 3.0}, + ]) + """ + + disconnects: list[DisconnectEvent] = field(default_factory=list) + + @classmethod + def from_dicts(cls, disconnect_dicts: list[dict]) -> "DisconnectConfig": + """Create config from list of dicts (for JSON/CLI compatibility).""" + disconnects = [ + DisconnectEvent(start_s=d["start_s"], duration_s=d["duration_s"]) for d in disconnect_dicts + ] + return cls(disconnects=disconnects) + + +class DisconnectSimulator: + """Simulates network disconnects by blocking caller threads. + + Both the observation sender and action receiver should call + ``wait_if_disconnected()`` on each iteration. If the current time + falls inside a disconnect window the call **sleeps** until the window + ends and returns the sleep duration so the caller can record a sim + event. Outside any window it returns immediately with 0. + + Example: + config = DisconnectConfig(disconnects=[ + DisconnectEvent(start_s=5.0, duration_s=3.0), + ]) + sim = DisconnectSimulator(config=config) + + # In a sender/receiver loop: + slept = sim.wait_if_disconnected() + if slept > 0: + record_sim_event("disconnect") + continue + """ + + def __init__(self, config: DisconnectConfig | None = None): + self._disconnects: list[DisconnectEvent] = config.disconnects if config else [] + self._start_time: float | None = None + + @classmethod + def from_dicts(cls, disconnect_dicts: list[dict]) -> "DisconnectSimulator": + """Create simulator from list of dicts (for JSON/CLI compatibility).""" + config = DisconnectConfig.from_dicts(disconnect_dicts) + return cls(config=config) + + def _active_window_end(self) -> float | None: + """Return the end time (absolute) of the active disconnect window, or None.""" + if not self._disconnects: + return None + + now = time.time() + if self._start_time is None: + self._start_time = now + + elapsed = now - self._start_time + + for dc in self._disconnects: + if dc.start_s <= elapsed < dc.start_s + dc.duration_s: + # Return absolute wall-clock time when this window ends + return self._start_time + dc.start_s + dc.duration_s + + return None + + def is_disconnected(self) -> bool: + """Check if the network is currently disconnected.""" + return self._active_window_end() is not None + + def wait_if_disconnected(self) -> float: + """Block until the current disconnect window ends. + + Returns: + The number of seconds slept (0 if not disconnected). + """ + window_end = self._active_window_end() + if window_end is None: + return 0.0 + + sleep_s = max(0.0, window_end - time.time()) + if sleep_s > 0: + time.sleep(sleep_s) + return sleep_s + + def reset(self) -> None: + """Reset the simulator start time.""" + self._start_time = None diff --git a/src/lerobot/async_inference/utils/trajectory_viz.html b/src/lerobot/async_inference/utils/trajectory_viz.html new file mode 100644 index 00000000000..1305eca2e33 --- /dev/null +++ b/src/lerobot/async_inference/utils/trajectory_viz.html @@ -0,0 +1,1747 @@ + + + + + + RTC Trajectory Visualization + + + + + +

RTC Trajectory Visualization

+
Disconnected
+ + +
+
RTC Guidance Parameters
+
+
+ H (chunk size) + -- +
+
+ d (delay) + -- +
+
+ overlap_end + -- +
+
+ σ_d + -- +
+
+ schedule + -- +
+
+ max β + -- +
+
+
+
+ Frozen [0, d) + -- +
+
+ Soft Mask [d, overlap_end) + -- +
+
+ Fresh [overlap_end, H) + -- +
+
+
+

Prefix Weights (current chunk)

+ +
+
+ Hard Mask + [0, 0) +
+
+ Soft Mask + [0, 0) +
+
+ Fresh + [0, 0) +
+
+
+
+ + +
+
Session Rolling Average (for parameter sweep comparison)
+
+
+
+ Mean Discrepancy + -- + samples: 0 | max: -- +
+ +
+
+

Rolling Average Over Time

+ +
+
+
+ + +
+
+
+ Transition Discrepancy (current) + Avg: -- +
+
+
Waiting for overlapping chunks...
+
+
+ +
+ + +
+ Debug info will appear here... +
+ +
+ +
+ +

Executed Actions (actually sent to robot)

+
+ +
+ +
+
Chunk Legend (source_step)
+
+
+ +
+ Waiting for data... +
+ + + + diff --git a/src/lerobot/async_inference/utils/trajectory_viz.py b/src/lerobot/async_inference/utils/trajectory_viz.py new file mode 100644 index 00000000000..eaa0e7be626 --- /dev/null +++ b/src/lerobot/async_inference/utils/trajectory_viz.py @@ -0,0 +1,403 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Real-time trajectory visualization server for RTC inpainting assessment. + +This module provides a WebSocket + HTTP server that streams action chunk data +to a browser-based visualization. It shows per-motor trajectories from up to +10 previous action chunks, with each chunk in a different color. + +Usage: + # Start standalone server (connects to robot client via shared queue): + python -m lerobot.async_inference.utils.trajectory_viz --http_port 8088 --ws_port 8089 + + # Then open http://localhost:8088 in your browser +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import threading +import time +from http.server import HTTPServer, SimpleHTTPRequestHandler +from ipaddress import IPv4Address +from pathlib import Path +from queue import Empty, Full, Queue +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .metrics import EvActionChunk, EvExecutedAction + +logger = logging.getLogger(__name__) +ALL_INTERFACES_HOST = str(IPv4Address(0)) + + +# ============================================================================= +# WebSocket Server (using websockets library) +# ============================================================================= + + +class TrajectoryVizServer: + """WebSocket server that broadcasts action chunk data to connected clients.""" + + def __init__(self, ws_port: int = 8089, http_port: int = 8088): + self.ws_port = ws_port + self.http_port = http_port + self._chunk_queue: Queue[dict] = Queue(maxsize=100) + self._clients: set = set() + self._shutdown = threading.Event() + + def on_chunk(self, event: EvActionChunk) -> None: + """Callback to forward action chunks.""" + chunk_data = { + "type": "action_chunk", + "source_step": event.src_control_step, + "actions": event.actions, + "frozen_len": event.frozen_len, + "timestamp": event.timestamp, + # RTC visualization fields (may be None) + "rtc_params": event.rtc_params, + "prefix_weights": event.prefix_weights, + } + try: + self._chunk_queue.put_nowait(chunk_data) + except Full: + # Drop oldest if full + try: + self._chunk_queue.get_nowait() + self._chunk_queue.put_nowait(chunk_data) + except Empty: + pass + + async def _handler(self, websocket): + """Handle a WebSocket connection.""" + self._clients.add(websocket) + try: + async for message in websocket: + # Process incoming messages from clients (e.g., TrajectoryVizClient) + # and queue them for broadcasting to all other clients (browsers) + try: + data = json.loads(message) + # Queue for broadcasting (e.g., executed_action from robot client) + try: + self._chunk_queue.put_nowait(data) + except Full: + # Drop oldest if full + try: + self._chunk_queue.get_nowait() + self._chunk_queue.put_nowait(data) + except Empty: + pass + except json.JSONDecodeError: + pass + finally: + self._clients.discard(websocket) + + async def _broadcaster(self): + """Background task that broadcasts chunks to all connected clients.""" + while not self._shutdown.is_set(): + try: + # Non-blocking check with small sleep + await asyncio.sleep(0.01) + try: + chunk_data = self._chunk_queue.get_nowait() + except Empty: + continue + + if self._clients: + message = json.dumps(chunk_data) + # Broadcast to all connected clients + await asyncio.gather( + *[client.send(message) for client in self._clients], + return_exceptions=True, + ) + except Exception as e: + logger.debug(f"Broadcaster error: {e}") + + async def _run_websocket_server(self): + """Run the WebSocket server.""" + try: + import websockets + except ImportError: + logger.error("websockets package not installed. Run: uv pip install websockets") + return + + host = ALL_INTERFACES_HOST + async with websockets.serve(self._handler, host, self.ws_port): + logger.info(f"WebSocket server started on ws://{host}:{self.ws_port}") + broadcaster_task = asyncio.create_task(self._broadcaster()) + try: + await asyncio.Future() # Run forever + finally: + broadcaster_task.cancel() + + def _run_http_server(self): + """Run the HTTP server for static files.""" + # Get the directory containing the HTML file + viz_dir = Path(__file__).parent + + class Handler(SimpleHTTPRequestHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, directory=str(viz_dir), **kwargs) + + def log_message(self, format, *args): + # Suppress access logs + pass + + def do_GET(self): + # Serve trajectory_viz.html as index + if self.path == "/" or self.path == "/index.html": + self.path = "/trajectory_viz.html" + return super().do_GET() + + host = ALL_INTERFACES_HOST + server = HTTPServer((host, self.http_port), Handler) + logger.info(f"HTTP server started on http://{host}:{self.http_port}") + server.serve_forever() + + def start(self): + """Start both HTTP and WebSocket servers.""" + # Start HTTP server in a thread + http_thread = threading.Thread(target=self._run_http_server, daemon=True) + http_thread.start() + + # Run WebSocket server in the main thread's event loop + asyncio.run(self._run_websocket_server()) + + def stop(self): + """Signal shutdown.""" + self._shutdown.set() + + +# ============================================================================= +# WebSocket Client (for sending data to external visualization server) +# ============================================================================= + + +class TrajectoryVizClient: + """WebSocket client that sends action chunk data to a visualization server.""" + + # Rate limit for "not connected" warnings (seconds between warnings) + _NOT_CONNECTED_WARN_INTERVAL = 5.0 + + def __init__(self, ws_url: str = "ws://localhost:8089"): + self.ws_url = ws_url + self._ws = None + self._loop = None + self._thread: threading.Thread | None = None + self._queue: Queue[dict] = Queue(maxsize=100) + self._shutdown = threading.Event() + self._connected = False + self._connection_attempted = False + self._last_not_connected_warn: float = 0.0 + self._dropped_while_disconnected: int = 0 + + def start(self) -> None: + """Start the WebSocket client in a background thread.""" + logger.info(f"Starting trajectory viz client, will connect to {self.ws_url}") + self._thread = threading.Thread(target=self._run, daemon=True, name="trajectory_viz_client") + self._thread.start() + + def stop(self) -> None: + """Stop the WebSocket client.""" + self._shutdown.set() + if self._dropped_while_disconnected > 0: + logger.warning( + f"Trajectory viz client stopped. Total chunks dropped while disconnected: " + f"{self._dropped_while_disconnected}" + ) + + def _run(self) -> None: + """Run the WebSocket client event loop.""" + import asyncio + + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._loop.run_until_complete(self._connect_and_send()) + + async def _connect_and_send(self) -> None: + """Connect to the server and send queued data.""" + try: + import websockets + except ImportError: + logger.error("websockets package not installed. Run: uv pip install websockets") + return + + while not self._shutdown.is_set(): + self._connection_attempted = True + try: + logger.info(f"Attempting to connect to visualization server at {self.ws_url}...") + async with websockets.connect(self.ws_url) as ws: + self._connected = True + if self._dropped_while_disconnected > 0: + logger.info( + f"Connected to visualization server at {self.ws_url} " + f"(dropped {self._dropped_while_disconnected} chunks while disconnected)" + ) + self._dropped_while_disconnected = 0 + else: + logger.info(f"Connected to visualization server at {self.ws_url}") + + while not self._shutdown.is_set(): + try: + # Non-blocking check for data + await asyncio.sleep(0.01) + try: + chunk_data = self._queue.get_nowait() + await ws.send(json.dumps(chunk_data)) + except Empty: + continue + except Exception as e: + logger.warning(f"WebSocket send error: {e}") + break + + except Exception as e: + was_connected = self._connected + self._connected = False + if was_connected: + logger.warning(f"Disconnected from visualization server: {e}") + else: + logger.warning( + f"Failed to connect to visualization server at {self.ws_url}: {e}. " + f"Make sure the server is running: python -m lerobot.async_inference.utils.trajectory_viz" + ) + await asyncio.sleep(2.0) + + def on_chunk(self, event: EvActionChunk) -> None: + """Callback to queue an action chunk for sending.""" + if not self._connected: + self._dropped_while_disconnected += 1 + # Rate-limited warning to avoid log spam + now = time.time() + if now - self._last_not_connected_warn > self._NOT_CONNECTED_WARN_INTERVAL: + self._last_not_connected_warn = now + if self._connection_attempted: + logger.warning( + f"Trajectory viz: not connected to server, dropping chunks " + f"(total dropped: {self._dropped_while_disconnected}). " + f"Is the viz server running at {self.ws_url}?" + ) + else: + logger.debug("Trajectory viz: waiting for connection to establish...") + return + + chunk_data = { + "type": "action_chunk", + "source_step": event.src_control_step, + "actions": event.actions, + "frozen_len": event.frozen_len, + "timestamp": event.timestamp, + # RTC visualization fields (may be None) + "rtc_params": event.rtc_params, + "prefix_weights": event.prefix_weights, + } + try: + self._queue.put_nowait(chunk_data) + except Full: + # Drop oldest if full + try: + self._queue.get_nowait() + self._queue.put_nowait(chunk_data) + except Empty: + pass + + def on_executed_action(self, event: EvExecutedAction) -> None: + """Callback to queue an executed action for sending.""" + if not self._connected: + # Don't count/warn for executed actions, just silently drop + return + + action_data = { + "type": "executed_action", + "step": event.step, + "action": event.action, + "timestamp": event.timestamp, + } + try: + self._queue.put_nowait(action_data) + except Full: + # Drop oldest if full + try: + self._queue.get_nowait() + self._queue.put_nowait(action_data) + except Empty: + pass + + +# ============================================================================= +# Standalone Mode (for testing without robot client) +# ============================================================================= + + +def generate_mock_chunks(server: TrajectoryVizServer, interval: float = 0.5): + """Generate mock action chunks for testing the visualization.""" + import random + + step = 0 + num_actions = 50 + num_dims = 6 + + while True: + # Generate random trajectories with some continuity + actions = [] + base = [random.uniform(-1, 1) for _ in range(num_dims)] + for t in range(num_actions): + action = [base[d] + 0.1 * t + random.gauss(0, 0.05) for d in range(num_dims)] + actions.append(action) + + # Create mock event + from .metrics import EvActionChunk + + event = EvActionChunk( + src_control_step=step, + actions=actions, + frozen_len=random.randint(5, 15), + timestamp=time.time(), + ) + server.on_chunk(event) + + step += num_actions + time.sleep(interval) + + +def main(): + """Run the trajectory visualization server standalone.""" + parser = argparse.ArgumentParser(description="RTC Trajectory Visualization Server") + parser.add_argument("--http_port", type=int, default=8088, help="HTTP server port") + parser.add_argument("--ws_port", type=int, default=8089, help="WebSocket server port") + parser.add_argument("--mock", action="store_true", help="Generate mock data for testing") + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + server = TrajectoryVizServer(ws_port=args.ws_port, http_port=args.http_port) + + if args.mock: + # Start mock data generator in background + mock_thread = threading.Thread(target=generate_mock_chunks, args=(server,), daemon=True) + mock_thread.start() + logger.info("Mock data generator started") + + logger.info(f"Open http://localhost:{args.http_port} in your browser") + server.start() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/async_inference/utils/viz_utils.py b/src/lerobot/async_inference/utils/viz_utils.py new file mode 100644 index 00000000000..5abe710ac02 --- /dev/null +++ b/src/lerobot/async_inference/utils/viz_utils.py @@ -0,0 +1,62 @@ +import logging + +# Track which (schedule, d, overlap_end) combos have been logged to avoid spam +_prefix_weights_logged: set[tuple[str, int, int]] = set() + + +def compute_prefix_weights_for_viz( + d: int, overlap_end: int, chunk_len: int, schedule: str = "linear" +) -> list[float]: + """Compute prefix weights for RTC visualization. + + Args: + d: Inference delay (hard mask region ends at d). + overlap_end: Where soft masking ends (H - d with s=d). + chunk_len: Total chunk size. + schedule: Weight schedule ("linear" or "exp"). + + Returns: + List of `chunk_len` floats, each in [0, 1]: + - [0, d): weight = 1.0 (hard mask) + - [d, overlap_end): weight decays 1->0 (soft mask) + - [overlap_end, chunk_len): weight = 0.0 (fresh) + """ + import math + + weights = [] + for i in range(chunk_len): + if i < d: + # Hard mask region + weights.append(1.0) + elif i < overlap_end: + # Soft masking region - linear decay from 1 to 0 + if overlap_end > d: + t = (i - d) / (overlap_end - d) # t goes from 0 to 1 + w = 1.0 - t # Linear decay + if schedule.lower() == "exp": + # Exponential decay (steeper at start) + w = w * (math.expm1(w) / (math.e - 1)) if w > 0 else 0.0 + weights.append(w) + else: + weights.append(0.0) + else: + # Fresh region + weights.append(0.0) + + # Log weight samples once per unique (schedule, d, overlap_end) to verify formula + _log_key = (schedule.lower(), d, overlap_end) + if _log_key not in _prefix_weights_logged and chunk_len > 0: + _prefix_weights_logged.add(_log_key) + logger = logging.getLogger("policy_server_drtc") + sample_indices = [d, (d + overlap_end) // 2, overlap_end - 1] + samples = [(i, weights[i]) for i in sample_indices if 0 <= i < len(weights)] + logger.info( + "RTC prefix weights (%s): d=%d, overlap_end=%d, H=%d, samples=%s", + schedule, + d, + overlap_end, + chunk_len, + [(f"w[{i}]", f"{w:.3f}") for i, w in samples], + ) + + return weights diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index f3289ddc7e3..a78059bf78f 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -589,4 +589,4 @@ def disconnect(self) -> None: self.latest_timestamp = None self.new_frame_event.clear() - logger.info(f"{self} disconnected.") + logger.info(f"{self} disconnected.") \ No newline at end of file diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 26f0c769caa..09ad0c93c32 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1898,3 +1898,4 @@ def __repr__(self): f" Transformations: {self.image_transforms},\n" f")" ) + \ No newline at end of file diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index d50d8652a81..0829477eff3 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -587,4 +587,4 @@ def _make_processors_from_policy_config( ) module = importlib.import_module(module_path) function = getattr(module, function_name) - return function(config, dataset_stats=dataset_stats) + return function(config, dataset_stats=dataset_stats) \ No newline at end of file diff --git a/src/lerobot/transport/services.proto b/src/lerobot/transport/services.proto index ea0c12de673..15b84c5832a 100644 --- a/src/lerobot/transport/services.proto +++ b/src/lerobot/transport/services.proto @@ -40,15 +40,19 @@ service AsyncInference { // Policy -> Robot to share actions predicted for given observations rpc SendObservations(stream Observation) returns (Empty); rpc GetActions(Empty) returns (Actions); + // Low-jitter path: server-streamed dense action chunks (preferred). + rpc StreamActionsDense(Empty) returns (stream ActionsDense); rpc SendPolicyInstructions(PolicySetup) returns (Empty); rpc Ready(Empty) returns (Empty); + // Robot -> Policy to send trajectory data for visualization + rpc SendTrajectoryChunk(TrajectoryChunk) returns (Empty); } enum TransferState { - TRANSFER_UNKNOWN = 0; - TRANSFER_BEGIN = 1; - TRANSFER_MIDDLE = 2; - TRANSFER_END = 3; + TRANSFER_UNKNOWN = 0; + TRANSFER_BEGIN = 1; + TRANSFER_MIDDLE = 2; + TRANSFER_END = 3; } // Messages @@ -70,7 +74,7 @@ message InteractionMessage { // Messages message Observation { // sent by Robot, to remote Policy - TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size + TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size bytes data = 2; } @@ -79,9 +83,39 @@ message Actions { bytes data = 1; } +message ActionsDense { + // Timestamp of the source observation (Unix time, seconds). + double timestamp = 1; + // Control-loop timestep t that triggered this inference (LWW logical clock). + int64 source_control_step = 2; + // Action timestep spacing (seconds). + float dt = 3; + // Actions array shape: (num_actions, action_dim). + int32 num_actions = 4; + int32 action_dim = 5; + // Packed row-major float32 actions, length = num_actions*action_dim*4 bytes. + bytes actions_f32 = 6; + // Action step n_k at which this chunk starts in execution space. + int64 chunk_start_step = 7; + // Timestamp when the server received this observation (Unix time, seconds). + double server_obs_received_ts = 8; + // Timestamp when the server finished inference and is about to send (Unix time, seconds). + double server_action_sent_ts = 9; +} + message PolicySetup { // sent by Robot to remote server, to init Policy bytes data = 1; } +message TrajectoryChunk { + // sent by Robot to Policy server for trajectory visualization + int64 source_step = 1; // Chunk provenance (observation step that triggered inference) + int32 num_actions = 2; // Number of actions in the chunk (T) + int32 action_dim = 3; // Action dimension (A) + bytes actions_f32 = 4; // Packed row-major float32 actions, length = T*A*4 bytes + int32 frozen_len = 5; // Number of frozen actions in this chunk + double timestamp = 6; // Arrival time (Unix time, seconds) +} + message Empty {} diff --git a/src/lerobot/transport/services_pb2.py b/src/lerobot/transport/services_pb2.py index 05f2d174fd4..3a71c627d4a 100644 --- a/src/lerobot/transport/services_pb2.py +++ b/src/lerobot/transport/services_pb2.py @@ -23,15 +23,15 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xf5\x01\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\xe1\x01\n\x0c\x41\x63tionsDense\x12\x11\n\ttimestamp\x18\x01 \x01(\x01\x12\x1b\n\x13source_control_step\x18\x02 \x01(\x03\x12\n\n\x02\x64t\x18\x03 \x01(\x02\x12\x13\n\x0bnum_actions\x18\x04 \x01(\x05\x12\x12\n\naction_dim\x18\x05 \x01(\x05\x12\x13\n\x0b\x61\x63tions_f32\x18\x06 \x01(\x0c\x12\x18\n\x10\x63hunk_start_step\x18\x07 \x01(\x03\x12\x1e\n\x16server_obs_received_ts\x18\x08 \x01(\x01\x12\x1d\n\x15server_action_sent_ts\x18\t \x01(\x01\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x8b\x01\n\x0fTrajectoryChunk\x12\x13\n\x0bsource_step\x18\x01 \x01(\x03\x12\x13\n\x0bnum_actions\x18\x02 \x01(\x05\x12\x12\n\naction_dim\x18\x03 \x01(\x05\x12\x13\n\x0b\x61\x63tions_f32\x18\x04 \x01(\x0c\x12\x12\n\nfrozen_len\x18\x05 \x01(\x05\x12\x11\n\ttimestamp\x18\x06 \x01(\x01\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xfd\x02\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x41\n\x12StreamActionsDense\x12\x10.transport.Empty\x1a\x17.transport.ActionsDense0\x01\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty\x12\x43\n\x13SendTrajectoryChunk\x12\x1a.transport.TrajectoryChunk\x1a\x10.transport.Emptyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_TRANSFERSTATE']._serialized_start=431 - _globals['_TRANSFERSTATE']._serialized_end=527 + _globals['_TRANSFERSTATE']._serialized_start=801 + _globals['_TRANSFERSTATE']._serialized_end=897 _globals['_TRANSITION']._serialized_start=47 _globals['_TRANSITION']._serialized_end=123 _globals['_PARAMETERS']._serialized_start=125 @@ -42,12 +42,16 @@ _globals['_OBSERVATION']._serialized_end=366 _globals['_ACTIONS']._serialized_start=368 _globals['_ACTIONS']._serialized_end=391 - _globals['_POLICYSETUP']._serialized_start=393 - _globals['_POLICYSETUP']._serialized_end=420 - _globals['_EMPTY']._serialized_start=422 - _globals['_EMPTY']._serialized_end=429 - _globals['_LEARNERSERVICE']._serialized_start=530 - _globals['_LEARNERSERVICE']._serialized_end=787 - _globals['_ASYNCINFERENCE']._serialized_start=790 - _globals['_ASYNCINFERENCE']._serialized_end=1035 + _globals['_ACTIONSDENSE']._serialized_start=394 + _globals['_ACTIONSDENSE']._serialized_end=619 + _globals['_POLICYSETUP']._serialized_start=621 + _globals['_POLICYSETUP']._serialized_end=648 + _globals['_TRAJECTORYCHUNK']._serialized_start=651 + _globals['_TRAJECTORYCHUNK']._serialized_end=790 + _globals['_EMPTY']._serialized_start=792 + _globals['_EMPTY']._serialized_end=799 + _globals['_LEARNERSERVICE']._serialized_start=900 + _globals['_LEARNERSERVICE']._serialized_end=1157 + _globals['_ASYNCINFERENCE']._serialized_start=1160 + _globals['_ASYNCINFERENCE']._serialized_end=1541 # @@protoc_insertion_point(module_scope) diff --git a/src/lerobot/transport/services_pb2_grpc.py b/src/lerobot/transport/services_pb2_grpc.py index 35a01b6754e..b3eda689e64 100644 --- a/src/lerobot/transport/services_pb2_grpc.py +++ b/src/lerobot/transport/services_pb2_grpc.py @@ -254,6 +254,11 @@ def __init__(self, channel): request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, response_deserializer=lerobot_dot_transport_dot_services__pb2.Actions.FromString, _registered_method=True) + self.StreamActionsDense = channel.unary_stream( + '/transport.AsyncInference/StreamActionsDense', + request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.ActionsDense.FromString, + _registered_method=True) self.SendPolicyInstructions = channel.unary_unary( '/transport.AsyncInference/SendPolicyInstructions', request_serializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString, @@ -264,6 +269,11 @@ def __init__(self, channel): request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) + self.SendTrajectoryChunk = channel.unary_unary( + '/transport.AsyncInference/SendTrajectoryChunk', + request_serializer=lerobot_dot_transport_dot_services__pb2.TrajectoryChunk.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) class AsyncInferenceServicer: @@ -285,6 +295,13 @@ def GetActions(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def StreamActionsDense(self, request, context): + """Low-jitter path: server-streamed dense action chunks (preferred). + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def SendPolicyInstructions(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -297,6 +314,13 @@ def Ready(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def SendTrajectoryChunk(self, request, context): + """Robot -> Policy to send trajectory data for visualization + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_AsyncInferenceServicer_to_server(servicer, server): rpc_method_handlers = { @@ -310,6 +334,11 @@ def add_AsyncInferenceServicer_to_server(servicer, server): request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, response_serializer=lerobot_dot_transport_dot_services__pb2.Actions.SerializeToString, ), + 'StreamActionsDense': grpc.unary_stream_rpc_method_handler( + servicer.StreamActionsDense, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.ActionsDense.SerializeToString, + ), 'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler( servicer.SendPolicyInstructions, request_deserializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.FromString, @@ -320,6 +349,11 @@ def add_AsyncInferenceServicer_to_server(servicer, server): request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, ), + 'SendTrajectoryChunk': grpc.unary_unary_rpc_method_handler( + servicer.SendTrajectoryChunk, + request_deserializer=lerobot_dot_transport_dot_services__pb2.TrajectoryChunk.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'transport.AsyncInference', rpc_method_handlers) @@ -387,6 +421,33 @@ def GetActions(request, metadata, _registered_method=True) + @staticmethod + def StreamActionsDense(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/transport.AsyncInference/StreamActionsDense', + lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_transport_dot_services__pb2.ActionsDense.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def SendPolicyInstructions(request, target, @@ -440,3 +501,30 @@ def Ready(request, timeout, metadata, _registered_method=True) + + @staticmethod + def SendTrajectoryChunk(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/transport.AsyncInference/SendTrajectoryChunk', + lerobot_dot_transport_dot_services__pb2.TrajectoryChunk.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/tests/async_inference/test_drtc_timed.py b/tests/async_inference/test_drtc_timed.py new file mode 100644 index 00000000000..3ef2fdedd85 --- /dev/null +++ b/tests/async_inference/test_drtc_timed.py @@ -0,0 +1,44 @@ +import pickle + +import numpy as np + +from lerobot.async_inference.drtc_timed import DrtcAction, DrtcObservation +from lerobot.utils.constants import OBS_STATE + + +def test_drtc_action_getters_and_pickle_roundtrip(): + action = np.arange(6, dtype=np.float32) + timed_action = DrtcAction(timestamp=123.4, control_step=7, action_step=11, action=action) + + assert timed_action.get_timestamp() == 123.4 + assert timed_action.get_control_step() == 7 + assert timed_action.get_action_step() == 11 + np.testing.assert_array_equal(timed_action.get_action(), action) + + reloaded = pickle.loads(pickle.dumps(timed_action)) # nosec B301 + assert reloaded.get_control_step() == 7 + assert reloaded.get_action_step() == 11 + np.testing.assert_array_equal(reloaded.get_action(), action) + + +def test_drtc_observation_getters_and_pickle_roundtrip(): + observation = {OBS_STATE: [1.0, 2.0, 3.0]} + timed_observation = DrtcObservation( + timestamp=456.7, + control_step=13, + observation=observation, + chunk_start_step=17, + server_received_ts=460.0, + ) + + assert timed_observation.get_timestamp() == 456.7 + assert timed_observation.get_control_step() == 13 + assert timed_observation.get_observation() == observation + assert timed_observation.chunk_start_step == 17 + assert timed_observation.server_received_ts == 460.0 + + reloaded = pickle.loads(pickle.dumps(timed_observation)) # nosec B301 + assert reloaded.get_control_step() == 13 + assert reloaded.get_observation() == observation + assert reloaded.chunk_start_step == 17 + assert reloaded.server_received_ts == 460.0 diff --git a/tests/async_inference/test_lww_register.py b/tests/async_inference/test_lww_register.py new file mode 100644 index 00000000000..bef4be96a91 --- /dev/null +++ b/tests/async_inference/test_lww_register.py @@ -0,0 +1,127 @@ +"""Tests for the LWW register primitive. + +The LWW (Last-Write-Wins) register uses a single logical clock (action step) +for all causality relationships. See robot_client_drtc.py for the full +causality model documentation. +""" + +from lerobot.async_inference.lww_register import LWWCursor, LWWRegister, LWWState + + +def test_lww_state_idempotent() -> None: + s = LWWState(control_step=3, value="x") + assert (s | s) == s + + +def test_lww_state_commutative_distinct_control_step() -> None: + a = LWWState(control_step=1, value="a") + b = LWWState(control_step=2, value="b") + assert (a | b) == (b | a) + assert (a | b).control_step == 2 + + +def test_lww_state_associative_distinct_control_step() -> None: + b = LWWState(control_step=2, value="b") + a = LWWState(control_step=1, value="a") + c = LWWState(control_step=3, value="c") + assert ((a | b) | c) == (a | (b | c)) + assert (a | (b | c)).control_step == 3 + + +def test_lww_state_equal_control_step_stability_when_values_equal() -> None: + # Our join is stable on ties: equal-control_step should not introduce changes. + a1 = LWWState(control_step=7, value={"x": 1}) + a2 = LWWState(control_step=7, value={"x": 1}) + assert (a1 | a2) == a1 + assert (a2 | a1) == a2 + + +def test_register_starts_at_initial_state() -> None: + reg: LWWRegister[str | None] = LWWRegister(initial_control_step=-1, initial_value=None) + s = reg.read() + assert s.control_step == -1 + assert s.value is None + + +def test_register_monotone_ignores_stale_and_equal_control_step_updates() -> None: + reg: LWWRegister[str | None] = LWWRegister(initial_control_step=-1, initial_value=None) + + reg.update(1, "v1") + assert reg.read() == LWWState(control_step=1, value="v1") + + # stale + reg.update(0, "stale") + assert reg.read() == LWWState(control_step=1, value="v1") + + # equal-control_step (should not overwrite) + reg.update(1, "v1-duplicate") + assert reg.read() == LWWState(control_step=1, value="v1") + + +def test_register_out_of_order_tolerates_gaps() -> None: + reg: LWWRegister[str] = LWWRegister(initial_control_step=-1, initial_value="") + reg.update(10, "ten") + reg.update(5, "five") # out-of-order + assert reg.read() == LWWState(control_step=10, value="ten") + + +def test_register_control_step_never_decreases_under_updates() -> None: + reg: LWWRegister[int] = LWWRegister(initial_control_step=-1, initial_value=0) + steps = [3, 1, 5, 5, 2, 9, 4] + last_step = reg.read().control_step + for step in steps: + reg.update(step, step) + new_step = reg.read().control_step + assert new_step >= last_step + last_step = new_step + + +def test_register_read_returns_latest_state() -> None: + reg: LWWRegister[str | None] = LWWRegister(initial_control_step=-1, initial_value=None) + assert reg.read().value is None + reg.update(0, "zero") + reg.update(2, "two") + assert reg.read() == LWWState(control_step=2, value="two") + + +def test_cursor_is_monotone_semilattice() -> None: + c1 = LWWCursor(watermark=1) + c2 = LWWCursor(watermark=2) + c3 = LWWCursor(watermark=3) + assert (c1 | c1) == c1 + assert (c1 | c2) == (c2 | c1) == c2 + assert ((c1 | c2) | c3) == (c1 | (c2 | c3)) == c3 + + +def test_read_if_newer_returns_is_new_once() -> None: + reg: LWWRegister[str | None] = LWWRegister(initial_control_step=-1, initial_value=None) + reader = reg.reader() + + # Nothing newer than cursor (control_step=-1) + state, cursor2, is_new = reader.read_if_newer() + assert is_new is False + assert cursor2 == LWWCursor(watermark=-1) + assert state.control_step == -1 + + # Update to control_step=0; first read should be new + reg.update_if_newer(0, "zero") + state, cursor, is_new = reader.read_if_newer() + assert is_new is True + assert cursor.watermark == 0 + assert state == LWWState(control_step=0, value="zero") + + # Second read without update should not be new + state2, cursor2, is_new2 = reader.read_if_newer() + assert is_new2 is False + assert cursor2 == cursor + assert state2 == state + + +def test_update_if_newer_reports_rejection_on_stale_or_equal_control_step() -> None: + reg: LWWRegister[str] = LWWRegister(initial_control_step=-1, initial_value="") + _, did_update1 = reg.update_if_newer(1, "one") + assert did_update1 is True + _, did_update_stale = reg.update_if_newer(0, "stale") + assert did_update_stale is False + _, did_update_equal = reg.update_if_newer(1, "equal") + assert did_update_equal is False diff --git a/tests/async_inference/test_rtc_guidance.py b/tests/async_inference/test_rtc_guidance.py new file mode 100644 index 00000000000..c0dae904b95 --- /dev/null +++ b/tests/async_inference/test_rtc_guidance.py @@ -0,0 +1,264 @@ +# Copyright 2026 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for AsyncRTCProcessor in async inference. + +These tests verify that RTC guidance works correctly when: +1. No postprocess is used (raw model space) +2. Postprocess preserves dimensions +3. Postprocess changes dimensions (e.g., 32 -> 6) +""" + +from __future__ import annotations + +import torch + +from lerobot.async_inference.rtc_guidance import AsyncRTCConfig, AsyncRTCProcessor + + +class TestAsyncRTCProcessor: + """Tests for AsyncRTCProcessor.denoise_step with various postprocess configurations.""" + + def test_denoise_step_without_postprocess(self) -> None: + """Test RTC guidance without any postprocess (raw model space).""" + cfg = AsyncRTCConfig(enabled=True) + rtc = AsyncRTCProcessor(cfg, postprocess=None) + + # Both x_t and prev are in the same action space (32 dims) + x_t = torch.randn(1, 50, 32) + prev = torch.randn(1, 10, 32) + + def mock_denoise(x: torch.Tensor) -> torch.Tensor: + return torch.randn_like(x) + + result = rtc.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev, + inference_delay=5, + time=0.5, + original_denoise_step_partial=mock_denoise, + overlap_end=10, # H - s where s = d = 5, so overlap at 50 - 5 = 45, but we test with explicit value + ) + + assert result.shape == x_t.shape + + def test_denoise_step_with_dimension_preserving_postprocess(self) -> None: + """Test RTC guidance with a postprocess that preserves dimensions.""" + cfg = AsyncRTCConfig(enabled=True) + + # Postprocess that keeps the same dimension (e.g., unnormalization) + def postprocess_identity(x_bta: torch.Tensor) -> torch.Tensor: + return x_bta * 2.0 # Just scales, preserves shape + + rtc = AsyncRTCProcessor(cfg, postprocess=postprocess_identity) + + x_t = torch.randn(1, 50, 32) + prev = torch.randn(1, 10, 32) + + def mock_denoise(x: torch.Tensor) -> torch.Tensor: + return torch.randn_like(x) + + result = rtc.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev, + inference_delay=5, + time=0.5, + original_denoise_step_partial=mock_denoise, + overlap_end=10, + ) + + assert result.shape == x_t.shape + + def test_denoise_step_with_dimension_changing_postprocess(self) -> None: + """Test RTC guidance when postprocess changes action dimension (32 -> 6). + + This is the key test case that was failing. The client sends frozen actions + in executable action space (6 dims), while the model operates in raw action + space (32 dims). The postprocess converts model output to executable space. + """ + cfg = AsyncRTCConfig(enabled=True) + + # Postprocess that changes dimensions: 32 -> 6 + # This simulates what happens in the real system where model output + # (32 dims) is transformed to robot action space (6 dims) + def postprocess_32_to_6(x_bta: torch.Tensor) -> torch.Tensor: + b, t, a_in = x_bta.shape + assert a_in == 32, f"Expected 32 dims, got {a_in}" + # Simulate reduction to 6 dims (e.g., taking first 6 or linear projection) + return x_bta[:, :, :6] + + rtc = AsyncRTCProcessor(cfg, postprocess=postprocess_32_to_6) + + # x_t is in raw model space (32 dims) + x_t = torch.randn(1, 50, 32) + # prev is in executable action space (6 dims) - what the client sends + prev = torch.randn(1, 10, 6) + + def mock_denoise(x: torch.Tensor) -> torch.Tensor: + return torch.randn_like(x) + + # This should NOT raise "size of tensor a (32) must match size of tensor b (6)" + result = rtc.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev, + inference_delay=5, + time=0.5, + original_denoise_step_partial=mock_denoise, + overlap_end=10, + ) + + # Output should match input shape (still in raw model space) + assert result.shape == x_t.shape + + def test_denoise_step_with_prev_longer_than_chunk(self) -> None: + """Test when prev chunk is longer than the model's chunk size.""" + cfg = AsyncRTCConfig(enabled=True) + + def postprocess_32_to_6(x_bta: torch.Tensor) -> torch.Tensor: + return x_bta[:, :, :6] + + rtc = AsyncRTCProcessor(cfg, postprocess=postprocess_32_to_6) + + x_t = torch.randn(1, 50, 32) + # prev is longer than x_t's temporal dimension + prev = torch.randn(1, 100, 6) + + def mock_denoise(x: torch.Tensor) -> torch.Tensor: + return torch.randn_like(x) + + result = rtc.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev, + inference_delay=5, + time=0.5, + original_denoise_step_partial=mock_denoise, + overlap_end=10, + ) + + assert result.shape == x_t.shape + + def test_denoise_step_disabled(self) -> None: + """Test that RTC is bypassed when disabled.""" + cfg = AsyncRTCConfig(enabled=False) + rtc = AsyncRTCProcessor(cfg, postprocess=None) + + x_t = torch.randn(1, 50, 32) + prev = torch.randn(1, 10, 6) + + call_count = 0 + + def mock_denoise(x: torch.Tensor) -> torch.Tensor: + nonlocal call_count + call_count += 1 + return torch.randn_like(x) + + result = rtc.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev, + inference_delay=5, + time=0.5, + original_denoise_step_partial=mock_denoise, + overlap_end=10, + ) + + # Should call denoise once and return directly + assert call_count == 1 + assert result.shape == x_t.shape + + def test_denoise_step_no_prev_chunk(self) -> None: + """Test that RTC is bypassed when no prev_chunk_left_over is provided.""" + cfg = AsyncRTCConfig(enabled=True) + rtc = AsyncRTCProcessor(cfg, postprocess=None) + + x_t = torch.randn(1, 50, 32) + + def mock_denoise(x: torch.Tensor) -> torch.Tensor: + return torch.randn_like(x) + + result = rtc.denoise_step( + x_t=x_t, + prev_chunk_left_over=None, + inference_delay=5, + time=0.5, + original_denoise_step_partial=mock_denoise, + overlap_end=10, + ) + + assert result.shape == x_t.shape + + def test_denoise_step_2d_input(self) -> None: + """Test that 2D input (without batch dim) is handled correctly.""" + cfg = AsyncRTCConfig(enabled=True) + + def postprocess_32_to_6(x_bta: torch.Tensor) -> torch.Tensor: + return x_bta[:, :, :6] + + rtc = AsyncRTCProcessor(cfg, postprocess=postprocess_32_to_6) + + # 2D input (no batch dimension) + x_t = torch.randn(50, 32) + prev = torch.randn(10, 6) + + def mock_denoise(x: torch.Tensor) -> torch.Tensor: + return torch.randn_like(x) + + result = rtc.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev, + inference_delay=5, + time=0.5, + original_denoise_step_partial=mock_denoise, + overlap_end=10, + ) + + # Should squeeze back to 2D + assert result.shape == x_t.shape + + +class TestPrefixWeights: + """Tests for the prefix weight calculation.""" + + def test_linear_schedule(self) -> None: + """Test linear weight schedule.""" + cfg = AsyncRTCConfig(enabled=True, prefix_attention_schedule="linear") + rtc = AsyncRTCProcessor(cfg) + + weights = rtc._get_prefix_weights(start=5, end=15, total=50) + assert weights.shape == (50,) + # First 5 should be 1.0 + assert torch.allclose(weights[:5], torch.ones(5)) + # After 15 should be 0.0 + assert torch.allclose(weights[15:], torch.zeros(35)) + + def test_zeros_schedule(self) -> None: + """Test zeros weight schedule.""" + cfg = AsyncRTCConfig(enabled=True, prefix_attention_schedule="zeros") + rtc = AsyncRTCProcessor(cfg) + + weights = rtc._get_prefix_weights(start=5, end=15, total=50) + assert weights.shape == (50,) + # First 5 should be 1.0, rest 0.0 + assert torch.allclose(weights[:5], torch.ones(5)) + assert torch.allclose(weights[5:], torch.zeros(45)) + + def test_ones_schedule(self) -> None: + """Test ones weight schedule.""" + cfg = AsyncRTCConfig(enabled=True, prefix_attention_schedule="ones") + rtc = AsyncRTCProcessor(cfg) + + weights = rtc._get_prefix_weights(start=5, end=15, total=50) + assert weights.shape == (50,) + # First 15 should be 1.0, rest 0.0 + assert torch.allclose(weights[:15], torch.ones(15)) + assert torch.allclose(weights[15:], torch.zeros(35)) diff --git a/tests/async_inference/test_rtc_transport.py b/tests/async_inference/test_rtc_transport.py new file mode 100644 index 00000000000..88247aa50a8 --- /dev/null +++ b/tests/async_inference/test_rtc_transport.py @@ -0,0 +1,42 @@ +# Copyright 2026 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np + + +def test_rtc_frozen_prefix_roundtrip_bytes() -> None: + t = 7 + a = 13 + x = np.random.randn(t, a).astype(np.float32) + + payload = x.tobytes(order="C") + y = np.frombuffer(payload, dtype=np.float32) + assert y.size == t * a + y = y.reshape(t, a) + + assert y.shape == x.shape + assert np.allclose(y, x) + + +def test_actionsdense_raw_payload_sizes_match() -> None: + # This mirrors the assumption used by the client and server: + # `actions_f32` is packed float32 with shape (t, a). + t = 5 + a = 4 + nbytes_expected = t * a * 4 + + x = np.zeros((t, a), dtype=np.float32, order="C") + assert len(x.tobytes(order="C")) == nbytes_expected