-
Notifications
You must be signed in to change notification settings - Fork 4k
Expand file tree
/
Copy pathcontrol_utils.py
More file actions
240 lines (199 loc) · 9.17 KB
/
control_utils.py
File metadata and controls
240 lines (199 loc) · 9.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
########################################################################################
# Utilities
########################################################################################
import logging
import os
import platform
import sys
import threading
from contextlib import nullcontext
from copy import copy
from functools import cache
from typing import Any
import numpy as np
import torch
from deepdiff import DeepDiff
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.robots import Robot
@cache
def is_headless():
"""
Detects if the Python script is running in a headless environment (no display available).
On Linux, checks for an X11 (`DISPLAY`) or Wayland (`WAYLAND_DISPLAY`) environment variable.
On macOS and Windows, a display is always assumed to be present. The result is cached to
avoid re-running the check.
Returns:
True if the environment is determined to be headless, False otherwise.
"""
if platform.system() == "Linux":
has_display = bool(os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"))
if not has_display:
logging.warning(
"No display detected (DISPLAY and WAYLAND_DISPLAY are unset). "
"Switching to headless mode. "
"As a result, the video stream from the cameras won't be shown."
)
return not has_display
return False
def predict_action(
observation: dict[str, np.ndarray],
policy: PreTrainedPolicy,
device: torch.device,
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
use_amp: bool,
task: str | None = None,
robot_type: str | None = None,
):
"""
Performs a single-step inference to predict a robot action from an observation.
This function encapsulates the full inference pipeline:
1. Prepares the observation by converting it to PyTorch tensors and adding a batch dimension.
2. Runs the preprocessor pipeline on the observation.
3. Feeds the processed observation to the policy to get a raw action.
4. Runs the postprocessor pipeline on the raw action.
5. Formats the final action by removing the batch dimension and moving it to the CPU.
Args:
observation: A dictionary of NumPy arrays representing the robot's current observation.
policy: The `PreTrainedPolicy` model to use for action prediction.
device: The `torch.device` (e.g., 'cuda' or 'cpu') to run inference on.
preprocessor: The `PolicyProcessorPipeline` for preprocessing observations.
postprocessor: The `PolicyProcessorPipeline` for postprocessing actions.
use_amp: A boolean to enable/disable Automatic Mixed Precision for CUDA inference.
task: An optional string identifier for the task.
robot_type: An optional string identifier for the robot type.
Returns:
A `torch.Tensor` containing the predicted action, ready for the robot.
"""
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
observation = prepare_observation_for_inference(observation, device, task, robot_type)
observation = preprocessor(observation)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
action = postprocessor(action)
return action
def init_keyboard_listener():
"""
Initializes a non-blocking keyboard listener for real-time user interaction.
Reads directly from stdin using `readchar`, which works on both X11 and Wayland sessions
without any display-server dependency. Keyboard input is unavailable when stdin is not a
TTY (e.g. piped input or a truly headless server).
Returns:
A tuple containing:
- A ``threading.Thread`` with a ``stop()`` method, or ``None`` if stdin is not a TTY.
- A dictionary of event flags (``exit_early``, ``rerecord_episode``, ``stop_recording``)
that are set by the corresponding key presses.
"""
import readchar
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
}
if not sys.stdin.isatty():
logging.warning(
"Stdin is not a TTY. Keyboard inputs will not be available. "
"You won't be able to change the control flow with keyboard shortcuts."
)
return None, events
_stop = threading.Event()
def listen():
while not _stop.is_set():
try:
key = readchar.readkey()
except Exception:
break
if key == readchar.key.RIGHT:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == readchar.key.LEFT:
print("Left arrow key pressed. Re-recording episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == readchar.key.ESC:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
break
if events["stop_recording"]:
break
listener = threading.Thread(target=listen, daemon=True)
listener.start()
listener.stop = _stop.set # compatibility shim: lets callers do listener.stop()
return listener, events
def sanity_check_dataset_name(repo_id, policy_cfg):
"""
Validates the dataset repository name against the presence of a policy configuration.
This function enforces a naming convention: a dataset repository ID should start with "eval_"
if and only if a policy configuration is provided for evaluation purposes.
Args:
repo_id: The Hugging Face Hub repository ID of the dataset.
policy_cfg: The configuration object for the policy, or `None`.
Raises:
ValueError: If the naming convention is violated.
"""
_, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy
# Check if dataset_name starts with "eval_" but policy is missing
if dataset_name.startswith("eval_") and policy_cfg is None:
raise ValueError(
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
)
# Check if dataset_name does not start with "eval_" but policy is provided
if not dataset_name.startswith("eval_") and policy_cfg is not None:
raise ValueError(
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
)
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, features: dict
) -> None:
"""
Checks if a dataset's metadata is compatible with the current robot and recording setup.
This function compares key metadata fields (`robot_type`, `fps`, and `features`) from the
dataset against the current configuration to ensure that appended data will be consistent.
Args:
dataset: The `LeRobotDataset` instance to check.
robot: The `Robot` instance representing the current hardware setup.
fps: The current recording frequency (frames per second).
features: The dictionary of features for the current recording session.
Raises:
ValueError: If any of the checked metadata fields do not match.
"""
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, {**features, **DEFAULT_FEATURES}),
]
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)