Skip to content

Commit a4b2db0

Browse files
authored
Merge branch 'main' into patch-2
2 parents 5875c00 + 3423421 commit a4b2db0

File tree

10 files changed

+908
-28
lines changed

10 files changed

+908
-28
lines changed

examples/aloha_real/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Run the container:
77
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
88

9-
FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
9+
FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
1010
SHELL ["/bin/bash", "-c"]
1111

1212
ENV DEBIAN_FRONTEND=noninteractive

examples/aloha_real/real_env.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,17 @@ def _reset_joints(self):
112112
)
113113

114114
def _reset_gripper(self):
115-
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
115+
"""Set to position mode and do position resets: first close then open. Then change back to PWM mode
116+
117+
NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
118+
was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
119+
increase the frequency of motor faults.
120+
"""
116121
robot_utils.move_grippers(
117-
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
122+
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
118123
)
119124
robot_utils.move_grippers(
120-
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
125+
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
121126
)
122127

123128
def get_observation(self):

examples/droid/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,27 @@ The script will ask you to enter a free-form language instruction for the robot
4444
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
4545
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
4646
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
47+
48+
49+
# Running RoboArena Baseline Policies
50+
51+
We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
52+
53+
```
54+
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
55+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
56+
57+
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
58+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
59+
60+
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
61+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
62+
63+
# Trained from PaliGemma, using FSQ tokenizer.
64+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
65+
66+
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
67+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
68+
```
69+
70+
You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).

examples/libero/convert_libero_data_to_lerobot.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
`uv pip install tensorflow tensorflow_datasets`
1515
1616
You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
17-
The resulting dataset will get saved to the $LEROBOT_HOME directory.
17+
The resulting dataset will get saved to the $HF_LEROBOT_HOME directory.
1818
Running this conversion script will take approximately 30 minutes.
1919
"""
2020

2121
import shutil
2222

23-
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
23+
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
2424
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
2525
import tensorflow_datasets as tfds
2626
import tyro
@@ -36,7 +36,7 @@
3636

3737
def main(data_dir: str, *, push_to_hub: bool = False):
3838
# Clean up any existing dataset in the output directory
39-
output_path = LEROBOT_HOME / REPO_NAME
39+
output_path = HF_LEROBOT_HOME / REPO_NAME
4040
if output_path.exists():
4141
shutil.rmtree(output_path)
4242

@@ -85,12 +85,10 @@ def main(data_dir: str, *, push_to_hub: bool = False):
8585
"wrist_image": step["observation"]["wrist_image"],
8686
"state": step["observation"]["state"],
8787
"actions": step["action"],
88+
"task": step["language_instruction"].decode(),
8889
}
8990
)
90-
dataset.save_episode(task=step["language_instruction"].decode())
91-
92-
# Consolidate the dataset, skip computing stats since we will do that later
93-
dataset.consolidate(run_compute_stats=False)
91+
dataset.save_episode()
9492

9593
# Optionally push to the Hugging Face Hub
9694
if push_to_hub:

src/openpi/models/pi0_fast.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import logging
3+
from typing import Any
34

45
import einops
56
import flax.nnx as nnx
@@ -82,6 +83,11 @@ class Pi0FASTConfig(_model.BaseModelConfig):
8283
action_horizon: int = 32
8384
max_token_len: int = 250
8485

86+
# Tokenizer for the fast model.
87+
fast_model_tokenizer: Any | None = None
88+
# Keyword arguments for the fast model tokenizer.
89+
fast_model_tokenizer_kwargs: dict[str, Any] | None = None
90+
8591
@property
8692
@override
8793
def model_type(self) -> _model.ModelType:
@@ -265,14 +271,17 @@ def sample_actions(
265271
output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))
266272

267273
def step(carry):
268-
last_logit, output_tokens, cache, _, step = carry
274+
rng, last_logit, output_tokens, cache, _, step = carry
269275

270276
# Sample token from last logit
271-
if temperature > 0.0:
272-
last_logit = last_logit / temperature
273-
token = jax.random.categorical(rng, last_logit, axis=-1)
274-
else:
275-
token = jnp.argmax(last_logit, axis=-1)
277+
# Split RNG for this step
278+
rng, rng_step = jax.random.split(rng)
279+
token = jax.lax.cond(
280+
temperature > 0.0,
281+
lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1),
282+
lambda _: jnp.argmax(last_logit, axis=-1),
283+
operand=None,
284+
)
276285
output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)
277286

278287
# Check for early stopping --> stop if all batch elements have EOS token
@@ -291,12 +300,14 @@ def step(carry):
291300
embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache
292301
)
293302

294-
return last_logit, output_tokens, kv_cache, all_eos, step + 1
303+
return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1
295304

296305
def cond(carry):
297-
_, _, _, all_eos, step = carry
306+
_, _, _, _, all_eos, step = carry
298307
return (~all_eos) & (step < max_decoding_steps)
299308

300309
# Use lax.while_loop so we can jit the full decoding loop.
301-
_, output_tokens, _, _, _ = jax.lax.while_loop(cond, step, (last_logit, output_tokens, kv_cache, False, 0))
310+
_, _, output_tokens, _, _, _ = jax.lax.while_loop(
311+
cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0)
312+
)
302313
return output_tokens

0 commit comments

Comments
 (0)