diff --git a/examples/remote_rtc/eval_dataset.py b/examples/remote_rtc/eval_dataset.py new file mode 100644 index 00000000000..4dfeb9794a8 --- /dev/null +++ b/examples/remote_rtc/eval_dataset.py @@ -0,0 +1,528 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Evaluate Real-Time Chunking (RTC) on dataset samples using remote inference. + +This script evaluates RTC performance on dataset samples by communicating +with a remote RTC policy server. It compares action predictions with and +without RTC, measuring consistency and ground truth alignment. + +The server runs the heavy policy inference on a powerful machine (e.g., with GPU), +while this client can run on a lightweight computer. + +Usage: + # First, start the server on a powerful machine: + python examples/remote_rtc/rtc_policy_server.py \ + --host=0.0.0.0 \ + --port=8080 + + # Then, run this evaluation script: + python examples/remote_rtc/eval_dataset.py \ + --server_address=192.168.1.100:8080 \ + --policy_type=smolvla \ + --pretrained_name_or_path=helper2424/smolvla_check_rtc_last3 \ + --policy_device=cuda \ + --dataset.repo_id=helper2424/check_rtc \ + --rtc.execution_horizon=8 \ + --seed=10 + + # With Pi0.5 policy: + python examples/remote_rtc/eval_dataset.py \ + --server_address=192.168.1.100:8080 \ + --policy_type=pi05 \ + --pretrained_name_or_path=lerobot/pi05_libero_finetuned \ + --policy_device=cuda \ + --dataset.repo_id=HuggingFaceVLA/libero \ + --rtc.execution_horizon=10 \ + --seed=10 +""" + +import logging +import os +import pickle # nosec +import random +import time +from dataclasses import asdict, dataclass, field +from pprint import pformat +from typing import Any + +import draccus +import grpc +import matplotlib.pyplot as plt +import numpy as np +import torch + +from lerobot.configs.default import DatasetConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import RTCAttentionSchedule +from lerobot.datasets.factory import resolve_delta_timestamps +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.profiling import RTCProfiler, RTCProfilingRecord +from lerobot.policies.rtc.remote import RTCActionData, RTCObservationData, RTCRemotePolicyConfig +from lerobot.transport import ( + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore +) +from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def set_seed(seed: int): + """Set random seed for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +@dataclass +class RTCEvalConfig: + """Configuration for remote RTC dataset evaluation.""" + + # Policy configuration (required fields first) + policy_type: str = field(metadata={"help": "Type of policy (smolvla, pi0, pi05)"}) + pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"}) + + # Dataset configuration + dataset: DatasetConfig = field(default_factory=DatasetConfig) + + # Policy device + policy_device: str = field(default="cuda", metadata={"help": "Device for policy inference on server"}) + + # Network configuration + server_address: str = field(default="localhost:8080", metadata={"help": "Server address"}) + + # RTC configuration + rtc: RTCConfig = field( + default_factory=lambda: RTCConfig( + enabled=True, + execution_horizon=20, + max_guidance_weight=10.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + ) + ) + + # Evaluation parameters + seed: int = field(default=42, metadata={"help": "Random seed"}) + inference_delay: int = field(default=4, metadata={"help": "Simulated inference delay"}) + output_dir: str = field(default="rtc_remote_eval_output", metadata={"help": "Output directory"}) + enable_profiling: bool = field( + default=False, + metadata={"help": "Collect per-request timing and save profiling artifacts"}, + ) + profiling_run_name: str = field( + default="remote_rtc_dataset", + metadata={"help": "Filename prefix for profiling artifacts"}, + ) + verbose_request_logging: bool = field( + default=False, + metadata={"help": "Enable per-request timing logs"}, + ) + use_torch_compile: bool = field( + default=False, + metadata={"help": "Enable torch.compile on the server policy"}, + ) + torch_compile_mode: str = field( + default="reduce-overhead", + metadata={"help": "torch.compile mode (reduce-overhead, max-autotune, default)"}, + ) + + def __post_init__(self): + if not self.server_address: + raise ValueError("server_address cannot be empty") + if not self.policy_type: + raise ValueError("policy_type cannot be empty") + if not self.pretrained_name_or_path: + raise ValueError("pretrained_name_or_path cannot be empty") + + +class RTCEvaluator: + """Evaluator for RTC on dataset samples using remote inference.""" + + def __init__(self, cfg: RTCEvalConfig): + self.cfg = cfg + self.request_idx = 0 + self.sim_queue_size = 0 + self.sim_action_index = 0 + + # Load dataset + logger.info(f"Loading dataset: {cfg.dataset.repo_id}") + + # Get metadata for delta_timestamps calculation + logger.debug("Getting dataset metadata...") + ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id) + + # Create a temporary policy config to resolve delta_timestamps + logger.debug("Loading policy config...") + policy_cfg = PreTrainedConfig.from_pretrained(cfg.pretrained_name_or_path) + delta_timestamps = resolve_delta_timestamps(policy_cfg, ds_meta) + + logger.debug("Creating LeRobotDataset...") + self.dataset = LeRobotDataset( + cfg.dataset.repo_id, + delta_timestamps=delta_timestamps, + ) + logger.info(f"Dataset loaded: {len(self.dataset)} samples") + + # Note: Preprocessing is done on server side, not client + # Initialize gRPC connection + logger.debug(f"Creating gRPC channel to {cfg.server_address}...") + self.channel = grpc.insecure_channel( + cfg.server_address, + grpc_channel_options(initial_backoff="0.1s"), + ) + self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel) + + # Create lerobot features from dataset + self.lerobot_features = {} + self.profiler = RTCProfiler(cfg.enable_profiling, cfg.output_dir, cfg.profiling_run_name) + + logger.info(f"Ready to connect to server at {cfg.server_address}") + + def connect(self) -> bool: + """Connect to server and send policy instructions.""" + try: + logger.debug("Sending Ready signal to server...") + start_time = time.perf_counter() + self.stub.Ready(services_pb2.Empty()) + logger.info(f"Connected to server in {time.perf_counter() - start_time:.4f}s") + + # Send policy configuration + logger.debug("Sending policy instructions...") + policy_config = RTCRemotePolicyConfig( + policy_type=self.cfg.policy_type, + pretrained_name_or_path=self.cfg.pretrained_name_or_path, + lerobot_features=self.lerobot_features, + rtc_config=self.cfg.rtc, + device=self.cfg.policy_device, + use_torch_compile=self.cfg.use_torch_compile, + torch_compile_mode=self.cfg.torch_compile_mode, + ) + + policy_config_bytes = pickle.dumps(policy_config) + self.stub.SendPolicyInstructions(services_pb2.PolicySetup(data=policy_config_bytes)) + + logger.info(f"Policy instructions sent | Type: {self.cfg.policy_type}") + return True + + except grpc.RpcError as e: + logger.error(f"Failed to connect to server: {e}") + return False + + def _request_actions( + self, + observation: dict[str, Any], + inference_delay: int, + prev_chunk_left_over: torch.Tensor | None, + execution_horizon: int, + label: str, + ) -> RTCActionData: + """Send observation and get actions from remote server.""" + logger.debug(f"Preparing observation (delay={inference_delay}, horizon={execution_horizon})...") + + t_start = time.perf_counter() + queue_size_before = self.sim_queue_size + action_index_before = self.sim_action_index + + rtc_obs = RTCObservationData( + observation=observation, + timestamp=time.time(), + timestep=action_index_before, + inference_delay=inference_delay, + prev_chunk_left_over=prev_chunk_left_over, + execution_horizon=execution_horizon, + ) + + obs_bytes = pickle.dumps(rtc_obs) + t_pickle = time.perf_counter() + pickle_ms = (t_pickle - t_start) * 1000 + + logger.debug(f"Sending observation ({len(obs_bytes)} bytes, pickle: {pickle_ms:.1f}ms)...") + obs_iterator = send_bytes_in_chunks( + obs_bytes, + services_pb2.Observation, + log_prefix="[CLIENT] Observation", + silent=True, + ) + self.stub.SendObservations(obs_iterator) + t_send = time.perf_counter() + send_ms = (t_send - t_pickle) * 1000 + + # Get actions + logger.debug("Waiting for actions from server...") + actions_response = self.stub.GetActions(services_pb2.Empty()) + t_response = time.perf_counter() + roundtrip_ms = (t_response - t_send) * 1000 + + if len(actions_response.data) == 0: + raise RuntimeError("Empty response from server") + + action_data = pickle.loads(actions_response.data) # nosec + t_unpickle = time.perf_counter() + unpickle_ms = (t_unpickle - t_response) * 1000 + + total_ms = (t_unpickle - t_start) * 1000 + chunk_size = int(action_data.actions.shape[0]) + realized_delay = max(int(inference_delay), 0) + queue_size_after = max(chunk_size - realized_delay, 0) + self.sim_queue_size = queue_size_after + self.sim_action_index = 0 + server_timing = getattr(action_data, "timing", None) + + self.profiler.add( + RTCProfilingRecord( + request_idx=self.request_idx, + timestamp=time.time(), + label=label, + payload_bytes=len(obs_bytes), + queue_size_before=queue_size_before, + queue_size_after=queue_size_after, + action_index_before=action_index_before, + inference_delay_requested=inference_delay, + realized_delay=realized_delay, + client_pickle_ms=pickle_ms, + client_send_ms=send_ms, + client_get_actions_ms=roundtrip_ms, + client_unpickle_ms=unpickle_ms, + client_total_ms=total_ms, + server_queue_wait_ms=(server_timing.queue_wait_ms if server_timing is not None else None), + server_preprocess_ms=(server_timing.preprocess_ms if server_timing is not None else None), + server_inference_ms=(server_timing.inference_ms if server_timing is not None else None), + server_postprocess_ms=(server_timing.postprocess_ms if server_timing is not None else None), + server_pickle_ms=server_timing.pickle_ms if server_timing is not None else None, + server_total_ms=server_timing.total_ms if server_timing is not None else None, + ) + ) + self.request_idx += 1 + + if self.cfg.verbose_request_logging: + logger.info( + f"Actions received | " + f"pickle: {pickle_ms:.1f}ms | " + f"send: {send_ms:.1f}ms | " + f"roundtrip: {roundtrip_ms:.1f}ms | " + f"unpickle: {unpickle_ms:.1f}ms | " + f"total: {total_ms:.1f}ms" + ) + return action_data + + def run_evaluation(self): + """Run evaluation comparing RTC and non-RTC on dataset samples.""" + logger.info("Starting evaluation...") + os.makedirs(self.cfg.output_dir, exist_ok=True) + logger.info(f"Output directory: {self.cfg.output_dir}") + + if not self.connect(): + logger.error("Failed to connect to server") + return + + logger.info("=" * 60) + logger.info("Starting RTC evaluation on dataset samples") + logger.info(f"Inference delay: {self.cfg.inference_delay}") + logger.info(f"Execution horizon: {self.cfg.rtc.execution_horizon}") + logger.info("=" * 60) + + # Load two random samples (send raw to server for preprocessing) + logger.debug("Loading samples from dataset...") + data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True) + loader_iter = iter(data_loader) + first_sample = next(loader_iter) + second_sample = next(loader_iter) + logger.debug("Samples loaded (sending raw to server)") + + # Step 1: Generate previous chunk (without RTC) + logger.info("=" * 60) + logger.info("Step 1: Generating previous chunk (baseline)") + logger.info("=" * 60) + + set_seed(self.cfg.seed) + + prev_chunk_response = self._request_actions( + observation=first_sample, # Send raw sample + inference_delay=0, + prev_chunk_left_over=None, + execution_horizon=0, + label="prev_chunk_baseline", + ) + prev_chunk_left_over = prev_chunk_response.original_actions[:25] + logger.info(f"Previous chunk shape: {prev_chunk_left_over.shape}") + + # Step 2: Generate actions WITHOUT RTC + logger.info("=" * 60) + logger.info("Step 2: Generating actions WITHOUT RTC") + logger.info("=" * 60) + + set_seed(self.cfg.seed) + + no_rtc_response = self._request_actions( + observation=second_sample, # Send raw sample + inference_delay=0, + prev_chunk_left_over=None, + execution_horizon=0, + label="no_rtc", + ) + no_rtc_actions = no_rtc_response.original_actions + logger.info(f"No-RTC actions shape: {no_rtc_actions.shape}") + + # Step 3: Generate actions WITH RTC + logger.info("=" * 60) + logger.info("Step 3: Generating actions WITH RTC") + logger.info("=" * 60) + + set_seed(self.cfg.seed) + + rtc_response = self._request_actions( + observation=second_sample, # Send raw sample + inference_delay=self.cfg.inference_delay, + prev_chunk_left_over=prev_chunk_left_over, + execution_horizon=self.cfg.rtc.execution_horizon, + label="rtc", + ) + rtc_actions = rtc_response.original_actions + logger.info(f"RTC actions shape: {rtc_actions.shape}") + + # Plot comparison + logger.info("=" * 80) + logger.info("Plotting results...") + self._plot_comparison(rtc_actions, no_rtc_actions, prev_chunk_left_over) + + logger.info("=" * 80) + logger.info("Evaluation completed successfully") + + profiling_artifacts = self.profiler.finalize() + if profiling_artifacts: + logger.info("Saved profiling artifacts:") + for name, path in profiling_artifacts.items(): + logger.info(f" - {name}: {path}") + + # Cleanup + self.channel.close() + + def _plot_comparison( + self, + rtc_actions: torch.Tensor, + no_rtc_actions: torch.Tensor, + prev_chunk: torch.Tensor, + ): + """Plot comparison of RTC vs non-RTC actions.""" + rtc_plot = rtc_actions.cpu().numpy() + no_rtc_plot = no_rtc_actions.cpu().numpy() + prev_chunk_plot = prev_chunk.cpu().numpy() + + num_dims = min(rtc_plot.shape[-1], 6) + + fig, axes = plt.subplots(num_dims, 1, figsize=(16, 12)) + fig.suptitle("Remote RTC Evaluation: Action Comparison", fontsize=16) + + for dim_idx in range(num_dims): + ax = axes[dim_idx] if num_dims > 1 else axes + + # Plot previous chunk (ground truth) + ax.plot( + range(len(prev_chunk_plot)), + prev_chunk_plot[:, dim_idx], + color="red", + linewidth=2.5, + alpha=0.8, + label="Previous Chunk (Ground Truth)" if dim_idx == 0 else None, + ) + + # Plot no-RTC actions + ax.plot( + range(len(no_rtc_plot)), + no_rtc_plot[:, dim_idx], + color="blue", + linewidth=2, + alpha=0.7, + label="No RTC" if dim_idx == 0 else None, + ) + + # Plot RTC actions + ax.plot( + range(len(rtc_plot)), + rtc_plot[:, dim_idx], + color="green", + linewidth=2, + alpha=0.7, + label="RTC" if dim_idx == 0 else None, + ) + + # Add vertical lines for inference delay and execution horizon + if self.cfg.inference_delay > 0: + ax.axvline( + x=self.cfg.inference_delay - 1, + color="orange", + linestyle="--", + alpha=0.5, + label=f"Inference Delay ({self.cfg.inference_delay})" if dim_idx == 0 else None, + ) + + if self.cfg.rtc.execution_horizon > 0: + ax.axvline( + x=self.cfg.rtc.execution_horizon, + color="purple", + linestyle="--", + alpha=0.5, + label=f"Execution Horizon ({self.cfg.rtc.execution_horizon})" if dim_idx == 0 else None, + ) + + ax.set_ylabel(f"Dim {dim_idx}", fontsize=10) + ax.grid(True, alpha=0.3) + + axes[-1].set_xlabel("Step", fontsize=10) if num_dims > 1 else axes.set_xlabel("Step", fontsize=10) + + # Add legend + handles, labels = (axes[0] if num_dims > 1 else axes).get_legend_handles_labels() + fig.legend( + handles, + labels, + loc="center right", + fontsize=9, + bbox_to_anchor=(1.0, 0.5), + framealpha=0.9, + ) + + output_path = os.path.join(self.cfg.output_dir, "remote_rtc_comparison.png") + fig.tight_layout(rect=[0, 0, 0.85, 1]) + fig.savefig(output_path, dpi=150, bbox_inches="tight") + logger.info(f"Saved comparison plot to {output_path}") + plt.close(fig) + + +@draccus.wrap() +def main(cfg: RTCEvalConfig): + """Main entry point for remote RTC dataset evaluation.""" + set_seed(cfg.seed) + + logger.info("Configuration:\n%s", pformat(asdict(cfg))) + logger.info("=" * 80) + logger.info("Remote RTC Dataset Evaluation") + logger.info("=" * 80) + + evaluator = RTCEvaluator(cfg) + evaluator.run_evaluation() + + +if __name__ == "__main__": + main() diff --git a/examples/remote_rtc/eval_with_real_robot.py b/examples/remote_rtc/eval_with_real_robot.py new file mode 100644 index 00000000000..5b9b8575f0a --- /dev/null +++ b/examples/remote_rtc/eval_with_real_robot.py @@ -0,0 +1,632 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Evaluate Real-Time Chunking (RTC) with a real robot using remote inference. + +This script controls a robot and communicates with a remote RTC policy server, +managing action queues with RTC-specific merging for smooth action execution. + +The server runs the heavy policy inference on a powerful machine (e.g., with GPU), +while this client runs on a lightweight computer connected to the robot. + +Usage: + # First, start the server on a powerful machine: + python examples/remote_rtc/rtc_policy_server.py \ + --host=0.0.0.0 \ + --port=8080 + + # Then, run this client on the robot's computer: + + # Run with SO100 robot and SmolVLA policy + python examples/remote_rtc/eval_with_real_robot.py \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58FA0834591 \ + --robot.id=so100_follower \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --server_address=192.168.1.100:8080 \ + --policy_type=smolvla \ + --pretrained_name_or_path=helper2424/smolvla_check_rtc_last3 \ + --policy_device=cuda \ + --task="Move the object" \ + --rtc.enabled=true \ + --rtc.execution_horizon=20 \ + --duration=120 + + # Run with Pi0.5 policy + python examples/remote_rtc/eval_with_real_robot.py \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58FA0834591 \ + --robot.id=so100_follower \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --server_address=192.168.1.100:8080 \ + --policy_type=pi05 \ + --pretrained_name_or_path=lerobot/pi05_libero_finetuned \ + --policy_device=cuda \ + --task="Pick up the cube" \ + --rtc.enabled=true \ + --rtc.execution_horizon=20 \ + --duration=120 +""" + +import logging +import math +import pickle # nosec +import sys +import threading +import time +import traceback +from dataclasses import asdict, dataclass, field +from pprint import pformat +from typing import Any + +import draccus +import grpc +import torch + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.configs.types import RTCAttentionSchedule +from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.policies.rtc.action_queue import ActionQueue +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.latency_tracker import LatencyTracker +from lerobot.policies.rtc.profiling import RTCProfiler, RTCProfilingRecord +from lerobot.policies.rtc.remote import RTCActionData, RTCObservationData, RTCRemotePolicyConfig +from lerobot.processor.factory import ( + make_default_robot_action_processor, + make_default_robot_observation_processor, +) +from lerobot.rl.process import ProcessSignalHandler +from lerobot.robots import Robot, RobotConfig, make_robot_from_config +from lerobot.transport import ( + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore +) +from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks +from lerobot.utils.import_utils import register_third_party_plugins +from lerobot.utils.utils import init_logging + +logger = logging.getLogger(__name__) + + +@dataclass +class RobotClientConfig: + """Configuration for RTC Robot Client.""" + + # Robot configuration + robot: RobotConfig = field(metadata={"help": "Robot configuration"}) + + # Policy configuration + policy_type: str = field(metadata={"help": "Type of policy (smolvla, pi0, pi05)"}) + pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"}) + policy_device: str = field(default="cuda", metadata={"help": "Device for policy inference on server"}) + + # Network configuration + server_address: str = field(default="localhost:8080", metadata={"help": "Server address"}) + + # Task configuration + task: str = field(default="", metadata={"help": "Task instruction"}) + + # RTC configuration + rtc: RTCConfig = field( + default_factory=lambda: RTCConfig( + enabled=True, + execution_horizon=20, + max_guidance_weight=10.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + ) + ) + + # Control configuration + fps: float = field(default=10.0, metadata={"help": "Action execution frequency (Hz)"}) + duration: float = field(default=60.0, metadata={"help": "Duration to run (seconds)"}) + + # Action queue threshold - when queue size drops below this, request new actions + action_queue_threshold: int = field( + default=30, + metadata={"help": "Request new actions when queue size drops below this value"}, + ) + enable_profiling: bool = field( + default=False, + metadata={"help": "Collect per-request timings and queue metrics"}, + ) + profiling_output_dir: str = field( + default="rtc_remote_profile_output", + metadata={"help": "Directory for profiling artifacts"}, + ) + profiling_run_name: str = field( + default="remote_rtc_robot", + metadata={"help": "Filename prefix for profiling artifacts"}, + ) + verbose_request_logging: bool = field( + default=False, + metadata={"help": "Enable per-request timing logs"}, + ) + use_torch_compile: bool = field( + default=False, + metadata={"help": "Enable torch.compile on the server policy"}, + ) + torch_compile_mode: str = field( + default="reduce-overhead", + metadata={"help": "torch.compile mode (reduce-overhead, max-autotune, default)"}, + ) + compile_warmup_delay: list[int] = field( + default_factory=lambda: [0, 4], + metadata={"help": "Warmup inference delays per call, e.g. [0,4,5,6]. Empty list disables warmup."}, + ) + + def __post_init__(self): + if not self.server_address: + raise ValueError("server_address cannot be empty") + if not self.policy_type: + raise ValueError("policy_type cannot be empty") + if not self.pretrained_name_or_path: + raise ValueError("pretrained_name_or_path cannot be empty") + if any(delay < 0 for delay in self.compile_warmup_delay): + raise ValueError("All compile_warmup_delay values must be >= 0") + + @property + def environment_dt(self) -> float: + return 1 / self.fps + + +class RobotWrapper: + """Thread-safe wrapper for robot access.""" + + def __init__(self, robot: Robot): + self.robot = robot + self.lock = threading.Lock() + + def get_observation(self) -> dict[str, Any]: + with self.lock: + return self.robot.get_observation() + + def send_action(self, action: Any): + with self.lock: + return self.robot.send_action(action) + + def observation_features(self) -> list[str]: + with self.lock: + return self.robot.observation_features + + def action_features(self) -> list[str]: + with self.lock: + return self.robot.action_features + + +class RobotClient: + """Robot client with RTC action queue management.""" + + def __init__(self, config: RobotClientConfig): + self.config = config + self.shutdown_event = threading.Event() + self.request_idx = 0 + + # Initialize robot + logger.info(f"Initializing robot: {config.robot.type}") + self.robot = make_robot_from_config(config.robot) + self.robot.connect() + self.robot_wrapper = RobotWrapper(self.robot) + + # Create lerobot features mapping + self.lerobot_features = hw_to_dataset_features(self.robot.observation_features, "observation") + + # Initialize gRPC connection + self.channel = grpc.insecure_channel( + config.server_address, + grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s"), + ) + self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel) + + # Initialize RTC action queue + self.action_queue = ActionQueue(config.rtc) + + # Latency tracking for inference delay calculation + self.latency_tracker = LatencyTracker() + self.profiler = RTCProfiler( + config.enable_profiling, + config.profiling_output_dir, + config.profiling_run_name, + ) + + # Robot processors + self.robot_observation_processor = make_default_robot_observation_processor() + self.robot_action_processor = make_default_robot_action_processor() + + logger.info(f"RobotClient initialized, connecting to {config.server_address}") + + @property + def running(self): + return not self.shutdown_event.is_set() + + def start(self) -> bool: + """Connect to server and send policy instructions.""" + try: + # Handshake + start_time = time.perf_counter() + self.stub.Ready(services_pb2.Empty()) + logger.info(f"Connected to server in {time.perf_counter() - start_time:.4f}s") + + # Send policy configuration + policy_config = RTCRemotePolicyConfig( + policy_type=self.config.policy_type, + pretrained_name_or_path=self.config.pretrained_name_or_path, + lerobot_features=self.lerobot_features, + rtc_config=self.config.rtc, + device=self.config.policy_device, + use_torch_compile=self.config.use_torch_compile, + torch_compile_mode=self.config.torch_compile_mode, + ) + + policy_config_bytes = pickle.dumps(policy_config) + self.stub.SendPolicyInstructions(services_pb2.PolicySetup(data=policy_config_bytes)) + + logger.info( + f"Policy instructions sent | " + f"Type: {self.config.policy_type} | " + f"Device: {self.config.policy_device} | " + f"Compile: {self.config.use_torch_compile} ({self.config.torch_compile_mode})" + ) + + return True + + except grpc.RpcError as e: + logger.error(f"Failed to connect to server: {e}") + return False + + def stop(self): + """Stop the client and cleanup.""" + self.shutdown_event.set() + self.robot.disconnect() + self.channel.close() + logger.info("Client stopped") + + def save_profiling_artifacts(self) -> dict[str, str]: + artifacts = self.profiler.finalize() + if artifacts: + logger.info("Saved profiling artifacts:") + for name, path in artifacts.items(): + logger.info(f" - {name}: {path}") + return artifacts + + def _prepare_observation(self, task: str) -> dict[str, Any]: + """Capture and prepare observation for sending to server.""" + raw_obs = self.robot_wrapper.get_observation() + + # Apply robot observation processor + obs_processed = self.robot_observation_processor(raw_obs) + + # Build dataset frame with proper keys + obs_with_features = build_dataset_frame(self.lerobot_features, obs_processed, prefix="observation") + + # Convert to tensors and prepare for policy + for name in obs_with_features: + obs_with_features[name] = torch.from_numpy(obs_with_features[name]) + if "image" in name: + obs_with_features[name] = obs_with_features[name].type(torch.float32) / 255 + obs_with_features[name] = obs_with_features[name].permute(2, 0, 1).contiguous() + obs_with_features[name] = obs_with_features[name].unsqueeze(0) + + obs_with_features["task"] = [task] + obs_with_features["robot_type"] = self.robot.name if hasattr(self.robot, "name") else "" + + return obs_with_features + + def _run_remote_request( + self, + observation: dict[str, Any], + *, + action_index_before: int, + queue_size_before: int, + inference_delay: int, + prev_actions: torch.Tensor | None, + execution_horizon: int, + label: str, + merge_actions: bool, + observation_ms: float, + ) -> tuple[RTCActionData, float, int]: + request_idx = self.request_idx + request_start = time.perf_counter() + + rtc_obs = RTCObservationData( + observation=observation, + timestamp=time.time(), + timestep=action_index_before, + inference_delay=inference_delay, + prev_chunk_left_over=prev_actions, + execution_horizon=execution_horizon, + ) + + obs_bytes = pickle.dumps(rtc_obs) + pickle_done = time.perf_counter() + client_pickle_ms = (pickle_done - request_start) * 1000 + + obs_iterator = send_bytes_in_chunks( + obs_bytes, + services_pb2.Observation, + log_prefix="[CLIENT] Observation", + silent=True, + ) + self.stub.SendObservations(obs_iterator) + send_done = time.perf_counter() + client_send_ms = (send_done - pickle_done) * 1000 + + actions_response = self.stub.GetActions(services_pb2.Empty()) + response_done = time.perf_counter() + client_get_actions_ms = (response_done - send_done) * 1000 + + if len(actions_response.data) == 0: + raise RuntimeError("Empty response from server") + + rtc_action_data: RTCActionData = pickle.loads(actions_response.data) # nosec + unpickle_done = time.perf_counter() + client_unpickle_ms = (unpickle_done - response_done) * 1000 + + new_latency = unpickle_done - request_start + client_total_ms = new_latency * 1000 + time_per_step = 1.0 / self.config.fps + new_delay = math.ceil(new_latency / time_per_step) + applied_delay = new_delay + + if merge_actions: + applied_delay = self.action_queue.merge( + rtc_action_data.original_actions, + rtc_action_data.actions, + new_delay, + action_index_before, + ) + queue_size_after = self.action_queue.qsize() + else: + queue_size_after = queue_size_before + + server_timing = getattr(rtc_action_data, "timing", None) + self.profiler.add( + RTCProfilingRecord( + request_idx=request_idx, + timestamp=time.time(), + label=label, + payload_bytes=len(obs_bytes), + queue_size_before=queue_size_before, + queue_size_after=queue_size_after, + action_index_before=action_index_before, + inference_delay_requested=inference_delay, + realized_delay=applied_delay, + client_observation_ms=observation_ms, + client_pickle_ms=client_pickle_ms, + client_send_ms=client_send_ms, + client_get_actions_ms=client_get_actions_ms, + client_unpickle_ms=client_unpickle_ms, + client_total_ms=client_total_ms, + server_queue_wait_ms=(server_timing.queue_wait_ms if server_timing is not None else None), + server_preprocess_ms=(server_timing.preprocess_ms if server_timing is not None else None), + server_inference_ms=(server_timing.inference_ms if server_timing is not None else None), + server_postprocess_ms=(server_timing.postprocess_ms if server_timing is not None else None), + server_pickle_ms=(server_timing.pickle_ms if server_timing is not None else None), + server_total_ms=server_timing.total_ms if server_timing is not None else None, + ) + ) + self.request_idx += 1 + + if self.config.verbose_request_logging: + logger.info( + f"[GET_ACTIONS] {label} | " + f"total: {client_total_ms:.1f}ms | " + f"delay: {applied_delay} | " + f"queue: {queue_size_after}" + ) + + return rtc_action_data, new_latency, applied_delay + + def warmup_compiled_policy(self) -> None: + warmup_delays = list(self.config.compile_warmup_delay) + if len(warmup_delays) == 0: + return + + logger.info( + "Running remote warmup requests: %d, delays=%s", + len(warmup_delays), + warmup_delays, + ) + prev_actions = None + + for warmup_idx, delay in enumerate(warmup_delays): + observation_start = time.perf_counter() + observation = self._prepare_observation(self.config.task) + observation_ms = (time.perf_counter() - observation_start) * 1000 + + try: + rtc_action_data, warmup_latency, _ = self._run_remote_request( + observation, + action_index_before=0, + queue_size_before=self.action_queue.qsize(), + inference_delay=delay, + prev_actions=prev_actions, + execution_horizon=self.config.rtc.execution_horizon, + label="warmup", + merge_actions=False, + observation_ms=observation_ms, + ) + logger.info("Warmup %d/%d: %.1fms", warmup_idx + 1, len(warmup_delays), warmup_latency * 1000) + except RuntimeError: + logger.warning("Warmup request returned empty response, stopping warmup early") + break + + if warmup_idx < len(warmup_delays) - 1: + chunk_size = int(rtc_action_data.original_actions.shape[0]) + next_delay = warmup_delays[warmup_idx + 1] + if next_delay < chunk_size: + prev_actions = rtc_action_data.original_actions[next_delay:].clone() + else: + prev_actions = None + + self.action_queue.clear() + self.latency_tracker = LatencyTracker() + logger.info("Remote warmup finished") + + def get_actions_thread(self): + """Thread function to request action chunks from remote server.""" + try: + logger.info("[GET_ACTIONS] Starting get actions thread") + + threshold = self.config.action_queue_threshold + + if not self.config.rtc.enabled: + threshold = 0 + + while self.running: + if self.action_queue.qsize() <= threshold: + queue_size_before = self.action_queue.qsize() + action_index_before = self.action_queue.get_action_index() + prev_actions = self.action_queue.get_left_over() + + # Calculate inference delay from latency + time_per_step = 1.0 / self.config.fps + inference_delay = math.ceil(self.latency_tracker.max() / time_per_step) + + # Prepare observation + observation_start = time.perf_counter() + observation = self._prepare_observation(self.config.task) + observation_ms = (time.perf_counter() - observation_start) * 1000 + + try: + _, new_latency, new_delay = self._run_remote_request( + observation, + action_index_before=action_index_before, + queue_size_before=queue_size_before, + inference_delay=inference_delay, + prev_actions=prev_actions, + execution_horizon=self.config.rtc.execution_horizon, + label="robot_live", + merge_actions=True, + observation_ms=observation_ms, + ) + except RuntimeError: + logger.warning("[GET_ACTIONS] Empty response from server") + continue + self.latency_tracker.add(new_latency) + + # Warn if threshold is too small + if self.config.action_queue_threshold < self.config.rtc.execution_horizon + new_delay: + logger.warning( + "[GET_ACTIONS] action_queue_threshold too small. " + f"Should be > execution_horizon + delay = " + f"{self.config.rtc.execution_horizon + new_delay}" + ) + + else: + time.sleep(0.01) + + logger.info("[GET_ACTIONS] Thread shutting down") + + except Exception as e: + logger.error(f"[GET_ACTIONS] Fatal error: {e}") + traceback.print_exc() + sys.exit(1) + + def actor_thread(self): + """Thread function to execute actions on the robot.""" + try: + logger.info("[ACTOR] Starting actor thread") + + action_count = 0 + action_interval = 1.0 / self.config.fps + + while self.running: + start_time = time.perf_counter() + + action = self.action_queue.get() + + if action is not None: + action = action.cpu() + action_dict = { + key: action[i].item() for i, key in enumerate(self.robot_wrapper.action_features()) + } + action_processed = self.robot_action_processor((action_dict, None)) + self.robot_wrapper.send_action(action_processed) + action_count += 1 + + dt = time.perf_counter() - start_time + time.sleep(max(0, action_interval - dt - 0.001)) + + logger.info(f"[ACTOR] Thread shutting down. Total actions: {action_count}") + + except Exception as e: + logger.error(f"[ACTOR] Fatal error: {e}") + traceback.print_exc() + sys.exit(1) + + +@draccus.wrap() +def main(cfg: RobotClientConfig): + """Main entry point for RTC Robot Client.""" + init_logging() + logger.info(pformat(asdict(cfg))) + + signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False) + shutdown_event = signal_handler.shutdown_event + + client = RobotClient(cfg) + + if not client.start(): + logger.error("Failed to connect to server") + return + + client.warmup_compiled_policy() + + # Start threads + get_actions_thread = threading.Thread( + target=client.get_actions_thread, + daemon=True, + name="GetActions", + ) + get_actions_thread.start() + + actor_thread = threading.Thread( + target=client.actor_thread, + daemon=True, + name="Actor", + ) + actor_thread.start() + + logger.info(f"Running for {cfg.duration} seconds...") + start_time = time.time() + + try: + while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration: + time.sleep(1.0) + + if int(time.time() - start_time) % 5 == 0: + logger.info(f"[MAIN] Queue size: {client.action_queue.qsize()}") + + except KeyboardInterrupt: + logger.info("Interrupted by user") + + finally: + logger.info("Shutting down...") + client.shutdown_event.set() + + get_actions_thread.join(timeout=5) + actor_thread.join(timeout=5) + + client.save_profiling_artifacts() + client.stop() + logger.info("Cleanup completed") + + +if __name__ == "__main__": + register_third_party_plugins() + main() diff --git a/examples/remote_rtc/rtc_policy_server.py b/examples/remote_rtc/rtc_policy_server.py new file mode 100644 index 00000000000..7eca8f8cdf9 --- /dev/null +++ b/examples/remote_rtc/rtc_policy_server.py @@ -0,0 +1,509 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +RTC Policy Server - Remote inference server with Real-Time Chunking support. + +This server runs diffusion-based policies (SmolVLA, Pi0, Pi0.5) with RTC on a powerful +remote machine, allowing lightweight robot computers to control robots smoothly. + +Usage: + python examples/remote_rtc/rtc_policy_server.py \ + --host=0.0.0.0 \ + --port=8080 +""" + +import contextlib +import logging +import pickle # nosec +import threading +import time +from concurrent import futures +from dataclasses import asdict, dataclass, field +from pprint import pformat +from queue import Empty, Queue + +import draccus +import grpc +import torch + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.policies.rtc.remote import ( + RTCActionData, + RTCObservationData, + RTCRemotePolicyConfig, + RTCTimingData, +) +from lerobot.processor import PolicyProcessorPipeline +from lerobot.transport import ( + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore +) +from lerobot.transport.utils import receive_bytes_in_chunks +from lerobot.utils.utils import init_logging + +logger = logging.getLogger(__name__) + + +SUPPORTED_POLICIES = ["smolvla", "pi0", "pi05"] + + +@dataclass +class RTCPolicyServerConfig: + """Configuration for RTC Policy Server.""" + + host: str = field(default="0.0.0.0", metadata={"help": "Host address to bind the server to"}) + port: int = field(default=8080, metadata={"help": "Port number to bind the server to"}) + obs_queue_timeout: float = field( + default=1.0, metadata={"help": "Timeout for observation queue in seconds"} + ) + + verbose_request_logging: bool = field( + default=False, + metadata={"help": "Enable detailed per-request timing logs"}, + ) + client_unavailable_timeout_s: float = field( + default=2.0, + metadata={ + "help": ( + "Reset/unload the server (freeing VRAM) if no client RPCs arrive for this many seconds. " + "Set <= 0 to disable." + ) + }, + ) + + def __post_init__(self): + if self.port < 1 or self.port > 65535: + raise ValueError(f"Port must be between 1 and 65535, got {self.port}") + + +class RTCPolicyServer(services_pb2_grpc.AsyncInferenceServicer): + """gRPC server for RTC policy inference.""" + + def __init__(self, config: RTCPolicyServerConfig): + self.config = config + self.shutdown_event = threading.Event() + self.observation_queue: Queue[RTCObservationData] = Queue(maxsize=1) + self._rpc_state_lock = threading.Lock() + self._active_rpcs = 0 + self._client_unavailable_timer: threading.Timer | None = None + self._has_received_observation = False + + # Policy components (initialized by SendPolicyInstructions) + self.device = None + self.policy_type = None + self.lerobot_features = None + self.policy = None + self.preprocessor: PolicyProcessorPipeline | None = None + self.postprocessor: PolicyProcessorPipeline | None = None + + logger.info(f"RTCPolicyServer initialized with config: {config}") + + @property + def running(self): + return not self.shutdown_event.is_set() + + def _rpc_enter(self) -> None: + with self._rpc_state_lock: + self._active_rpcs += 1 + if self._client_unavailable_timer is not None: + self._client_unavailable_timer.cancel() + self._client_unavailable_timer = None + + def _rpc_exit(self) -> None: + with self._rpc_state_lock: + self._active_rpcs = max(0, self._active_rpcs - 1) + if self._active_rpcs == 0: + if not self._has_received_observation: + return + timeout_s = self.config.client_unavailable_timeout_s + if timeout_s <= 0: + return + if self._client_unavailable_timer is not None: + self._client_unavailable_timer.cancel() + self._client_unavailable_timer = threading.Timer(timeout_s, self._on_client_unavailable) + self._client_unavailable_timer.daemon = True + self._client_unavailable_timer.start() + + def _unload_policy(self, reason: str) -> None: + with self._rpc_state_lock: + self._has_received_observation = False + + if self.policy is None: + return + + logger.warning("Unloading policy to free VRAM (reason=%s)", reason) + + policy = self.policy + preprocessor = self.preprocessor + postprocessor = self.postprocessor + device = self.device + + self.policy = None + self.preprocessor = None + self.postprocessor = None + self.device = None + self.policy_type = None + self.lerobot_features = None + self.observation_queue = Queue(maxsize=1) + + del policy, preprocessor, postprocessor + + try: + import gc + + gc.collect() + except Exception as e: + logger.debug("gc.collect failed: %s", e) + + if device is not None and torch.cuda.is_available() and "cuda" in str(device): + try: + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + except Exception as e: + logger.debug("Failed to clear CUDA cache: %s", e) + + def _on_client_unavailable(self) -> None: + with self._rpc_state_lock: + if self._active_rpcs != 0: + return + self._client_unavailable_timer = None + + self._reset_server() + self._unload_policy(reason=f"client_unavailable_{self.config.client_unavailable_timeout_s}s") + + def _reset_server(self) -> None: + """Reset server state when new client connects.""" + self.shutdown_event.set() + self.observation_queue = Queue(maxsize=1) + + def Ready(self, request, context): # noqa: N802 + """Handle client ready signal.""" + self._rpc_enter() + context.add_callback(self._rpc_exit) + client_id = context.peer() + logger.info(f"Client {client_id} connected and ready") + self._reset_server() + with self._rpc_state_lock: + no_other_rpcs = self._active_rpcs == 1 + if no_other_rpcs: + self._unload_policy(reason="new_client_ready") + self.shutdown_event.clear() + return services_pb2.Empty() + + def SendPolicyInstructions(self, request, context): # noqa: N802 + """Receive policy configuration from client and initialize policy.""" + self._rpc_enter() + context.add_callback(self._rpc_exit) + if not self.running: + logger.warning("Server is not running. Ignoring policy instructions.") + return services_pb2.Empty() + + client_id = context.peer() + policy_specs = pickle.loads(request.data) # nosec + + if not isinstance(policy_specs, RTCRemotePolicyConfig): + raise TypeError(f"Expected RTCRemotePolicyConfig, got {type(policy_specs)}") + + if policy_specs.policy_type not in SUPPORTED_POLICIES: + raise ValueError( + f"Policy type {policy_specs.policy_type} not supported. " + f"Supported policies: {SUPPORTED_POLICIES}" + ) + + logger.info( + f"Receiving policy instructions from {client_id} | " + f"Policy type: {policy_specs.policy_type} | " + f"Pretrained: {policy_specs.pretrained_name_or_path} | " + f"Device: {policy_specs.device}" + ) + + self.device = policy_specs.device + self.policy_type = policy_specs.policy_type + self.lerobot_features = policy_specs.lerobot_features + + # Load policy + self._unload_policy(reason="replacing_existing_policy") + policy_class = get_policy_class(self.policy_type) + start = time.perf_counter() + + use_compile = getattr(policy_specs, "use_torch_compile", False) + compile_mode = getattr(policy_specs, "torch_compile_mode", "reduce-overhead") + + # Load policy config, applying client overrides + policy_cfg = PreTrainedConfig.from_pretrained(policy_specs.pretrained_name_or_path) + policy_cfg.device = policy_specs.device + + chunk_size = getattr(policy_specs, "chunk_size", None) + n_action_steps = getattr(policy_specs, "n_action_steps", None) + if chunk_size is not None: + policy_cfg.chunk_size = chunk_size + logger.info(f"Overriding chunk_size={chunk_size}") + if n_action_steps is not None: + policy_cfg.n_action_steps = n_action_steps + logger.info(f"Overriding n_action_steps={n_action_steps}") + + if use_compile and self.policy_type in ["pi05", "pi0"]: + torch._inductor.config.fx_graph_cache = True + torch._inductor.config.fx_graph_remote_cache = False + logger.info("Enabled persistent FX graph cache for torch.compile") + policy_cfg.compile_model = True + if compile_mode == "max-autotune": + compile_mode = "max-autotune-no-cudagraphs" + policy_cfg.compile_mode = compile_mode + + self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path, config=policy_cfg) + self.policy.to(self.device) + self.policy.eval() + + # Configure RTC from client config + rtc_config = getattr(policy_specs, "rtc_config", None) + if rtc_config is not None: + self.policy.config.rtc_config = rtc_config + self.policy.init_rtc_processor() + + # Apply torch.compile for non-pi0 policies (pi0/pi05 handle it internally) + if use_compile and self.policy_type not in ("pi05", "pi0"): + try: + logger.info("Applying torch.compile to predict_action_chunk...") + self.policy.predict_action_chunk = torch.compile( + self.policy.predict_action_chunk, + backend="inductor", + mode=compile_mode, + ) + logger.info("Successfully compiled predict_action_chunk") + except Exception as e: + logger.error(f"Failed to apply torch.compile: {e}") + + # Load preprocessor and postprocessor + device_override = {"device": self.device} + self.preprocessor, self.postprocessor = make_pre_post_processors( + self.policy.config, + pretrained_path=policy_specs.pretrained_name_or_path, + preprocessor_overrides={ + "device_processor": device_override, + "rename_observations_processor": {"rename_map": policy_specs.rename_map}, + }, + postprocessor_overrides={"device_processor": device_override}, + ) + + end = time.perf_counter() + logger.info(f"Policy loaded on {self.device} in {end - start:.4f} seconds") + logger.info(f"RTC config: {self.policy.config.rtc_config}") + + return services_pb2.Empty() + + def SendObservations(self, request_iterator, context): # noqa: N802 + """Receive observations with RTC parameters from client.""" + self._rpc_enter() + context.add_callback(self._rpc_exit) + logger.debug("SendObservations called, receiving data...") + t_start = time.perf_counter() + + received_bytes = receive_bytes_in_chunks(request_iterator, None, self.shutdown_event, logger) + if received_bytes is None: + return services_pb2.Empty() + + with self._rpc_state_lock: + self._has_received_observation = True + + t_receive = time.perf_counter() + receive_ms = (t_receive - t_start) * 1000 + + rtc_obs_data: RTCObservationData = pickle.loads(received_bytes) # nosec + t_unpickle = time.perf_counter() + unpickle_ms = (t_unpickle - t_receive) * 1000 + + if self.config.verbose_request_logging: + prev_shape = ( + tuple(rtc_obs_data.prev_chunk_left_over.shape) + if rtc_obs_data.prev_chunk_left_over is not None + else None + ) + logger.info( + f"Observation received | " + f"bytes: {len(received_bytes)} | " + f"receive: {receive_ms:.1f}ms | " + f"unpickle: {unpickle_ms:.1f}ms | " + f"inference_delay: {rtc_obs_data.inference_delay} | " + f"execution_horizon: {rtc_obs_data.execution_horizon} | " + f"prev_chunk_left_over: {prev_shape}" + ) + + # Enqueue observation (replacing old one if queue is full) + if self.observation_queue.full(): + with contextlib.suppress(Empty): + self.observation_queue.get_nowait() + + rtc_obs_data._server_receive_time = t_start # Store for end-to-end timing + self.observation_queue.put(rtc_obs_data) + logger.debug("Observation queued") + + return services_pb2.Empty() + + def GetActions(self, request, context): # noqa: N802 + """Run RTC inference and return actions to client.""" + self._rpc_enter() + context.add_callback(self._rpc_exit) + try: + if self.policy is None or self.preprocessor is None or self.postprocessor is None: + return services_pb2.Actions(data=b"") + + logger.debug("GetActions called, waiting for observation...") + wait_start = time.perf_counter() + rtc_obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout) + wait_end = time.perf_counter() + + logger.debug( + f"Running inference | delay={rtc_obs.inference_delay} | horizon={rtc_obs.execution_horizon}" + ) + + t_start = time.perf_counter() + + # Preprocess observation + logger.debug("Preprocessing observation...") + observation = rtc_obs.observation + preprocessed_obs = self.preprocessor(observation) + t_preprocess = time.perf_counter() + + # Run policy with RTC parameters + logger.debug("Running predict_action_chunk...") + with torch.no_grad(): + actions = self.policy.predict_action_chunk( + preprocessed_obs, + inference_delay=rtc_obs.inference_delay, + prev_chunk_left_over=rtc_obs.prev_chunk_left_over, + execution_horizon=rtc_obs.execution_horizon, + ) + t_inference = time.perf_counter() + logger.debug("predict_action_chunk completed") + + logger.debug("Postprocessing actions...") + # Store original actions for RTC tracking + original_actions = actions.squeeze(0).clone() + + # Postprocess actions + postprocessed_actions = self.postprocessor(actions) + postprocessed_actions = postprocessed_actions.squeeze(0) + t_postprocess = time.perf_counter() + + # Calculate detailed timing + queue_wait_ms = (wait_end - wait_start) * 1000 + preprocess_ms = (t_preprocess - t_start) * 1000 + inference_ms = (t_inference - t_preprocess) * 1000 + postprocess_ms = (t_postprocess - t_inference) * 1000 + server_compute_total_ms = queue_wait_ms + preprocess_ms + inference_ms + postprocess_ms + + # Create response + rtc_action_data = RTCActionData( + actions=postprocessed_actions.cpu(), + original_actions=original_actions.cpu(), + timestamp=time.time(), + timestep=rtc_obs.timestep, + timing=RTCTimingData( + queue_wait_ms=queue_wait_ms, + preprocess_ms=preprocess_ms, + inference_ms=inference_ms, + postprocess_ms=postprocess_ms, + total_ms=server_compute_total_ms, + ), + ) + + actions_bytes = pickle.dumps(rtc_action_data) + t_pickle = time.perf_counter() + pickle_ms = (t_pickle - t_postprocess) * 1000 + total_ms = (t_pickle - t_start) * 1000 + + # Calculate server-side total if we have receive time + server_total_ms = None + if hasattr(rtc_obs, "_server_receive_time"): + server_total_ms = (t_pickle - rtc_obs._server_receive_time) * 1000 + + log_message = ( + f"Actions ready | " + f"queue_wait: {queue_wait_ms:.1f}ms | " + f"preprocess: {preprocess_ms:.1f}ms | " + f"inference: {inference_ms:.1f}ms | " + f"postprocess: {postprocess_ms:.1f}ms | " + f"pickle: {pickle_ms:.1f}ms | " + f"total: {total_ms:.1f}ms" + + (f" | server_total: {server_total_ms:.1f}ms" if server_total_ms else "") + + f" | shape: {postprocessed_actions.shape}" + ) + if self.config.verbose_request_logging: + logger.info(log_message) + else: + logger.debug(log_message) + + return services_pb2.Actions(data=actions_bytes) + + except Empty: + logger.debug("GetActions timeout - no observation in queue") + return services_pb2.Actions(data=b"") + + except Exception as e: + logger.error(f"Error in GetActions: {e}") + import traceback + + traceback.print_exc() + return services_pb2.Actions(data=b"") + + def stop(self): + """Stop the server.""" + with self._rpc_state_lock: + if self._client_unavailable_timer is not None: + self._client_unavailable_timer.cancel() + self._client_unavailable_timer = None + self._reset_server() + self._unload_policy(reason="server_stop") + logger.info("Server stopping...") + + +@draccus.wrap() +def serve(cfg: RTCPolicyServerConfig): + """Start the RTC Policy Server.""" + init_logging() + logger.info("Configuration:\n%s", pformat(asdict(cfg))) + + logger.info("Creating RTCPolicyServer...") + policy_server = RTCPolicyServer(cfg) + + logger.info("Creating gRPC server...") + server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) + services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) + server.add_insecure_port(f"{cfg.host}:{cfg.port}") + + server.start() + + logger.info("=" * 60) + logger.info(f"RTC Policy Server running on {cfg.host}:{cfg.port}") + logger.info("Waiting for client connections...") + logger.info("Press Ctrl+C to stop") + logger.info("=" * 60) + + try: + server.wait_for_termination() + except KeyboardInterrupt: + logger.info("Shutting down...") + policy_server.stop() + server.stop(grace=5) + + logger.info("Server terminated") + + +if __name__ == "__main__": + serve() diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 4c803eb7e42..1af59a801cf 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -101,6 +101,7 @@ from lerobot.robots.utils import make_robot_from_config from lerobot.utils.constants import OBS_IMAGES from lerobot.utils.hub import HubMixin +from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.utils import init_logging logging.basicConfig(level=logging.INFO) @@ -174,8 +175,13 @@ class RTCDemoConfig(HubMixin): ) torch_compile_mode: str = field( - default="default", - metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"}, + default="reduce-overhead", + metadata={ + "help": ( + "Compilation mode (default, reduce-overhead, max-autotune, " + "max-autotune-no-cudagraphs)" + ) + }, ) torch_compile_disable_cudagraphs: bool = field( @@ -186,6 +192,11 @@ class RTCDemoConfig(HubMixin): }, ) + compile_warmup_delay: list[int] = field( + default_factory=lambda: [0, 4], + metadata={"help": "Warmup inference delays per call, e.g. [0,4]. Empty list disables warmup."}, + ) + def __post_init__(self): # HACK: We parse again the cli args here to get the pretrained path if there was one. policy_path = parser.get_path_arg("policy") @@ -200,6 +211,9 @@ def __post_init__(self): if self.robot is None: raise ValueError("Robot configuration must be provided") + if any(delay < 0 for delay in self.compile_warmup_delay): + raise ValueError("All compile_warmup_delay values must be >= 0") + @classmethod def __get_path_fields__(cls) -> list[str]: """This enables the parser to load config from the policy using `--policy.path=local/dir`""" @@ -210,10 +224,104 @@ def is_image_key(k: str) -> bool: return k.startswith(OBS_IMAGES) +def _prepare_policy_inputs( + robot: RobotWrapper, + robot_observation_processor, + dataset_features, + policy_device: str, + preprocessor, + task: str, +): + """Prepare a single observation for policy inference.""" + obs = robot.get_observation() + obs_processed = robot_observation_processor(obs) + obs_with_policy_features = build_dataset_frame(dataset_features, obs_processed, prefix="observation") + + for name in obs_with_policy_features: + obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) + if "image" in name: + obs_with_policy_features[name] = obs_with_policy_features[name].type(torch.float32) / 255 + obs_with_policy_features[name] = obs_with_policy_features[name].permute(2, 0, 1).contiguous() + obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) + obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) + + obs_with_policy_features["task"] = [task] + obs_with_policy_features["robot_type"] = robot.robot.name if hasattr(robot.robot, "name") else "" + return preprocessor(obs_with_policy_features) + + +def run_compile_warmup( + policy, + robot: RobotWrapper, + robot_observation_processor, + dataset_features, + preprocessor, + cfg: RTCDemoConfig, +) -> None: + """Run warmup inference calls to trigger torch.compile before the robot starts moving.""" + warmup_delays = list(cfg.compile_warmup_delay) + if not cfg.use_torch_compile or len(warmup_delays) == 0: + return + + logger.info( + "Running compile warmup before RTC start (%d calls), delays=%s", + len(warmup_delays), + warmup_delays, + ) + + policy_device = policy.config.device + warmup_prev_actions: Tensor | None = None + + warmup_total_start = time.perf_counter() + for warmup_idx, warmup_delay in enumerate(warmup_delays): + step_start = time.perf_counter() + logger.info( + "Compile warmup step %d/%d (delay=%d, prev=%s)...", + warmup_idx + 1, + len(warmup_delays), + warmup_delay, + "None" if warmup_prev_actions is None else f"shape {tuple(warmup_prev_actions.shape)}", + ) + preprocessed_obs = _prepare_policy_inputs( + robot=robot, + robot_observation_processor=robot_observation_processor, + dataset_features=dataset_features, + policy_device=policy_device, + preprocessor=preprocessor, + task=cfg.task, + ) + + with torch.no_grad(): + actions = policy.predict_action_chunk( + preprocessed_obs, + inference_delay=warmup_delay, + prev_chunk_left_over=warmup_prev_actions, + ) + + step_elapsed = time.perf_counter() - step_start + logger.info("Compile warmup step %d/%d done in %.1fs", warmup_idx + 1, len(warmup_delays), step_elapsed) + + original_actions = actions.squeeze(0).clone() + chunk_size = int(original_actions.shape[0]) + + if warmup_idx < len(warmup_delays) - 1: + next_delay = warmup_delays[warmup_idx + 1] + if next_delay < chunk_size: + warmup_prev_actions = original_actions[next_delay:].clone() + else: + warmup_prev_actions = None + + total_elapsed = time.perf_counter() - warmup_total_start + logger.info("Compile warmup finished in %.1fs", total_elapsed) + + def get_actions( policy, robot: RobotWrapper, robot_observation_processor, + dataset_features, + preprocessor, + postprocessor, action_queue: ActionQueue, shutdown_event: Event, cfg: RTCDemoConfig, @@ -224,6 +332,9 @@ def get_actions( policy: The policy instance (SmolVLA, Pi0, etc.) robot: The robot instance for getting observations robot_observation_processor: Processor for raw robot observations + dataset_features: Dataset feature definitions for observation conversion + preprocessor: Policy preprocessor + postprocessor: Policy postprocessor action_queue: Queue to put new action chunks shutdown_event: Event to signal shutdown cfg: Demo configuration @@ -232,27 +343,10 @@ def get_actions( logger.info("[GET_ACTIONS] Starting get actions thread") latency_tracker = LatencyTracker() # Track latency of action chunks - fps = cfg.fps - time_per_chunk = 1.0 / fps + time_per_chunk = 1.0 / cfg.fps - dataset_features = hw_to_dataset_features(robot.observation_features(), "observation") policy_device = policy.config.device - # Load preprocessor and postprocessor from pretrained files - # The stats are embedded in the processor .safetensors files - logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}") - - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - dataset_stats=None, # Will load from pretrained processor files - preprocessor_overrides={ - "device_processor": {"device": cfg.policy.device}, - }, - ) - - logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats") - get_actions_threshold = cfg.action_queue_size_to_get_new_actions if not cfg.rtc.enabled: @@ -267,37 +361,18 @@ def get_actions( inference_latency = latency_tracker.max() inference_delay = math.ceil(inference_latency / time_per_chunk) - obs = robot.get_observation() - - # Apply robot observation processor - obs_processed = robot_observation_processor(obs) - - obs_with_policy_features = build_dataset_frame( - dataset_features, obs_processed, prefix="observation" - ) - - for name in obs_with_policy_features: - obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) - if "image" in name: - obs_with_policy_features[name] = ( - obs_with_policy_features[name].type(torch.float32) / 255 - ) - obs_with_policy_features[name] = ( - obs_with_policy_features[name].permute(2, 0, 1).contiguous() - ) - obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) - obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) - - obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string! - obs_with_policy_features["robot_type"] = ( - robot.robot.name if hasattr(robot.robot, "name") else "" + preprocessed_obs = _prepare_policy_inputs( + robot=robot, + robot_observation_processor=robot_observation_processor, + dataset_features=dataset_features, + policy_device=policy_device, + preprocessor=preprocessor, + task=cfg.task, ) - preproceseded_obs = preprocessor(obs_with_policy_features) - # Generate actions WITH RTC actions = policy.predict_action_chunk( - preproceseded_obs, + preprocessed_obs, inference_delay=inference_delay, prev_chunk_left_over=prev_actions, ) @@ -388,8 +463,9 @@ def _apply_torch_compile(policy, cfg: RTCDemoConfig): Policy with compiled predict_action_chunk method """ - # PI models handle their own compilation - if policy.type == "pi05" or policy.type == "pi0": + # PI models handle their own compilation via config.compile_model + # Note: policy.type is nn.Module.type() (a method), use policy.config.type instead. + if policy.config.type in ("pi05", "pi0"): return policy try: @@ -406,21 +482,19 @@ def _apply_torch_compile(policy, cfg: RTCDemoConfig): logger.info(f" Mode: {cfg.torch_compile_mode}") logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}") - # Compile the predict_action_chunk method - # - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t) + compile_mode = cfg.torch_compile_mode + if cfg.torch_compile_disable_cudagraphs and compile_mode == "max-autotune": + compile_mode = "max-autotune-no-cudagraphs" + compile_kwargs = { "backend": cfg.torch_compile_backend, - "mode": cfg.torch_compile_mode, + "mode": compile_mode, } - # Disable CUDA graphs if requested (prevents tensor aliasing issues) - if cfg.torch_compile_disable_cudagraphs: - compile_kwargs["options"] = {"triton.cudagraphs": False} - original_method = policy.predict_action_chunk compiled_method = torch.compile(original_method, **compile_kwargs) policy.predict_action_chunk = compiled_method - logger.info("✓ Successfully compiled predict_action_chunk") + logger.info("Successfully compiled predict_action_chunk") except Exception as e: logger.error(f"Failed to apply torch.compile: {e}") @@ -454,6 +528,13 @@ def demo_cli(cfg: RTCDemoConfig): if cfg.policy.type == "pi05" or cfg.policy.type == "pi0": config.compile_model = cfg.use_torch_compile + config.compile_mode = cfg.torch_compile_mode + + # Enable persistent compile cache so recompilation is skipped across runs + if cfg.use_torch_compile: + torch._inductor.config.fx_graph_cache = True + torch._inductor.config.fx_graph_remote_cache = False + logger.info("Enabled persistent FX graph cache for torch.compile") if config.use_peft: from peft import PeftConfig, PeftModel @@ -494,13 +575,41 @@ def demo_cli(cfg: RTCDemoConfig): robot_observation_processor = make_default_robot_observation_processor() robot_action_processor = make_default_robot_action_processor() + # Load preprocessor and postprocessor (needed for warmup and get_actions) + dataset_features = hw_to_dataset_features(robot_wrapper.observation_features(), "observation") + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=None, + preprocessor_overrides={ + "device_processor": {"device": cfg.policy.device}, + }, + ) + + # Run compile warmup before starting RTC + run_compile_warmup( + policy=policy, + robot=robot_wrapper, + robot_observation_processor=robot_observation_processor, + dataset_features=dataset_features, + preprocessor=preprocessor, + cfg=cfg, + ) + + # Wait for user input to start + input("Press enter to start RTC") + # Create action queue for communication between threads action_queue = ActionQueue(cfg.rtc) # Start chunk requester thread get_actions_thread = Thread( target=get_actions, - args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg), + args=( + policy, robot_wrapper, robot_observation_processor, + dataset_features, preprocessor, postprocessor, + action_queue, shutdown_event, cfg, + ), daemon=True, name="GetActions", ) @@ -556,5 +665,6 @@ def demo_cli(cfg: RTCDemoConfig): if __name__ == "__main__": + register_third_party_plugins() demo_cli() logging.info("RTC demo finished") diff --git a/pyproject.toml b/pyproject.toml index f4fb7d24967..9f53316d345 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,9 +76,9 @@ dependencies = [ "pyserial>=3.5,<4.0", "wandb>=0.24.0,<0.25.0", - "torch>=2.2.1,<2.11.0", # TODO: Bump dependency - "torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency - "torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency + "torch>=2.2.1,<2.8.0", # TODO: Bump dependency + "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency + "torchvision>=0.21.0,<0.23.0", # TODO: Bump dependency "draccus==0.10.0", # TODO: Remove == "gymnasium>=1.1.1,<2.0.0", @@ -96,7 +96,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.10.0"] -transformers-dep = ["transformers>=4.57.1,<5.0.0"] +transformers-dep = ["transformers>=4.47.0,<6.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] @@ -135,6 +135,8 @@ wallx = [ "torchdiffeq==0.2.5", "qwen_vl_utils==0.0.11" ] +# NOTE: pi extra has no conflicts declared due to uv bug with git URL dependencies +# Do not install pi with transformers-dep extras simultaneously pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] groot = [ @@ -220,6 +222,19 @@ lerobot = ["envs/*.json"] [tool.setuptools.packages.find] where = ["src"] +[tool.uv.sources] +torch = [ + { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +torchvision = [ + { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true + [tool.ruff] target-version = "py310" line-length = 110 @@ -399,83 +414,10 @@ ignore_errors = false # ignore_errors = false [tool.uv] -# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0 +# wallx has incompatible versions with many extras conflicts = [ - [ - { extra = "wallx" }, - { extra = "transformers-dep" }, - ], - [ - { extra = "wallx" }, - { extra = "pi" }, - ], - [ - { extra = "wallx" }, - { extra = "smolvla" }, - ], - [ - { extra = "wallx" }, - { extra = "groot" }, - ], - [ - { extra = "wallx" }, - { extra = "xvla" }, - ], - [ - { extra = "wallx" }, - { extra = "sarm" }, - ], - [ - { extra = "wallx" }, - { extra = "hilserl" }, - ], - [ - { extra = "wallx" }, - { extra = "libero" }, - ], - [ - { extra = "wallx" }, - { extra = "peft" }, - ], - [ - { extra = "wallx" }, - { extra = "all" }, - ], - # pi uses custom branch which conflicts with transformers-dep - [ - { extra = "pi" }, - { extra = "transformers-dep" }, - ], - [ - { extra = "pi" }, - { extra = "smolvla" }, - ], - [ - { extra = "pi" }, - { extra = "groot" }, - ], - [ - { extra = "pi" }, - { extra = "xvla" }, - ], - [ - { extra = "pi" }, - { extra = "sarm" }, - ], - [ - { extra = "pi" }, - { extra = "hilserl" }, - ], - [ - { extra = "pi" }, - { extra = "libero" }, - ], - [ - { extra = "pi" }, - { extra = "peft" }, - ], - [ - { extra = "pi" }, - { extra = "all" }, - ], + [{ extra = "wallx" }, { extra = "all" }], + [{ extra = "wallx" }, { extra = "peft" }], + [{ extra = "wallx" }, { extra = "sarm" }], + [{ extra = "wallx" }, { extra = "pi" }], ] diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 104ec63bf2e..e9ab77fcf50 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -17,6 +17,7 @@ import builtins import logging import math +import threading from collections import deque from pathlib import Path from typing import TYPE_CHECKING, Literal, TypedDict @@ -785,7 +786,9 @@ def sample_actions( masks, noise=None, num_steps=None, - **kwargs: Unpack[ActionSelectKwargs], + inference_delay: int | None = None, + prev_chunk_left_over: Tensor | None = None, + execution_horizon: int | None = None, ) -> Tensor: """Do a full inference forward and compute the action.""" if num_steps is None: @@ -824,6 +827,7 @@ def sample_actions( for step in range(num_steps): time = 1.0 + step * dt time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) + rtc_time_tensor = torch.tensor(time, dtype=torch.float32, device=device) def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): return self.denoise_step( @@ -834,15 +838,11 @@ def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): ) if self._rtc_enabled(): - inference_delay = kwargs.get("inference_delay") - prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - execution_horizon = kwargs.get("execution_horizon") - v_t = self.rtc_processor.denoise_step( x_t=x_t, prev_chunk_left_over=prev_chunk_left_over, inference_delay=inference_delay, - time=time, + time=rtc_time_tensor, original_denoise_step_partial=denoise_step_partial_call, execution_horizon=execution_horizon, ) @@ -1105,9 +1105,36 @@ def _fix_pytorch_state_dict_keys( def get_optim_params(self) -> dict: return self.parameters() + def _new_action_queue(self) -> deque: + """Create a fresh action queue honoring n_action_steps.""" + return deque(maxlen=self.config.n_action_steps) + + def _get_thread_action_queue(self) -> deque: + """Return the action queue scoped to the current thread.""" + if not hasattr(self, "_thread_local"): + self._thread_local = threading.local() + action_queue = getattr(self._thread_local, "action_queue", None) + if action_queue is None: + action_queue = self._new_action_queue() + self._thread_local.action_queue = action_queue + return action_queue + + @property + def _action_queue(self) -> deque: + """Expose the thread-local action queue (backwards compatible attribute).""" + return self._get_thread_action_queue() + + @_action_queue.setter + def _action_queue(self, queue: deque) -> None: + if not hasattr(self, "_thread_local"): + self._thread_local = threading.local() + + self._thread_local.action_queue = queue + def reset(self): """Reset internal state - called when environment resets.""" - self._action_queue = deque(maxlen=self.config.n_action_steps) + self._thread_local = threading.local() + self._action_queue = self._new_action_queue() self._queues = { ACTION: deque(maxlen=self.config.n_action_steps), } @@ -1225,8 +1252,43 @@ def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[Action images, img_masks = self._preprocess_images(batch) tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) - actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs) + # Normalize RTC inputs before the compiled sample_actions boundary to + # prevent dynamo recompilation: + # - Convert inference_delay / execution_horizon from Python int to tensor + # (dynamo specializes on int values, recompiling for each unique value; + # tensors are dynamic). + # - Pad prev_chunk_left_over to (chunk_size, max_action_dim) so the compiled + # function always sees the same input shape. + model_device = tokens.device + + inference_delay = kwargs.get("inference_delay") + if inference_delay is not None and not isinstance(inference_delay, torch.Tensor): + inference_delay = torch.tensor(inference_delay, dtype=torch.long, device=model_device) + + execution_horizon = kwargs.get("execution_horizon") + if execution_horizon is not None and not isinstance(execution_horizon, torch.Tensor): + execution_horizon = torch.tensor(execution_horizon, dtype=torch.long, device=model_device) + + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + if prev_chunk_left_over is not None: + prev_chunk_left_over = prev_chunk_left_over.to(device=model_device) + target_t, target_a = self.config.chunk_size, self.config.max_action_dim + cur_t, cur_a = prev_chunk_left_over.shape[-2], prev_chunk_left_over.shape[-1] + if cur_t < target_t or cur_a < target_a: + pad_shape = (*prev_chunk_left_over.shape[:-2], target_t, target_a) + padded = torch.zeros(pad_shape, device=model_device, dtype=prev_chunk_left_over.dtype) + padded[..., :cur_t, :cur_a] = prev_chunk_left_over + prev_chunk_left_over = padded + + actions = self.model.sample_actions( + images, + img_masks, + tokens, + masks, + inference_delay=inference_delay, + prev_chunk_left_over=prev_chunk_left_over, + execution_horizon=execution_horizon, + ) # Unpad actions to actual action dimension original_action_dim = self.config.output_features[ACTION].shape[0] diff --git a/src/lerobot/policies/rtc/action_queue.py b/src/lerobot/policies/rtc/action_queue.py index 3f253ff8ca4..b5d75b0dc4a 100644 --- a/src/lerobot/policies/rtc/action_queue.py +++ b/src/lerobot/policies/rtc/action_queue.py @@ -102,6 +102,16 @@ def empty(self) -> bool: length = len(self.queue) return length - self.last_index <= 0 + def clear(self) -> None: + """Clear all actions from the queue, resetting to empty state. + + Used when switching tasks so stale actions are not executed. + """ + with self.lock: + self.queue = None + self.original_queue = None + self.last_index = 0 + def get_action_index(self) -> int: """Get the current action consumption index. @@ -131,27 +141,67 @@ def merge( processed_actions: Tensor, real_delay: int, action_index_before_inference: int | None = 0, - ): + ) -> int: """Merge new actions into the queue. This method operates differently based on RTC mode: - RTC enabled: Replaces the queue, accounting for inference delay - RTC disabled: Appends to the queue, maintaining continuity + For RTC mode, the delay used for slicing is determined by ground truth + (actual actions consumed during inference) when available, falling back + to the latency-based estimate when the queue was empty. + Args: original_actions: Unprocessed actions from policy (time_steps, action_dim). processed_actions: Post-processed actions for robot (time_steps, action_dim). - real_delay: Number of time steps of inference delay. + real_delay: Number of time steps of inference delay (ceil of latency). action_index_before_inference: Index before inference started, for validation. + + Returns: + int: The actual delay applied (number of actions skipped). """ with self.lock: - self._check_delays(real_delay, action_index_before_inference) - if self.cfg.enabled: - self._replace_actions_queue(original_actions, processed_actions, real_delay) - return + actual_delay = self._compute_actual_delay(real_delay, action_index_before_inference) + self._replace_actions_queue(original_actions, processed_actions, actual_delay) + return actual_delay self._append_actions_queue(original_actions, processed_actions) + return 0 + + def _compute_actual_delay(self, real_delay: int, action_index_before_inference: int | None) -> int: + """Compute the actual delay to use for queue slicing. + + Uses ground truth (actions consumed during inference) when the queue + was active. Falls back to latency-based estimate when the queue was + empty (e.g. first request after warmup). + + Args: + real_delay: Latency-based delay estimate (ceil). + action_index_before_inference: Action index when inference started. + + Returns: + int: Delay to use for slicing the new action chunk. + """ + if action_index_before_inference is None: + return real_delay + + indexes_diff = self.last_index - action_index_before_inference + + if indexes_diff <= 0: + # Queue was empty during inference (first request or queue ran dry). + # Robot didn't move, so no actions are stale — skip nothing. + return 0 + + if indexes_diff != real_delay: + logger.debug( + "[ACTION_QUEUE] Using actual consumed actions as delay: %d (latency estimate was %d)", + indexes_diff, + real_delay, + ) + + return indexes_diff def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int): """Replace the queue with new actions (RTC mode). @@ -195,25 +245,3 @@ def _append_actions_queue(self, original_actions: Tensor, processed_actions: Ten self.queue = self.queue[self.last_index :] self.last_index = 0 - - def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None): - """Validate that computed delays match expectations. - - Compares the delay computed from inference latency with the actual - number of actions consumed during inference. - - Args: - real_delay: Delay computed from inference latency. - action_index_before_inference: Action index when inference started. - """ - if action_index_before_inference is None: - return - - indexes_diff = self.last_index - action_index_before_inference - if indexes_diff != real_delay: - # Let's check that action index difference (real delay calculated based on action queue) - # is the same as delay calculated based on inference latency - logger.warning( - f"[ACTION_QUEUE] Indexes diff is not equal to real delay. " - f"Indexes diff: {indexes_diff}, real delay: {real_delay}" - ) diff --git a/src/lerobot/policies/rtc/modeling_rtc.py b/src/lerobot/policies/rtc/modeling_rtc.py index 280905adf9f..e9714367508 100644 --- a/src/lerobot/policies/rtc/modeling_rtc.py +++ b/src/lerobot/policies/rtc/modeling_rtc.py @@ -150,25 +150,21 @@ def denoise_step( right-padded with zeros to match ``T``. - Prefix weights are constructed via ``get_prefix_weights(inference_delay, execution_horizon, T)`` and broadcast to ``(B, T, A)``. - - Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and - ``error = (prev_chunk_left_over - x1_t) * weights``. + - Guidance correction is ``err = (prev_chunk_left_over - x1_t) * weights`` + using the identity approximation (no backward pass needed). - The final guidance weight is clamped by ``max_guidance_weight`` from the config. Reference: https://www.physicalintelligence.company/download/real_time_chunking.pdf """ - # In the original implementation, the time goes from 0 to 1 and - # In our implementation, the time goes from 1 to 0 - # So we need to invert the time - tau = 1 - time - if prev_chunk_left_over is None: # First step, no guidance - return v_t v_t = original_denoise_step_partial(x_t) return v_t - x_t = x_t.clone().detach() + time_tensor = torch.as_tensor(time, dtype=x_t.dtype, device=x_t.device) + tau_tensor = 1 - time_tensor squeezed = False if len(x_t.shape) < 3: @@ -183,17 +179,21 @@ def denoise_step( if execution_horizon is None: execution_horizon = self.rtc_config.execution_horizon - # If the previous action chunk is to short then it doesn't make sense to use long execution horizon - # because there is nothing to merge - if execution_horizon > prev_chunk_left_over.shape[1]: - execution_horizon = prev_chunk_left_over.shape[1] + # Clamp execution_horizon to prev_chunk length (compile-friendly) + prev_chunk_len = prev_chunk_left_over.shape[1] + if isinstance(execution_horizon, Tensor): + execution_horizon = torch.clamp(execution_horizon, max=prev_chunk_len) + elif execution_horizon > prev_chunk_len: + execution_horizon = prev_chunk_len batch_size = x_t.shape[0] action_chunk_size = x_t.shape[1] action_dim = x_t.shape[2] if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim: - padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device) + padded = torch.zeros( + batch_size, action_chunk_size, action_dim, device=x_t.device, dtype=x_t.dtype + ) padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over prev_chunk_left_over = padded @@ -202,23 +202,22 @@ def denoise_step( ) weights = ( - self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size) - .to(x_t.device) + self.get_prefix_weights( + inference_delay, execution_horizon, action_chunk_size, device=x_t.device, dtype=x_t.dtype + ) .unsqueeze(0) .unsqueeze(-1) ) - with torch.enable_grad(): - v_t = original_denoise_step_partial(x_t) - x_t.requires_grad_(True) + # Identity approximation (J ≈ I): correction = err, no backward pass needed. + v_t = original_denoise_step_partial(x_t) + x1_t = x_t - time_tensor * v_t # noqa: N806 + err = (prev_chunk_left_over - x1_t) * weights + correction = err.detach() - x1_t = x_t - time * v_t # noqa: N806 - err = (prev_chunk_left_over - x1_t) * weights - grad_outputs = err.clone().detach() - correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0] - - max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight) - tau_tensor = torch.as_tensor(tau) + max_guidance_weight = torch.as_tensor( + self.rtc_config.max_guidance_weight, dtype=x_t.dtype, device=x_t.device + ) squared_one_minus_tau = (1 - tau_tensor) ** 2 inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau) c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight) @@ -247,24 +246,40 @@ def denoise_step( return result - def get_prefix_weights(self, start, end, total): - start = min(start, end) - - if self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS: - weights = torch.zeros(total) - weights[:start] = 1.0 - elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ONES: - weights = torch.ones(total) - weights[end:] = 0.0 - elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR: - lin_weights = self._linweights(start, end, total) - weights = self._add_trailing_zeros(lin_weights, total, end) - weights = self._add_leading_ones(weights, start, total) - elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.EXP: - lin_weights = self._linweights(start, end, total) - lin_weights = lin_weights * torch.expm1(lin_weights).div(math.e - 1) - weights = self._add_trailing_zeros(lin_weights, total, end) - weights = self._add_leading_ones(weights, start, total) + def get_prefix_weights( + self, + start: int | Tensor, + end: int | Tensor, + total: int, + *, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, + ): + # Pure tensor implementation — no .item() calls, no dynamic slicing, + # so this can live inside a torch.compile full-graph. + idx = torch.arange(total, device=device, dtype=dtype) + start_f = torch.as_tensor(start, device=device, dtype=dtype).reshape(()) + end_f = torch.as_tensor(end, device=device, dtype=dtype).reshape(()) + start_f = torch.minimum(start_f, end_f) + + schedule = self.rtc_config.prefix_attention_schedule + + if schedule == RTCAttentionSchedule.ZEROS: + weights = torch.where(idx < start_f, 1.0, 0.0) + elif schedule == RTCAttentionSchedule.ONES: + weights = torch.where(idx < end_f, 1.0, 0.0) + elif schedule in (RTCAttentionSchedule.LINEAR, RTCAttentionSchedule.EXP): + range_len = end_f - start_f + # Equivalent to linspace(1, 0, range_len+2)[1:-1] mapped onto [start, end) + # For absolute index i in [start, end): weight = (end - i) / (range_len + 1) + lin = (end_f - idx) / (range_len + 1) + + if schedule == RTCAttentionSchedule.EXP: + lin = lin * torch.expm1(lin) / (math.e - 1) + + leading_mask = idx < start_f + middle_mask = (idx >= start_f) & (idx < end_f) + weights = torch.where(leading_mask, 1.0, torch.where(middle_mask, lin, 0.0)) return weights diff --git a/src/lerobot/policies/rtc/profiling.py b/src/lerobot/policies/rtc/profiling.py new file mode 100644 index 00000000000..01a1d88b0de --- /dev/null +++ b/src/lerobot/policies/rtc/profiling.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Profiling utilities for remote RTC runs.""" + +import logging +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + + +@dataclass +class RTCProfilingRecord: + """Per-request timing and queue metrics for remote RTC.""" + + request_idx: int + timestamp: float + label: str + payload_bytes: int | None = None + + queue_size_before: int | None = None + queue_size_after: int | None = None + action_index_before: int | None = None + inference_delay_requested: int | None = None + realized_delay: int | None = None + + client_observation_ms: float | None = None + client_pickle_ms: float | None = None + client_send_ms: float | None = None + client_get_actions_ms: float | None = None + client_unpickle_ms: float | None = None + client_total_ms: float | None = None + + server_queue_wait_ms: float | None = None + server_preprocess_ms: float | None = None + server_inference_ms: float | None = None + server_postprocess_ms: float | None = None + server_pickle_ms: float | None = None + server_total_ms: float | None = None + + +class RTCProfiler: + """Stores profiling records and writes parquet + plot artifacts.""" + + def __init__(self, enabled: bool, output_dir: str, run_name: str): + self.enabled = enabled + self.output_dir = Path(output_dir) + self.run_name = run_name + self._records: list[RTCProfilingRecord] = [] + + def add(self, record: RTCProfilingRecord) -> None: + if not self.enabled: + return + self._records.append(record) + + def finalize(self) -> dict[str, str]: + if not self.enabled or not self._records: + return {} + + self.output_dir.mkdir(parents=True, exist_ok=True) + + parquet_path = self.output_dir / f"{self.run_name}_profiling.parquet" + plot_path = self.output_dir / f"{self.run_name}_profiling.png" + + self._save_parquet(parquet_path, [asdict(r) for r in self._records]) + self._save_plot(plot_path) + + return {"parquet": str(parquet_path), "plot": str(plot_path)} + + def _save_parquet(self, path: Path, rows: list[dict[str, Any]]) -> None: + try: + import pyarrow as pa # noqa: PLC0415 + import pyarrow.parquet as pq # noqa: PLC0415 + except ImportError: + logger.warning("pyarrow not installed, skipping parquet export for %s", path.name) + return + + if not rows: + return + table = pa.Table.from_pylist(rows) + pq.write_table(table, path) + + def _save_plot(self, path: Path) -> None: + try: + import matplotlib.pyplot as plt # noqa: PLC0415 + except ImportError: + logger.warning("matplotlib not installed, skipping profiling plot.") + return + + x = np.arange(len(self._records)) + has_queue = any(r.queue_size_before is not None or r.queue_size_after is not None for r in self._records) + + nrows = 2 if has_queue else 1 + fig, axes = plt.subplots(nrows, 1, figsize=(16, 5 * nrows), sharex=True) + if nrows == 1: + axes = [axes] + + ax_timing = axes[0] + self._plot_field(ax_timing, x, "client_total_ms", "Client Total") + self._plot_field(ax_timing, x, "client_observation_ms", "Client Observation") + self._plot_field(ax_timing, x, "client_send_ms", "Client Send") + self._plot_field(ax_timing, x, "client_get_actions_ms", "Client GetActions Wait") + self._plot_field(ax_timing, x, "server_total_ms", "Server Total") + self._plot_field(ax_timing, x, "server_inference_ms", "Server Inference") + self._plot_field(ax_timing, x, "server_preprocess_ms", "Server Preprocess") + self._plot_field(ax_timing, x, "server_postprocess_ms", "Server Postprocess") + ax_timing.set_ylabel("Milliseconds") + ax_timing.set_title("Remote RTC Timing Breakdown") + ax_timing.grid(True, alpha=0.3) + ax_timing.legend(loc="upper right") + + if has_queue: + ax_queue = axes[1] + self._plot_field(ax_queue, x, "queue_size_before", "Queue Before") + self._plot_field(ax_queue, x, "queue_size_after", "Queue After") + self._plot_field(ax_queue, x, "inference_delay_requested", "Requested Delay") + self._plot_field(ax_queue, x, "realized_delay", "Realized Delay") + ax_queue.set_ylabel("Steps") + ax_queue.set_title("Queue and Delay Dynamics") + ax_queue.grid(True, alpha=0.3) + ax_queue.legend(loc="upper right") + + axes[-1].set_xlabel("Request Index") + fig.tight_layout() + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + + def _plot_field(self, ax, x: np.ndarray, field_name: str, label: str) -> None: + values = [getattr(r, field_name) for r in self._records] + if not any(v is not None for v in values): + return + arr = np.array([np.nan if v is None else float(v) for v in values], dtype=np.float64) + ax.plot(x, arr, label=label, linewidth=2, alpha=0.85) diff --git a/src/lerobot/policies/rtc/remote.py b/src/lerobot/policies/rtc/remote.py new file mode 100644 index 00000000000..5288ac4b8b9 --- /dev/null +++ b/src/lerobot/policies/rtc/remote.py @@ -0,0 +1,93 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Data classes for remote RTC policy inference. + +These classes define the communication protocol between RTC policy servers +and clients (robots, simulations, or evaluation scripts). +""" + +from dataclasses import dataclass, field +from typing import Any + +import torch + +from lerobot.policies.rtc.configuration_rtc import RTCConfig + + +@dataclass +class RTCRemotePolicyConfig: + """Configuration sent by client to initialize policy on server. + + This is sent from clients to the server when establishing a connection, + telling the server which policy to load and how to configure it. + """ + + policy_type: str + pretrained_name_or_path: str + lerobot_features: dict[str, Any] + rtc_config: RTCConfig | None = None + device: str = "cuda" + rename_map: dict[str, str] = field(default_factory=dict) + use_torch_compile: bool = False + torch_compile_mode: str = "reduce-overhead" + chunk_size: int | None = None + n_action_steps: int | None = None + + +@dataclass +class RTCObservationData: + """Observation data with RTC parameters sent from client to server. + + Contains the observation dict along with RTC-specific parameters + needed for inference: + - inference_delay: Number of steps the inference is expected to take + - prev_chunk_left_over: Unconsumed actions from previous chunk for RTC guidance + - execution_horizon: How far into the future to plan + """ + + observation: dict[str, Any] + timestamp: float + timestep: int + inference_delay: int + prev_chunk_left_over: torch.Tensor | None + execution_horizon: int + + +@dataclass +class RTCActionData: + """Action data returned from server to client. + + Contains both the postprocessed actions (ready for robot execution) + and the original actions (for RTC left-over tracking in the action queue). + """ + + actions: torch.Tensor # Postprocessed actions for robot + original_actions: torch.Tensor # Original actions for RTC left-over tracking + timestamp: float + timestep: int + timing: "RTCTimingData | None" = None + + +@dataclass +class RTCTimingData: + """Timing breakdown for one remote inference request.""" + + queue_wait_ms: float | None = None + preprocess_ms: float | None = None + inference_ms: float | None = None + postprocess_ms: float | None = None + pickle_ms: float | None = None + total_ms: float | None = None