Skip to content

Commit df866f6

Browse files
authored
Move filtering idle actions after flatten for DroidRldsDataset (#605)
Turns out Karl's idle-action-filtering-bug spidey senses were correct. Right now, filtering of idle actions in `DroidRldsDataset` is happening before the `dataset.flatten()` operation. This means that the filtering function is being applied at the trajectory level, rather than the frame level. It thus filters out all _episodes_ wherein the first `action_chunk_size // 2` frames' action chunks are sufficiently different from the first frame's action chunk. This is not the intended behavior, the filtering should happen on a per-frame basis. This issue may also explain why non-zero temperature sampling (as added by [this PR](Physical-Intelligence/openpi#550)) is often needed to get the DROID Franka out of home position for policies fine-tuned with this codebase. Luckily, this is a very easy fix: just swap the `dataset = dataset.filter(filter_idle)` and the `dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)` steps :) As an example of the bug, I have included a minimal code snippet with a version of `DroidRldsDataset` that can filter before or after the flattening operation. When filtering before flattening and then loading a batch of 256 frames, many of them do not satisfy `tf.reduce_any(np.abs(actions[:action_chunk_size // 2] - actions[:1]) > 1e-3)`, which is the criteria by which the filter is checking for. Filtering after flattening fixes this issue. ```python import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" from openpi.training.droid_rlds_dataset import DroidActionSpace import numpy as np import tensorflow as tf # The below is a copy-pasted version of DroidRldsDataset from openpi.training.droid_rlds_dataset, # but changed to have the option to flatten then filter (correct) OR filter then flatten (incorrect, # current implementation) for demonstrative purposes. class DroidRldsDataset: def __init__( self, data_dir: str, batch_size: int, *, # Force keyword-only arguments shuffle: bool = True, action_chunk_size: int = 16, # We default to joint position actions, since they allow policy evaluation in simulation. action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION, max_loaded_steps_per_episode: int = 100, # Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random. shuffle_buffer_size: int = 250_000, num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level flatten_then_filter_idle: bool = True, # NEWLY ADDED ): # Import tensorflow here to not make it mandatory in case RLDS data loader is not used. import dlimp as dl import tensorflow as tf import tensorflow_datasets as tfds # Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX) tf.config.set_visible_devices([], "GPU") builder = tfds.builder("droid", data_dir=data_dir) dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads) # Filter out any unsuccessful trajectories -- we use the file name to check this dataset = dataset.filter( lambda traj: tf.strings.regex_full_match( traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*" ) ) # Repeat dataset so we never run out of data. dataset = dataset.repeat() def restructure(traj): """Reformat observation and action keys, sample language instruction.""" # Important: we use joint *position* action space -- easier to simulate! actions = tf.concat( ( ( traj["action_dict"]["joint_position"] if action_space == DroidActionSpace.JOINT_POSITION else traj["action_dict"]["joint_velocity"] ), traj["action_dict"]["gripper_position"], ), axis=-1, ) # Randomly samples one of the two exterior images in DROID during training (we only train with one at a time). # Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera. exterior_img = tf.cond( tf.random.uniform(shape=[]) > 0.5, lambda: traj["observation"]["exterior_image_1_left"], lambda: traj["observation"]["exterior_image_2_left"], ) wrist_img = traj["observation"]["wrist_image_left"] # Randomly sample one of the three language instructions instruction = tf.random.shuffle( [traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]] )[0] return { "actions": actions, "observation": { "image": exterior_img, "wrist_image": wrist_img, "joint_position": traj["observation"]["joint_position"], "gripper_position": traj["observation"]["gripper_position"], }, "prompt": instruction, } dataset = dataset.traj_map(restructure, num_parallel_calls) def chunk_actions(traj): """Splits episode into action chunks.""" traj_len = tf.shape(traj["actions"])[0] # For each step in the trajectory, construct indices for the next n actions action_chunk_indices = tf.broadcast_to( tf.range(action_chunk_size)[None], [traj_len, action_chunk_size], ) + tf.broadcast_to( tf.range(traj_len)[:, None], [traj_len, action_chunk_size], ) # Cap to length of the sequence --> final chunks will repeat the last action # This makes sense, since we are using absolute joint + gripper position actions action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1) # Gather the actions for each chunk traj["actions"] = tf.gather(traj["actions"], action_chunk_indices) return traj dataset = dataset.traj_map(chunk_actions, num_parallel_calls) def filter_idle(traj): """Filter out chunks with idle actions. --> we filter if at least first half of chunk does not move. """ if action_space == DroidActionSpace.JOINT_POSITION: # Compute delta to first position in action chunk return tf.reduce_any(tf.abs(traj["actions"][: action_chunk_size // 2] - traj["actions"][:1]) > 1e-3) return tf.reduce_any(tf.abs(traj["actions"][: action_chunk_size // 2]) > 1e-3) if flatten_then_filter_idle: # Flatten then filter (fixed correct version) dataset = dataset.flatten(num_parallel_calls=num_parallel_calls) dataset = dataset.filter(filter_idle) else: # Filter then flatten (what it currently is, incorrect) dataset = dataset.filter(filter_idle) dataset = dataset.flatten(num_parallel_calls=num_parallel_calls) # Decode images: RLDS saves encoded images, only decode now for efficiency def decode_images(traj): traj["observation"]["image"] = tf.io.decode_image( traj["observation"]["image"], expand_animations=False, dtype=tf.uint8 ) traj["observation"]["wrist_image"] = tf.io.decode_image( traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8 ) return traj dataset = dataset.frame_map(decode_images, num_parallel_calls) # Shuffle, batch dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.batch(batch_size) # Note =>> Seems to reduce memory usage without affecting speed? dataset = dataset.with_ram_budget(1) self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle def __iter__(self): yield from self.dataset.as_numpy_iterator() def __len__(self): # This is the approximate number of samples in DROID after filtering. # Easier to hardcode than to iterate through the dataset and compute it. return 20_000_000 action_chunk_size = 16 batch_size = 256 path_to_droid = "/path/to/droid" # Incorrect current implementation filter_then_flatten_dataset = DroidRldsDataset( path_to_droid, action_chunk_size=action_chunk_size, batch_size=batch_size, shuffle=False, shuffle_buffer_size=256, flatten_then_filter_idle=False # False is the current version ) for batch in filter_then_flatten_dataset: break for i in range(len(batch["actions"])): action = batch["actions"][i] delta = np.abs(action[: action_chunk_size // 2] - action[:1]) if not tf.reduce_any(delta > 1e-3): # If filter_idle is working, this should never happen, # yet it does end up printing. print(delta.max()) # Now change to correct implementation (flatten first, then filter idle) flatten_then_filter_dataset = DroidRldsDataset( path_to_droid, action_chunk_size=action_chunk_size, batch_size=batch_size, shuffle=False, shuffle_buffer_size=256, flatten_then_filter_idle=True # Now set to True ) for batch in flatten_then_filter_dataset: break for i in range(len(batch["actions"])): action = batch["actions"][i] delta = np.abs(action[: action_chunk_size // 2] - action[:1]) if not tf.reduce_any(delta > 1e-3): # With the fix, this never happens! print(delta.max()) ```
2 parents 51fb06b + 64e6406 commit df866f6

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/openpi/training/droid_rlds_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def chunk_actions(traj):
116116

117117
dataset = dataset.traj_map(chunk_actions, num_parallel_calls)
118118

119+
# Flatten: map from trajectory dataset to dataset of individual action chunks
120+
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
121+
122+
# Filter out frames where actions are idle. Must be done after flattening, as filter should apply per-frame.
119123
def filter_idle(traj):
120124
"""Filter out chunks with idle actions.
121125
--> we filter if at least first half of chunk does not move.
@@ -127,9 +131,6 @@ def filter_idle(traj):
127131

128132
dataset = dataset.filter(filter_idle)
129133

130-
# Flatten: map from trajectory dataset to dataset of individual action chunks
131-
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
132-
133134
# Decode images: RLDS saves encoded images, only decode now for efficiency
134135
def decode_images(traj):
135136
traj["observation"]["image"] = tf.io.decode_image(

0 commit comments

Comments
 (0)