Skip to content

Commit 3423421

Browse files
authored
add roboarena policies (#573)
2 parents 84cd4dc + db3374b commit 3423421

File tree

6 files changed

+869
-2
lines changed

6 files changed

+869
-2
lines changed

examples/droid/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,27 @@ The script will ask you to enter a free-form language instruction for the robot
4444
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
4545
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
4646
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
47+
48+
49+
# Running RoboArena Baseline Policies
50+
51+
We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
52+
53+
```
54+
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
55+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
56+
57+
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
58+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
59+
60+
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
61+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
62+
63+
# Trained from PaliGemma, using FSQ tokenizer.
64+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
65+
66+
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
67+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
68+
```
69+
70+
You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).

src/openpi/models/pi0_fast.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import logging
3+
from typing import Any
34

45
import einops
56
import flax.nnx as nnx
@@ -82,6 +83,11 @@ class Pi0FASTConfig(_model.BaseModelConfig):
8283
action_horizon: int = 32
8384
max_token_len: int = 250
8485

86+
# Tokenizer for the fast model.
87+
fast_model_tokenizer: Any | None = None
88+
# Keyword arguments for the fast model tokenizer.
89+
fast_model_tokenizer_kwargs: dict[str, Any] | None = None
90+
8591
@property
8692
@override
8793
def model_type(self) -> _model.ModelType:

src/openpi/models/tokenizer.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import logging
2+
import os
23

4+
import jax
35
import numpy as np
6+
import orbax.checkpoint as ocp
47
import sentencepiece
58
from transformers import AutoProcessor
69

10+
import openpi.models.utils.fsq_tokenizer as fsq_tokenizer
711
import openpi.shared.download as download
812

913

