Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
528 changes: 528 additions & 0 deletions examples/remote_rtc/eval_dataset.py

Large diffs are not rendered by default.

632 changes: 632 additions & 0 deletions examples/remote_rtc/eval_with_real_robot.py

Large diffs are not rendered by default.

509 changes: 509 additions & 0 deletions examples/remote_rtc/rtc_policy_server.py

Large diffs are not rendered by default.

226 changes: 168 additions & 58 deletions examples/rtc/eval_with_real_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.hub import HubMixin
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.utils import init_logging

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -174,8 +175,13 @@ class RTCDemoConfig(HubMixin):
)

torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
default="reduce-overhead",
metadata={
"help": (
"Compilation mode (default, reduce-overhead, max-autotune, "
"max-autotune-no-cudagraphs)"
)
},
)

torch_compile_disable_cudagraphs: bool = field(
Expand All @@ -186,6 +192,11 @@ class RTCDemoConfig(HubMixin):
},
)

compile_warmup_delay: list[int] = field(
default_factory=lambda: [0, 4],
metadata={"help": "Warmup inference delays per call, e.g. [0,4]. Empty list disables warmup."},
)

def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
Expand All @@ -200,6 +211,9 @@ def __post_init__(self):
if self.robot is None:
raise ValueError("Robot configuration must be provided")

if any(delay < 0 for delay in self.compile_warmup_delay):
raise ValueError("All compile_warmup_delay values must be >= 0")

@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
Expand All @@ -210,10 +224,104 @@ def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)


def _prepare_policy_inputs(
robot: RobotWrapper,
robot_observation_processor,
dataset_features,
policy_device: str,
preprocessor,
task: str,
):
"""Prepare a single observation for policy inference."""
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(dataset_features, obs_processed, prefix="observation")

for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = obs_with_policy_features[name].type(torch.float32) / 255
obs_with_policy_features[name] = obs_with_policy_features[name].permute(2, 0, 1).contiguous()
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)

obs_with_policy_features["task"] = [task]
obs_with_policy_features["robot_type"] = robot.robot.name if hasattr(robot.robot, "name") else ""
return preprocessor(obs_with_policy_features)


def run_compile_warmup(
policy,
robot: RobotWrapper,
robot_observation_processor,
dataset_features,
preprocessor,
cfg: RTCDemoConfig,
) -> None:
"""Run warmup inference calls to trigger torch.compile before the robot starts moving."""
warmup_delays = list(cfg.compile_warmup_delay)
if not cfg.use_torch_compile or len(warmup_delays) == 0:
return

logger.info(
"Running compile warmup before RTC start (%d calls), delays=%s",
len(warmup_delays),
warmup_delays,
)

policy_device = policy.config.device
warmup_prev_actions: Tensor | None = None

warmup_total_start = time.perf_counter()
for warmup_idx, warmup_delay in enumerate(warmup_delays):
step_start = time.perf_counter()
logger.info(
"Compile warmup step %d/%d (delay=%d, prev=%s)...",
warmup_idx + 1,
len(warmup_delays),
warmup_delay,
"None" if warmup_prev_actions is None else f"shape {tuple(warmup_prev_actions.shape)}",
)
preprocessed_obs = _prepare_policy_inputs(
robot=robot,
robot_observation_processor=robot_observation_processor,
dataset_features=dataset_features,
policy_device=policy_device,
preprocessor=preprocessor,
task=cfg.task,
)

with torch.no_grad():
actions = policy.predict_action_chunk(
preprocessed_obs,
inference_delay=warmup_delay,
prev_chunk_left_over=warmup_prev_actions,
)

step_elapsed = time.perf_counter() - step_start
logger.info("Compile warmup step %d/%d done in %.1fs", warmup_idx + 1, len(warmup_delays), step_elapsed)

original_actions = actions.squeeze(0).clone()
chunk_size = int(original_actions.shape[0])

