Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 312 additions & 0 deletions src/lerobot/datasets/dataset_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import logging
import shutil
from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
Expand All @@ -45,6 +46,8 @@
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
flatten_dict,
get_parquet_file_size_in_mb,
load_episodes,
update_chunk_file_indices,
Expand Down Expand Up @@ -141,6 +144,315 @@ def delete_episodes(
return new_dataset


def trim_episode_start(
dataset: LeRobotDataset,
seconds: float,
episode_indices: list[int] | None = None,
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Trim the first N seconds from selected episodes and create a new dataset.

The operation rewrites data parquet files and updates episode metadata so that:
- frame_index starts at 0 for each trimmed episode
- timestamp starts at 0 for each trimmed episode
- global index remains contiguous across the full dataset
- dataset_from_index / dataset_to_index reflect new frame ranges

Video files are copied as-is and per-episode video timestamps are shifted forward
for trimmed episodes.

Episodes selected for trimming that are too short (length <= trim_frames) are skipped
from the output dataset.

Args:
dataset: The source LeRobotDataset.
seconds: Number of seconds to remove from episode starts.
episode_indices: Optional list of episode indices to trim. If None, trims all episodes.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
"""
if seconds <= 0:
raise ValueError(f"seconds must be strictly positive, got {seconds}")

if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.meta.root)

trim_frames = int(seconds * dataset.meta.fps)
if trim_frames <= 0:
raise ValueError(
f"seconds={seconds} corresponds to 0 frames at fps={dataset.meta.fps}. "
"Increase seconds so at least one frame is trimmed."
)

if episode_indices is None:
episode_indices = list(range(dataset.meta.total_episodes))

if len(episode_indices) == 0:
raise ValueError("No episodes specified to trim")

episode_indices = sorted(set(episode_indices))
valid_indices = set(range(dataset.meta.total_episodes))
invalid = set(episode_indices) - valid_indices
if invalid:
raise ValueError(f"Invalid episode indices: {invalid}")

too_short = sorted(
ep_idx for ep_idx in episode_indices if int(dataset.meta.episodes[ep_idx]["length"]) <= trim_frames
)
trim_set = set(episode_indices)
skipped_set = set(too_short)
trim_set -= skipped_set

if too_short:
logging.warning(
f"Skipping {len(too_short)} episode(s) that are too short to trim "
f"({trim_frames} frames): {too_short}"
)

episodes_to_keep = [ep_idx for ep_idx in range(dataset.meta.total_episodes) if ep_idx not in skipped_set]
if not episodes_to_keep:
raise ValueError(
"All episodes selected for trimming are too short and would be skipped. "
"Try a smaller trim duration."
)

logging.info(
f"Trimming {len(trim_set)} episode(s) by {seconds}s and keeping {len(episodes_to_keep)} "
f"episode(s) in output"
)

if repo_id is None:
repo_id = f"{dataset.repo_id}_trimmed"
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id

new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=dataset.meta.features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(dataset.meta.video_keys) > 0,
chunks_size=dataset.meta.chunks_size,
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
)

if dataset.meta.tasks is not None:
write_tasks(dataset.meta.tasks, new_meta.root)
new_meta.tasks = dataset.meta.tasks.copy()

subtasks_path = dataset.root / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
dst_subtasks_path = new_meta.root / DEFAULT_SUBTASKS_PATH
dst_subtasks_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(subtasks_path, dst_subtasks_path)

episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)}
trim_duration_s = trim_frames / dataset.meta.fps

episode_lengths: dict[int, int] = {}
episode_ranges: dict[int, tuple[int, int]] = {}
total_frames = 0
for old_ep_idx in episodes_to_keep:
new_ep_idx = episode_mapping[old_ep_idx]
src_length = int(dataset.meta.episodes[old_ep_idx]["length"])
new_length = src_length - trim_frames if old_ep_idx in trim_set else src_length
episode_lengths[new_ep_idx] = new_length
episode_ranges[new_ep_idx] = (total_frames, total_frames + new_length)
total_frames += new_length

numeric_features = {
k: v
for k, v in dataset.meta.features.items()
if v["dtype"] not in ["image", "video", "string"]
}
episode_stats_parts: dict[int, list[dict[str, dict]]] = defaultdict(list)
episode_file_metadata: dict[int, dict[str, int]] = {}

data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")

for src_path in tqdm(parquet_files, desc="Trimming data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)

if len(df) == 0:
continue

if skipped_set:
keep_mask = ~df["episode_index"].isin(skipped_set)
if not keep_mask.all():
df = df.loc[keep_mask].copy().reset_index(drop=True)

if len(df) == 0:
continue

if trim_set:
trim_mask = df["episode_index"].isin(trim_set) & (df["frame_index"] < trim_frames)
if trim_mask.any():
df = df.loc[~trim_mask].copy().reset_index(drop=True)

if len(df) == 0:
continue

relative_path = src_path.relative_to(dataset.root)
chunk_idx = int(relative_path.parts[1].split("-")[1])
file_idx = int(relative_path.parts[2].split("-")[1].split(".")[0])

for old_ep_idx in sorted(df["episode_index"].unique().tolist()):
ep_mask = df["episode_index"] == old_ep_idx
new_ep_idx = episode_mapping[old_ep_idx]

if old_ep_idx in trim_set:
df.loc[ep_mask, "frame_index"] = df.loc[ep_mask, "frame_index"] - trim_frames
shifted_timestamps = df.loc[ep_mask, "timestamp"].to_numpy(dtype=np.float64) - trim_duration_s
df.loc[ep_mask, "timestamp"] = np.clip(shifted_timestamps, a_min=0.0, a_max=None)

df.loc[ep_mask, "episode_index"] = new_ep_idx

ep_start, _ = episode_ranges[new_ep_idx]
new_indices = ep_start + df.loc[ep_mask, "frame_index"].to_numpy(dtype=np.int64)
df.loc[ep_mask, "index"] = new_indices

if new_ep_idx in episode_file_metadata:
existing = episode_file_metadata[new_ep_idx]
if (
existing["data/chunk_index"] != chunk_idx
or existing["data/file_index"] != file_idx
):
raise ValueError(
f"Episode {old_ep_idx} spans multiple data files. "
"trim_episode_start currently expects one data file per episode."
)
else:
episode_file_metadata[new_ep_idx] = {
"data/chunk_index": chunk_idx,
"data/file_index": file_idx,
}

if numeric_features:
ep_df = df.loc[ep_mask]
episode_data: dict[str, np.ndarray] = {}
episode_feature_spec: dict[str, dict] = {}

for key, feature in numeric_features.items():
if key not in ep_df.columns:
continue

values = ep_df[key].to_numpy()
if len(values) == 0:
continue

first_value = values[0]
if isinstance(first_value, np.ndarray):
episode_data[key] = np.stack(values)
elif isinstance(first_value, (list, tuple)):
episode_data[key] = np.stack(values)
else:
episode_data[key] = np.asarray(values)

episode_feature_spec[key] = feature

if episode_data:
episode_stats_parts[new_ep_idx].append(
compute_episode_stats(episode_data, episode_feature_spec)
)

df["index"] = df["index"].astype(np.int64)
if "frame_index" in df.columns:
df["frame_index"] = df["frame_index"].astype(np.int64)

dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)
_write_parquet(df, dst_path, new_meta)

all_episode_stats = []
for old_ep_idx in tqdm(episodes_to_keep, desc="Writing episode metadata"):
new_ep_idx = episode_mapping[old_ep_idx]

if new_ep_idx not in episode_file_metadata:
raise ValueError(f"Missing data file metadata for episode {old_ep_idx}")

from_idx, to_idx = episode_ranges[new_ep_idx]
src_episode = dataset.meta.episodes[old_ep_idx]
ep_data_meta = episode_file_metadata[new_ep_idx]

stats_parts = episode_stats_parts.get(new_ep_idx, [])
ep_stats = aggregate_stats(stats_parts) if len(stats_parts) > 1 else (stats_parts[0] if stats_parts else {})
if ep_stats:
all_episode_stats.append(ep_stats)

episode_meta = {
"data/chunk_index": ep_data_meta["data/chunk_index"],
"data/file_index": ep_data_meta["data/file_index"],
"dataset_from_index": from_idx,
"dataset_to_index": to_idx,
}

for video_key in dataset.meta.video_keys:
from_ts = src_episode[f"videos/{video_key}/from_timestamp"]
if old_ep_idx in trim_set:
from_ts += trim_duration_s
episode_meta.update(
{
f"videos/{video_key}/chunk_index": src_episode[f"videos/{video_key}/chunk_index"],
f"videos/{video_key}/file_index": src_episode[f"videos/{video_key}/file_index"],
f"videos/{video_key}/from_timestamp": from_ts,
f"videos/{video_key}/to_timestamp": src_episode[f"videos/{video_key}/to_timestamp"],
}
)

episode_dict = {
"episode_index": new_ep_idx,
"tasks": src_episode["tasks"],
"length": episode_lengths[new_ep_idx],
}
episode_dict.update(episode_meta)
if ep_stats:
episode_dict.update(flatten_dict({"stats": ep_stats}))

new_meta._save_episode_metadata(episode_dict)

new_meta._close_writer()

if new_meta.video_keys:
_copy_videos(dataset, new_meta)

new_meta.info.update(
{
"total_episodes": len(episodes_to_keep),
"total_frames": total_frames,
"total_tasks": len(new_meta.tasks) if new_meta.tasks is not None else 0,
"splits": {"train": f"0:{len(episodes_to_keep)}"},
}
)

if new_meta.video_keys and dataset.meta.video_keys:
for key in new_meta.video_keys:
if key in dataset.meta.features:
new_meta.info["features"][key]["info"] = dataset.meta.info["features"][key].get("info", {})

write_info(new_meta.info, new_meta.root)

merged_stats = aggregate_stats(all_episode_stats) if all_episode_stats else {}
if dataset.meta.stats:
for key, value in dataset.meta.stats.items():
if key not in merged_stats:
merged_stats[key] = value
if merged_stats:
write_stats(merged_stats, new_meta.root)

return LeRobotDataset(
repo_id=repo_id,
root=output_dir,
image_transforms=dataset.image_transforms,
delta_timestamps=dataset.delta_timestamps,
tolerance_s=dataset.tolerance_s,
)


def split_dataset(
dataset: LeRobotDataset,
splits: dict[str, float | list[int]],
Expand Down
Loading
Loading