@@ -125,3 +129,235 @@ def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.
125129
if isinstance(tokens, list):
126130
tokens = np.array(tokens)
127131
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
132+
133+
134+
###########################################################################
135+
## The tokenizers below are used for RoboArena baseline implementations. ##
136+
## They are *not* used for pi0-style models. ##
137+
###########################################################################
138+
139+
140+
class BinningTokenizer:
141+
"""
142+
Standard RT-2 / OpenVLA style binning tokenizer.
143+
"""
144+
145+
def __init__(self, max_len: int = 256, n_bins: int = 256):
146+
self._max_len = max_len
147+
self._n_bins = n_bins
148+
149+
# Download base PaliGemma tokenizer
150+
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
151+
with path.open("rb") as f:
152+
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
153+
154+
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
155+
156+
def tokenize(
157+
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
158+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
159+
"""Tokenize a prompt and state into a sequence of tokens.
160+
161+
Args:
162+
prompt: The text prompt to tokenize.
163+
state: The state array to discretize and tokenize.
164+
actions: Must be None. Action encoding is not currently supported.
165+
166+
Returns:
167+
A tuple of (tokens, token_mask, ar_mask, targets).
168+
169+
Raises:
170+
NotImplementedError: If actions is not None.
171+
"""
172+
cleaned_text = prompt.lower().strip().replace("_", " ")
173+
174+
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
175+
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
176+
177+
# Convention: prefix includes prompt and string-representation of state, followed by ';'
178+
state_str = " ".join(map(str, discretized_state))
179+
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
180+
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
181+
182+
if actions is not None:
183+
raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)")
184+
postfix_tokens = []
185+
186+
# Create output token sequence & masks
187+
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
188+
tokens = prefix_tokens + postfix_tokens
189+
token_mask = [True] * len(tokens)
190+
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
191+
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
192+
193+
# Pad tokens to max length
194+
tokens_len = len(tokens)
195+
if tokens_len < self._max_len:
196+
padding = [False] * (self._max_len - tokens_len)
197+
tokens = tokens + padding
198+
token_mask = token_mask + padding
199+
ar_mask = ar_mask + padding
200+
loss_mask = loss_mask + padding
201+
else:
202+
if len(tokens) > self._max_len:
203+
logging.warning(
204+
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
205+
"Consider increasing the `max_token_len` in your model config if this happens frequently."
206+
)
207+
tokens = tokens[: self._max_len]
208+
token_mask = token_mask[: self._max_len]
209+
ar_mask = ar_mask[: self._max_len]
210+
loss_mask = loss_mask[: self._max_len]
211+
212+
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
213+
214+
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
215+
# Decode predicted output tokens
216+
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
217+
218+
# Extract actions from FAST model outputs
219+
if "Action: " not in decoded_tokens:
220+
return np.zeros((action_horizon, action_dim), dtype=np.float32)
221+
222+
# Extract actions from decoded tokens
223+
raw_action_tokens = np.array(
224+
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
225+
)
226+
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
227+
if len(action_tokens) < action_horizon * action_dim:
228+
return np.zeros([action_horizon, action_dim], dtype=np.float32)
229+
action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim])
230+
return action_tokens / self._n_bins * 2 - 1
231+
232+
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
233+
if isinstance(tokens, list):
234+
tokens = np.array(tokens)
235+
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
236+
237+
238+
class FSQTokenizer:
239+
"""
240+
FSQ tokenizer from the FAST paper baselines.
241+
"""
242+
243+
def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None):
244+
self._max_len = max_len
245+
246+
assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
247+
# Download tokenizer
248+
path = download.maybe_download(fsq_tokenizer_path)
249+
tok_path = os.path.join(path, os.listdir(path)[0]) # noqa: PTH118
250+
251+
# Split step from path
252+
step = int(tok_path.split("/")[-1])
253+
base_path = tok_path.rsplit("/", 1)[0]
254+
255+
mgr = ocp.CheckpointManager(
256+
base_path,
257+
item_handlers={
258+
"params": ocp.StandardCheckpointHandler(),
259+
"opt_state": ocp.StandardCheckpointHandler(),
260+
"config": ocp.JsonCheckpointHandler(),
261+
},
262+
options=ocp.CheckpointManagerOptions(max_to_keep=1),
263+
)
264+
265+
try:
266+
restored = mgr.restore(
267+
step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore())
268+
)
269+
config = restored["config"]
270+
self._params = restored["params"]
271+
self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)
272+
except Exception as e:
273+
raise RuntimeError(
274+
f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}"
275+
) from e
276+
277+
# Compile tokenize and detokenize functions
278+
self._tokenize_fn = jax.jit(
279+
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize)
280+
)
281+
self._detokenize_fn = jax.jit(
282+
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize)
283+
)
284+
285+
# Download base PaliGemma tokenizer
286+
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
287+
with path.open("rb") as f:
288+
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
289+
290+
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
291+
292+
def tokenize(
293+
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
294+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
295+
cleaned_text = prompt.lower().strip().replace("_", " ")
296+
297+
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
298+
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
299+
300+
# Convention: prefix includes prompt and string-representation of state, followed by ';'
301+
state_str = " ".join(map(str, discretized_state))
302+
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
303+
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
304+
305+
if actions is not None:
306+
raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)")
307+
postfix_tokens = []
308+
309+
# Create output token sequence & masks
310+
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
311+
tokens = prefix_tokens + postfix_tokens
312+
token_mask = [True] * len(tokens)
313+
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
314+
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
315+
316+
# Pad tokens to max length
317+
tokens_len = len(tokens)
318+
if tokens_len < self._max_len:
319+
padding = [False] * (self._max_len - tokens_len)
320+
tokens = tokens + padding
321+
token_mask = token_mask + padding
322+
ar_mask = ar_mask + padding
323+
loss_mask = loss_mask + padding
324+
else:
325+
if len(tokens) > self._max_len:
326+
logging.warning(
327+
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
328+
"Consider increasing the `max_token_len` in your model config if this happens frequently."
329+
)
330+
tokens = tokens[: self._max_len]
331+
token_mask = token_mask[: self._max_len]
332+
ar_mask = ar_mask[: self._max_len]
333+
loss_mask = loss_mask[: self._max_len]
334+
335+
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
336+
337+
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
338+
# Decode predicted output tokens
339+
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
340+
341+
# Extract actions from FAST model outputs
342+
if "Action: " not in decoded_tokens:
343+
return np.zeros((action_horizon, action_dim), dtype=np.float32)
344+
345+
# Extract actions from decoded tokens
346+
raw_action_tokens = np.array(
347+
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
348+
)
349+
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
350+
try:
351+
# Move computation to CPU and compile on-demand
352+
device = jax.devices("cpu")[0]
353+
with jax.default_device(device):
354+
detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0]
355+
return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim])
356+
except Exception as e:
357+
logging.warning(f"Error decoding FSQ: {e}")
358+
return np.zeros((action_horizon, action_dim))
359+
360+
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
361+
if isinstance(tokens, list):
362+
tokens = np.array(tokens)
363+
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens

0 commit comments

Comments
 (0)