Commit df866f6
authored
Move filtering idle actions after flatten for DroidRldsDataset (#605)
Turns out Karl's idle-action-filtering-bug spidey senses were correct.
Right now, filtering of idle actions in `DroidRldsDataset` is happening
before the `dataset.flatten()` operation. This means that the filtering
function is being applied at the trajectory level, rather than the frame
level. It thus filters out all _episodes_ wherein the first
`action_chunk_size // 2` frames' action chunks are sufficiently
different from the first frame's action chunk. This is not the intended
behavior, the filtering should happen on a per-frame basis.
This issue may also explain why non-zero temperature sampling (as added
by [this PR](Physical-Intelligence/openpi#550))
is often needed to get the DROID Franka out of home position for
policies fine-tuned with this codebase.
Luckily, this is a very easy fix: just swap the `dataset =
dataset.filter(filter_idle)` and the `dataset =
dataset.flatten(num_parallel_calls=num_parallel_calls)` steps :)
As an example of the bug, I have included a minimal code snippet with a
version of `DroidRldsDataset` that can filter before or after the
flattening operation. When filtering before flattening and then loading
a batch of 256 frames, many of them do not satisfy
`tf.reduce_any(np.abs(actions[:action_chunk_size // 2] - actions[:1]) >
1e-3)`, which is the criteria by which the filter is checking for.
Filtering after flattening fixes this issue.
```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from openpi.training.droid_rlds_dataset import DroidActionSpace
import numpy as np
import tensorflow as tf
# The below is a copy-pasted version of DroidRldsDataset from openpi.training.droid_rlds_dataset,
# but changed to have the option to flatten then filter (correct) OR filter then flatten (incorrect,
# current implementation) for demonstrative purposes.
class DroidRldsDataset:
def __init__(
self,
data_dir: str,
batch_size: int,
*, # Force keyword-only arguments
shuffle: bool = True,
action_chunk_size: int = 16,
# We default to joint position actions, since they allow policy evaluation in simulation.
action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION,
max_loaded_steps_per_episode: int = 100,
# Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random.
shuffle_buffer_size: int = 250_000,
num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
flatten_then_filter_idle: bool = True, # NEWLY ADDED
):
# Import tensorflow here to not make it mandatory in case RLDS data loader is not used.
import dlimp as dl
import tensorflow as tf
import tensorflow_datasets as tfds
# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX)
tf.config.set_visible_devices([], "GPU")
builder = tfds.builder("droid", data_dir=data_dir)
dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads)
# Filter out any unsuccessful trajectories -- we use the file name to check this
dataset = dataset.filter(
lambda traj: tf.strings.regex_full_match(
traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*"
)
)
# Repeat dataset so we never run out of data.
dataset = dataset.repeat()
def restructure(traj):
"""Reformat observation and action keys, sample language instruction."""
# Important: we use joint *position* action space -- easier to simulate!
actions = tf.concat(
(
(
traj["action_dict"]["joint_position"]
if action_space == DroidActionSpace.JOINT_POSITION
else traj["action_dict"]["joint_velocity"]
),
traj["action_dict"]["gripper_position"],
),
axis=-1,
)
# Randomly samples one of the two exterior images in DROID during training (we only train with one at a time).
# Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera.
exterior_img = tf.cond(
tf.random.uniform(shape=[]) > 0.5,
lambda: traj["observation"]["exterior_image_1_left"],
lambda: traj["observation"]["exterior_image_2_left"],
)
wrist_img = traj["observation"]["wrist_image_left"]
# Randomly sample one of the three language instructions
instruction = tf.random.shuffle(
[traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]]
)[0]
return {
"actions": actions,
"observation": {
"image": exterior_img,
"wrist_image": wrist_img,
"joint_position": traj["observation"]["joint_position"],
"gripper_position": traj["observation"]["gripper_position"],
},
"prompt": instruction,
}
dataset = dataset.traj_map(restructure, num_parallel_calls)
def chunk_actions(traj):
"""Splits episode into action chunks."""
traj_len = tf.shape(traj["actions"])[0]
# For each step in the trajectory, construct indices for the next n actions
action_chunk_indices = tf.broadcast_to(
tf.range(action_chunk_size)[None],
[traj_len, action_chunk_size],
) + tf.broadcast_to(
tf.range(traj_len)[:, None],
[traj_len, action_chunk_size],
)
# Cap to length of the sequence --> final chunks will repeat the last action
# This makes sense, since we are using absolute joint + gripper position actions
action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1)
# Gather the actions for each chunk
traj["actions"] = tf.gather(traj["actions"], action_chunk_indices)
return traj
dataset = dataset.traj_map(chunk_actions, num_parallel_calls)
def filter_idle(traj):
"""Filter out chunks with idle actions.
--> we filter if at least first half of chunk does not move.
"""
if action_space == DroidActionSpace.JOINT_POSITION:
# Compute delta to first position in action chunk
return tf.reduce_any(tf.abs(traj["actions"][: action_chunk_size // 2] - traj["actions"][:1]) > 1e-3)
return tf.reduce_any(tf.abs(traj["actions"][: action_chunk_size // 2]) > 1e-3)
if flatten_then_filter_idle:
# Flatten then filter (fixed correct version)
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
dataset = dataset.filter(filter_idle)
else:
# Filter then flatten (what it currently is, incorrect)
dataset = dataset.filter(filter_idle)
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
# Decode images: RLDS saves encoded images, only decode now for efficiency
def decode_images(traj):
traj["observation"]["image"] = tf.io.decode_image(
traj["observation"]["image"], expand_animations=False, dtype=tf.uint8
)
traj["observation"]["wrist_image"] = tf.io.decode_image(
traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8
)
return traj
dataset = dataset.frame_map(decode_images, num_parallel_calls)
# Shuffle, batch
dataset = dataset.shuffle(shuffle_buffer_size)
dataset = dataset.batch(batch_size)
# Note =>> Seems to reduce memory usage without affecting speed?
dataset = dataset.with_ram_budget(1)
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
yield from self.dataset.as_numpy_iterator()
def __len__(self):
# This is the approximate number of samples in DROID after filtering.
# Easier to hardcode than to iterate through the dataset and compute it.
return 20_000_000
action_chunk_size = 16
batch_size = 256
path_to_droid = "/path/to/droid"
# Incorrect current implementation
filter_then_flatten_dataset = DroidRldsDataset(
path_to_droid,
action_chunk_size=action_chunk_size,
batch_size=batch_size,
shuffle=False,
shuffle_buffer_size=256,
flatten_then_filter_idle=False # False is the current version
)
for batch in filter_then_flatten_dataset:
break
for i in range(len(batch["actions"])):
action = batch["actions"][i]
delta = np.abs(action[: action_chunk_size // 2] - action[:1])
if not tf.reduce_any(delta > 1e-3):
# If filter_idle is working, this should never happen,
# yet it does end up printing.
print(delta.max())
# Now change to correct implementation (flatten first, then filter idle)
flatten_then_filter_dataset = DroidRldsDataset(
path_to_droid,
action_chunk_size=action_chunk_size,
batch_size=batch_size,
shuffle=False,
shuffle_buffer_size=256,
flatten_then_filter_idle=True # Now set to True
)
for batch in flatten_then_filter_dataset:
break
for i in range(len(batch["actions"])):
action = batch["actions"][i]
delta = np.abs(action[: action_chunk_size // 2] - action[:1])
if not tf.reduce_any(delta > 1e-3):
# With the fix, this never happens!
print(delta.max())
```1 file changed
+4
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
116 | 116 | | |
117 | 117 | | |
118 | 118 | | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
119 | 123 | | |
120 | 124 | | |
121 | 125 | | |
| |||
127 | 131 | | |
128 | 132 | | |
129 | 133 | | |
130 | | - | |
131 | | - | |
132 | | - | |
133 | 134 | | |
134 | 135 | | |
135 | 136 | | |
| |||
0 commit comments