diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 546b3d67f1b..737c14acae3 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -25,6 +25,7 @@ import logging import shutil +from collections import defaultdict from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path @@ -45,6 +46,8 @@ DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, + DEFAULT_SUBTASKS_PATH, + flatten_dict, get_parquet_file_size_in_mb, load_episodes, update_chunk_file_indices, @@ -141,6 +144,315 @@ def delete_episodes( return new_dataset +def trim_episode_start( + dataset: LeRobotDataset, + seconds: float, + episode_indices: list[int] | None = None, + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Trim the first N seconds from selected episodes and create a new dataset. + + The operation rewrites data parquet files and updates episode metadata so that: + - frame_index starts at 0 for each trimmed episode + - timestamp starts at 0 for each trimmed episode + - global index remains contiguous across the full dataset + - dataset_from_index / dataset_to_index reflect new frame ranges + + Video files are copied as-is and per-episode video timestamps are shifted forward + for trimmed episodes. + + Episodes selected for trimming that are too short (length <= trim_frames) are skipped + from the output dataset. + + Args: + dataset: The source LeRobotDataset. + seconds: Number of seconds to remove from episode starts. + episode_indices: Optional list of episode indices to trim. If None, trims all episodes. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original. + """ + if seconds <= 0: + raise ValueError(f"seconds must be strictly positive, got {seconds}") + + if dataset.meta.episodes is None: + dataset.meta.episodes = load_episodes(dataset.meta.root) + + trim_frames = int(seconds * dataset.meta.fps) + if trim_frames <= 0: + raise ValueError( + f"seconds={seconds} corresponds to 0 frames at fps={dataset.meta.fps}. " + "Increase seconds so at least one frame is trimmed." + ) + + if episode_indices is None: + episode_indices = list(range(dataset.meta.total_episodes)) + + if len(episode_indices) == 0: + raise ValueError("No episodes specified to trim") + + episode_indices = sorted(set(episode_indices)) + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = set(episode_indices) - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + too_short = sorted( + ep_idx for ep_idx in episode_indices if int(dataset.meta.episodes[ep_idx]["length"]) <= trim_frames + ) + trim_set = set(episode_indices) + skipped_set = set(too_short) + trim_set -= skipped_set + + if too_short: + logging.warning( + f"Skipping {len(too_short)} episode(s) that are too short to trim " + f"({trim_frames} frames): {too_short}" + ) + + episodes_to_keep = [ep_idx for ep_idx in range(dataset.meta.total_episodes) if ep_idx not in skipped_set] + if not episodes_to_keep: + raise ValueError( + "All episodes selected for trimming are too short and would be skipped. " + "Try a smaller trim duration." + ) + + logging.info( + f"Trimming {len(trim_set)} episode(s) by {seconds}s and keeping {len(episodes_to_keep)} " + f"episode(s) in output" + ) + + if repo_id is None: + repo_id = f"{dataset.repo_id}_trimmed" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=dataset.meta.features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + chunks_size=dataset.meta.chunks_size, + data_files_size_in_mb=dataset.meta.data_files_size_in_mb, + video_files_size_in_mb=dataset.meta.video_files_size_in_mb, + ) + + if dataset.meta.tasks is not None: + write_tasks(dataset.meta.tasks, new_meta.root) + new_meta.tasks = dataset.meta.tasks.copy() + + subtasks_path = dataset.root / DEFAULT_SUBTASKS_PATH + if subtasks_path.exists(): + dst_subtasks_path = new_meta.root / DEFAULT_SUBTASKS_PATH + dst_subtasks_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(subtasks_path, dst_subtasks_path) + + episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)} + trim_duration_s = trim_frames / dataset.meta.fps + + episode_lengths: dict[int, int] = {} + episode_ranges: dict[int, tuple[int, int]] = {} + total_frames = 0 + for old_ep_idx in episodes_to_keep: + new_ep_idx = episode_mapping[old_ep_idx] + src_length = int(dataset.meta.episodes[old_ep_idx]["length"]) + new_length = src_length - trim_frames if old_ep_idx in trim_set else src_length + episode_lengths[new_ep_idx] = new_length + episode_ranges[new_ep_idx] = (total_frames, total_frames + new_length) + total_frames += new_length + + numeric_features = { + k: v + for k, v in dataset.meta.features.items() + if v["dtype"] not in ["image", "video", "string"] + } + episode_stats_parts: dict[int, list[dict[str, dict]]] = defaultdict(list) + episode_file_metadata: dict[int, dict[str, int]] = {} + + data_dir = dataset.root / DATA_DIR + parquet_files = sorted(data_dir.glob("*/*.parquet")) + if not parquet_files: + raise ValueError(f"No parquet files found in {data_dir}") + + for src_path in tqdm(parquet_files, desc="Trimming data files"): + df = pd.read_parquet(src_path).reset_index(drop=True) + + if len(df) == 0: + continue + + if skipped_set: + keep_mask = ~df["episode_index"].isin(skipped_set) + if not keep_mask.all(): + df = df.loc[keep_mask].copy().reset_index(drop=True) + + if len(df) == 0: + continue + + if trim_set: + trim_mask = df["episode_index"].isin(trim_set) & (df["frame_index"] < trim_frames) + if trim_mask.any(): + df = df.loc[~trim_mask].copy().reset_index(drop=True) + + if len(df) == 0: + continue + + relative_path = src_path.relative_to(dataset.root) + chunk_idx = int(relative_path.parts[1].split("-")[1]) + file_idx = int(relative_path.parts[2].split("-")[1].split(".")[0]) + + for old_ep_idx in sorted(df["episode_index"].unique().tolist()): + ep_mask = df["episode_index"] == old_ep_idx + new_ep_idx = episode_mapping[old_ep_idx] + + if old_ep_idx in trim_set: + df.loc[ep_mask, "frame_index"] = df.loc[ep_mask, "frame_index"] - trim_frames + shifted_timestamps = df.loc[ep_mask, "timestamp"].to_numpy(dtype=np.float64) - trim_duration_s + df.loc[ep_mask, "timestamp"] = np.clip(shifted_timestamps, a_min=0.0, a_max=None) + + df.loc[ep_mask, "episode_index"] = new_ep_idx + + ep_start, _ = episode_ranges[new_ep_idx] + new_indices = ep_start + df.loc[ep_mask, "frame_index"].to_numpy(dtype=np.int64) + df.loc[ep_mask, "index"] = new_indices + + if new_ep_idx in episode_file_metadata: + existing = episode_file_metadata[new_ep_idx] + if ( + existing["data/chunk_index"] != chunk_idx + or existing["data/file_index"] != file_idx + ): + raise ValueError( + f"Episode {old_ep_idx} spans multiple data files. " + "trim_episode_start currently expects one data file per episode." + ) + else: + episode_file_metadata[new_ep_idx] = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + } + + if numeric_features: + ep_df = df.loc[ep_mask] + episode_data: dict[str, np.ndarray] = {} + episode_feature_spec: dict[str, dict] = {} + + for key, feature in numeric_features.items(): + if key not in ep_df.columns: + continue + + values = ep_df[key].to_numpy() + if len(values) == 0: + continue + + first_value = values[0] + if isinstance(first_value, np.ndarray): + episode_data[key] = np.stack(values) + elif isinstance(first_value, (list, tuple)): + episode_data[key] = np.stack(values) + else: + episode_data[key] = np.asarray(values) + + episode_feature_spec[key] = feature + + if episode_data: + episode_stats_parts[new_ep_idx].append( + compute_episode_stats(episode_data, episode_feature_spec) + ) + + df["index"] = df["index"].astype(np.int64) + if "frame_index" in df.columns: + df["frame_index"] = df["frame_index"].astype(np.int64) + + dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + dst_path.parent.mkdir(parents=True, exist_ok=True) + _write_parquet(df, dst_path, new_meta) + + all_episode_stats = [] + for old_ep_idx in tqdm(episodes_to_keep, desc="Writing episode metadata"): + new_ep_idx = episode_mapping[old_ep_idx] + + if new_ep_idx not in episode_file_metadata: + raise ValueError(f"Missing data file metadata for episode {old_ep_idx}") + + from_idx, to_idx = episode_ranges[new_ep_idx] + src_episode = dataset.meta.episodes[old_ep_idx] + ep_data_meta = episode_file_metadata[new_ep_idx] + + stats_parts = episode_stats_parts.get(new_ep_idx, []) + ep_stats = aggregate_stats(stats_parts) if len(stats_parts) > 1 else (stats_parts[0] if stats_parts else {}) + if ep_stats: + all_episode_stats.append(ep_stats) + + episode_meta = { + "data/chunk_index": ep_data_meta["data/chunk_index"], + "data/file_index": ep_data_meta["data/file_index"], + "dataset_from_index": from_idx, + "dataset_to_index": to_idx, + } + + for video_key in dataset.meta.video_keys: + from_ts = src_episode[f"videos/{video_key}/from_timestamp"] + if old_ep_idx in trim_set: + from_ts += trim_duration_s + episode_meta.update( + { + f"videos/{video_key}/chunk_index": src_episode[f"videos/{video_key}/chunk_index"], + f"videos/{video_key}/file_index": src_episode[f"videos/{video_key}/file_index"], + f"videos/{video_key}/from_timestamp": from_ts, + f"videos/{video_key}/to_timestamp": src_episode[f"videos/{video_key}/to_timestamp"], + } + ) + + episode_dict = { + "episode_index": new_ep_idx, + "tasks": src_episode["tasks"], + "length": episode_lengths[new_ep_idx], + } + episode_dict.update(episode_meta) + if ep_stats: + episode_dict.update(flatten_dict({"stats": ep_stats})) + + new_meta._save_episode_metadata(episode_dict) + + new_meta._close_writer() + + if new_meta.video_keys: + _copy_videos(dataset, new_meta) + + new_meta.info.update( + { + "total_episodes": len(episodes_to_keep), + "total_frames": total_frames, + "total_tasks": len(new_meta.tasks) if new_meta.tasks is not None else 0, + "splits": {"train": f"0:{len(episodes_to_keep)}"}, + } + ) + + if new_meta.video_keys and dataset.meta.video_keys: + for key in new_meta.video_keys: + if key in dataset.meta.features: + new_meta.info["features"][key]["info"] = dataset.meta.info["features"][key].get("info", {}) + + write_info(new_meta.info, new_meta.root) + + merged_stats = aggregate_stats(all_episode_stats) if all_episode_stats else {} + if dataset.meta.stats: + for key, value in dataset.meta.stats.items(): + if key not in merged_stats: + merged_stats[key] = value + if merged_stats: + write_stats(merged_stats, new_meta.root) + + return LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + def split_dataset( dataset: LeRobotDataset, splits: dict[str, float | list[int]], diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 49825317db3..695c663873a 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -117,6 +117,13 @@ --operation.new_task "Default task" \ --operation.episode_tasks '{"5": "Special task for episode 5"}' +Trim first 3 seconds from all episodes: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_trim3s \ + --operation.type trim_episode_start \ + --operation.seconds 3.0 + Convert image dataset to video format and save locally: lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ @@ -170,6 +177,7 @@ modify_tasks, remove_feature, split_dataset, + trim_episode_start, ) from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.constants import HF_LEROBOT_HOME @@ -215,6 +223,13 @@ class ModifyTasksConfig(OperationConfig): episode_tasks: dict[str, str] | None = None +@OperationConfig.register_subclass("trim_episode_start") +@dataclass +class TrimEpisodeStartConfig(OperationConfig): + seconds: float | None = None + episode_indices: list[int] | None = None + + @OperationConfig.register_subclass("convert_image_to_video") @dataclass class ConvertImageToVideoConfig(OperationConfig): @@ -464,6 +479,41 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None: modified_dataset.push_to_hub() +def handle_trim_episode_start(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, TrimEpisodeStartConfig): + raise ValueError("Operation config must be TrimEpisodeStartConfig") + + if cfg.operation.seconds is None: + raise ValueError("seconds must be specified for trim_episode_start operation") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + logging.info( + f"Trimming first {cfg.operation.seconds}s from episodes " + f"{cfg.operation.episode_indices if cfg.operation.episode_indices else 'ALL'} in {cfg.repo_id}" + ) + new_dataset = trim_episode_start( + dataset=dataset, + seconds=cfg.operation.seconds, + episode_indices=cfg.operation.episode_indices, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() + + def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: # Note: Parser may create any config type with the right fields, so we access fields directly # instead of checking isinstance() @@ -594,6 +644,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_remove_feature(cfg) elif operation_type == "modify_tasks": handle_modify_tasks(cfg) + elif operation_type == "trim_episode_start": + handle_trim_episode_start(cfg) elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) elif operation_type == "info": diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 1de1996307b..3f84176827e 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -29,6 +29,7 @@ modify_tasks, remove_feature, split_dataset, + trim_episode_start, ) from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset @@ -142,6 +143,104 @@ def test_delete_empty_list(sample_dataset, tmp_path): ) +def test_trim_episode_start_updates_indices(sample_dataset, tmp_path): + """Test trimming episode starts updates frame/timestamp/index metadata consistently.""" + output_dir = tmp_path / "trimmed" + trim_seconds = 0.1 # 3 frames at 30 FPS + trim_frames = int(trim_seconds * sample_dataset.meta.fps) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = trim_episode_start( + sample_dataset, + seconds=trim_seconds, + output_dir=output_dir, + ) + + expected_length = 10 - trim_frames + assert new_dataset.meta.total_episodes == sample_dataset.meta.total_episodes + assert new_dataset.meta.total_frames == sample_dataset.meta.total_episodes * expected_length + + indices = [int(i.item()) for i in new_dataset.hf_dataset["index"]] + assert indices == list(range(new_dataset.meta.total_frames)) + + episode_indices = [int(i.item()) for i in new_dataset.hf_dataset["episode_index"]] + frame_indices = [int(i.item()) for i in new_dataset.hf_dataset["frame_index"]] + timestamps = [float(i.item()) for i in new_dataset.hf_dataset["timestamp"]] + + for ep_idx in range(sample_dataset.meta.total_episodes): + ep_frame_indices = [f for e, f in zip(episode_indices, frame_indices, strict=False) if e == ep_idx] + ep_timestamps = [t for e, t in zip(episode_indices, timestamps, strict=False) if e == ep_idx] + + assert len(ep_frame_indices) == expected_length + assert ep_frame_indices == list(range(expected_length)) + assert ep_timestamps[0] == pytest.approx(0.0) + assert ep_timestamps[-1] == pytest.approx((expected_length - 1) / sample_dataset.meta.fps) + + ep_meta = new_dataset.meta.episodes[ep_idx] + assert int(ep_meta["length"]) == expected_length + assert int(ep_meta["dataset_from_index"]) == ep_idx * expected_length + assert int(ep_meta["dataset_to_index"]) == (ep_idx + 1) * expected_length + + +def test_trim_episode_start_skips_too_short_episodes(tmp_path, empty_lerobot_dataset_factory): + """Test too-short episodes are skipped and remaining episodes are reindexed.""" + features = { + "action": {"dtype": "float32", "shape": (2,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (2,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (32, 32, 3), "names": None}, + } + dataset = empty_lerobot_dataset_factory(root=tmp_path / "source", features=features) + + for ep_len in [10, 2, 10]: + for _ in range(ep_len): + dataset.add_frame( + { + "action": np.random.randn(2).astype(np.float32), + "observation.state": np.random.randn(2).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(32, 32, 3), dtype=np.uint8), + "task": "task", + } + ) + dataset.save_episode() + dataset.finalize() + + trim_seconds = 0.1 # 3 frames at 30 FPS + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "trimmed") + + new_dataset = trim_episode_start( + dataset, + seconds=trim_seconds, + output_dir=tmp_path / "trimmed", + ) + + # Episode 1 is too short and gets skipped. Remaining episodes are trimmed and reindexed. + assert new_dataset.meta.total_episodes == 2 + assert new_dataset.meta.total_frames == 14 + assert sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) == [0, 1] + assert [int(ep["length"]) for ep in new_dataset.meta.episodes] == [7, 7] + + +def test_trim_episode_start_rejects_when_all_selected_are_too_short(sample_dataset, tmp_path): + """Test trimming fails when all selected episodes are too short and would be skipped.""" + with pytest.raises(ValueError, match="All episodes selected for trimming are too short"): + trim_episode_start( + sample_dataset, + seconds=1.0, # 30 frames > 10-frame episodes + output_dir=tmp_path / "trimmed", + ) + + def test_split_by_episodes(sample_dataset, tmp_path): """Test splitting dataset by specific episode indices.""" splits = { diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py index 4d758ae35cd..7449ed5aa87 100644 --- a/tests/scripts/test_edit_dataset_parsing.py +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -28,6 +28,7 @@ RemoveFeatureConfig, SplitConfig, _validate_config, + TrimEpisodeStartConfig, ) @@ -47,6 +48,7 @@ class TestOperationTypeParsing: ("merge", MergeConfig), ("remove_feature", RemoveFeatureConfig), ("modify_tasks", ModifyTasksConfig), + ("trim_episode_start", TrimEpisodeStartConfig), ("convert_image_to_video", ConvertImageToVideoConfig), ("info", InfoConfig), ], @@ -77,6 +79,7 @@ def test_non_merge_requires_repo_id(self): ("merge", MergeConfig), ("remove_feature", RemoveFeatureConfig), ("modify_tasks", ModifyTasksConfig), + ("trim_episode_start", TrimEpisodeStartConfig), ("convert_image_to_video", ConvertImageToVideoConfig), ("info", InfoConfig), ],