if warmup_idx < len(warmup_delays) - 1:
next_delay = warmup_delays[warmup_idx + 1]
if next_delay < chunk_size:
warmup_prev_actions = original_actions[next_delay:].clone()
else:
warmup_prev_actions = None

total_elapsed = time.perf_counter() - warmup_total_start
logger.info("Compile warmup finished in %.1fs", total_elapsed)


def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
dataset_features,
preprocessor,
postprocessor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
Expand All @@ -224,6 +332,9 @@ def get_actions(
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
dataset_features: Dataset feature definitions for observation conversion
preprocessor: Policy preprocessor
postprocessor: Policy postprocessor
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
Expand All @@ -232,27 +343,10 @@ def get_actions(
logger.info("[GET_ACTIONS] Starting get actions thread")

latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
time_per_chunk = 1.0 / cfg.fps

dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
policy_device = policy.config.device

# Load preprocessor and postprocessor from pretrained files
# The stats are embedded in the processor .safetensors files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")

preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)

logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")

get_actions_threshold = cfg.action_queue_size_to_get_new_actions

if not cfg.rtc.enabled:
Expand All @@ -267,37 +361,18 @@ def get_actions(
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)

obs = robot.get_observation()

# Apply robot observation processor
obs_processed = robot_observation_processor(obs)

obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)

for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)

obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
preprocessed_obs = _prepare_policy_inputs(
robot=robot,
robot_observation_processor=robot_observation_processor,
dataset_features=dataset_features,
policy_device=policy_device,
preprocessor=preprocessor,
task=cfg.task,
)

preproceseded_obs = preprocessor(obs_with_policy_features)

# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
preprocessed_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
Expand Down Expand Up @@ -388,8 +463,9 @@ def _apply_torch_compile(policy, cfg: RTCDemoConfig):
Policy with compiled predict_action_chunk method
"""

# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
# PI models handle their own compilation via config.compile_model
# Note: policy.type is nn.Module.type() (a method), use policy.config.type instead.
if policy.config.type in ("pi05", "pi0"):
return policy

try:
Expand All @@ -406,21 +482,19 @@ def _apply_torch_compile(policy, cfg: RTCDemoConfig):
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")

# Compile the predict_action_chunk method
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_mode = cfg.torch_compile_mode
if cfg.torch_compile_disable_cudagraphs and compile_mode == "max-autotune":
compile_mode = "max-autotune-no-cudagraphs"

compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
"mode": compile_mode,
}

# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}

original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("Successfully compiled predict_action_chunk")
logger.info("Successfully compiled predict_action_chunk")

except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
Expand Down Expand Up @@ -454,6 +528,13 @@ def demo_cli(cfg: RTCDemoConfig):

if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
config.compile_mode = cfg.torch_compile_mode

# Enable persistent compile cache so recompilation is skipped across runs
if cfg.use_torch_compile:
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.fx_graph_remote_cache = False
logger.info("Enabled persistent FX graph cache for torch.compile")

if config.use_peft:
from peft import PeftConfig, PeftModel
Expand Down Expand Up @@ -494,13 +575,41 @@ def demo_cli(cfg: RTCDemoConfig):
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()

# Load preprocessor and postprocessor (needed for warmup and get_actions)
dataset_features = hw_to_dataset_features(robot_wrapper.observation_features(), "observation")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None,
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)

# Run compile warmup before starting RTC
run_compile_warmup(
policy=policy,
robot=robot_wrapper,
robot_observation_processor=robot_observation_processor,
dataset_features=dataset_features,
preprocessor=preprocessor,
cfg=cfg,
)

# Wait for user input to start
input("Press enter to start RTC")

# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)

# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
args=(
policy, robot_wrapper, robot_observation_processor,
dataset_features, preprocessor, postprocessor,
action_queue, shutdown_event, cfg,
),
daemon=True,
name="GetActions",
)
Expand Down Expand Up @@ -556,5 +665,6 @@ def demo_cli(cfg: RTCDemoConfig):


if __name__ == "__main__":
register_third_party_plugins()
demo_cli()
logging.info("RTC demo finished")
Loading
Loading