Skip to content

Commit c032eed

Browse files
author
root
committed
added verbose for task ingestion + tag filtering before window copying
1 parent dc49639 commit c032eed

3 files changed

Lines changed: 302 additions & 16 deletions

File tree

olmoearth_pretrain/evals/studio_ingest/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ def main() -> int:
252252
Returns:
253253
Exit code (0 for success, non-zero for failure)
254254
"""
255+
logging.basicConfig(
256+
level=logging.INFO,
257+
format="%(asctime)s %(levelname)-5s %(name)s: %(message)s",
258+
datefmt="%Y-%m-%d %H:%M:%S",
259+
)
260+
255261
parser = argparse.ArgumentParser(
256262
prog="studio_ingest",
257263
description="Ingest Studio datasets into OlmoEarth eval system",

olmoearth_pretrain/evals/studio_ingest/ingest.py

Lines changed: 186 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import yaml
5656
from rslearn.config import DatasetConfig
5757
from rslearn.dataset.dataset import Dataset as RslearnDataset
58+
from tqdm import tqdm
5859
from upath import UPath
5960

6061
from olmoearth_pretrain.evals.datasets.rslearn_builder import parse_model_config
@@ -278,20 +279,31 @@ def _copy_from_gcs(
278279
source_path: str,
279280
dest_path: str,
280281
source_groups: list[str] | None = None,
282+
source_tags: dict[str, str] | None = None,
281283
) -> str:
282284
"""Copy dataset from GCS using gsutil with parallel transfers.
283285
284286
Uses gsutil -m for multi-threaded/multi-processing transfers.
285287
Streams output directly to console for progress visibility.
286288
289+
Note: *source_tags* filtering is not supported for GCS sources.
290+
If tags are specified a ``NotImplementedError`` is raised — download
291+
the dataset locally first or use a local source.
292+
287293
Args:
288294
source_path: GCS path (gs://bucket/path)
289295
dest_path: Local destination path
290296
source_groups: If specified, only copy these groups (subdirs under windows/)
297+
source_tags: Not supported for GCS (raises NotImplementedError).
291298
292299
Returns:
293300
Destination path
294301
"""
302+
if source_tags:
303+
raise NotImplementedError(
304+
"Tag-filtered copy is not supported for GCS sources. "
305+
"Download the dataset locally first, then ingest from a local path."
306+
)
295307
logger.info(" Copy method: gsutil (parallel GCS transfer)")
296308

297309
# Create destination directory
@@ -375,10 +387,128 @@ def _tar_copy_cmd(src: str, dst: str, use_pv: bool) -> str:
375387
return f"tar cf - -C {src} . | tar xf - -C {dst}"
376388

377389

390+
def _window_matches_tags(
391+
window_metadata_path: Path,
392+
source_tags: dict[str, str],
393+
) -> bool:
394+
"""Check whether a window's metadata.json matches all required tags.
395+
396+
Args:
397+
window_metadata_path: Path to the window's metadata.json
398+
source_tags: Tags to match. Empty string value means "key exists".
399+
400+
Returns:
401+
True if all tags match.
402+
"""
403+
try:
404+
with open(window_metadata_path) as f:
405+
meta = json.load(f)
406+
except (json.JSONDecodeError, OSError):
407+
return False
408+
409+
options = meta.get("options", {})
410+
for key, value in source_tags.items():
411+
if key not in options:
412+
return False
413+
if value and options[key] != value:
414+
return False
415+
return True
416+
417+
418+
def _collect_matching_windows(
419+
source_path: str,
420+
source_groups: list[str] | None,
421+
source_tags: dict[str, str],
422+
) -> list[tuple[str, str]]:
423+
"""Scan source windows and return (group, window_name) pairs matching tags.
424+
425+
Args:
426+
source_path: Path to rslearn dataset
427+
source_groups: If set, only scan these groups
428+
source_tags: Tags each window must have
429+
430+
Returns:
431+
List of (group_name, window_name) tuples that match.
432+
"""
433+
windows_dir = Path(source_path) / "windows"
434+
if not windows_dir.exists():
435+
return []
436+
437+
groups = source_groups or [d.name for d in windows_dir.iterdir() if d.is_dir()]
438+
logger.info(" Scanning groups: %s", groups)
439+
440+
all_window_dirs: list[tuple[str, Path]] = []
441+
for group in groups:
442+
group_dir = windows_dir / group
443+
if not group_dir.is_dir():
444+
continue
445+
for window_dir in group_dir.iterdir():
446+
if window_dir.is_dir():
447+
all_window_dirs.append((group, window_dir))
448+
449+
matched: list[tuple[str, str]] = []
450+
pbar = tqdm(all_window_dirs, desc="Scanning windows for tags", unit="win")
451+
for group, window_dir in pbar:
452+
meta_path = window_dir / "metadata.json"
453+
if meta_path.exists() and _window_matches_tags(meta_path, source_tags):
454+
matched.append((group, window_dir.name))
455+
pbar.set_postfix(matched=len(matched))
456+
pbar.close()
457+
458+
logger.info(
459+
" Tag scan complete: %d/%d windows matched tags %s",
460+
len(matched),
461+
len(all_window_dirs),
462+
source_tags,
463+
)
464+
return matched
465+
466+
467+
def _copy_filtered_windows(
468+
source_path: str,
469+
dest_path: str,
470+
matched_windows: list[tuple[str, str]],
471+
) -> None:
472+
"""Copy only the matched windows from source to destination.
473+
474+
Uses shutil.copytree per window for simplicity and correctness on Weka.
475+
476+
Args:
477+
source_path: Source dataset path
478+
dest_path: Destination dataset path
479+
matched_windows: List of (group, window_name) to copy
480+
"""
481+
from concurrent.futures import ThreadPoolExecutor, as_completed
482+
483+
num_workers = int(os.environ.get("OLMOEARTH_INGEST_WORKERS", "8"))
484+
total = len(matched_windows)
485+
logger.info(" Copying %d matched windows (workers=%d)...", total, num_workers)
486+
487+
def _copy_one(group: str, wname: str) -> str:
488+
src = Path(source_path) / "windows" / group / wname
489+
dst = Path(dest_path) / "windows" / group / wname
490+
dst.parent.mkdir(parents=True, exist_ok=True)
491+
shutil.copytree(str(src), str(dst))
492+
return wname
493+
494+
pbar = tqdm(total=total, desc="Copying windows", unit="win")
495+
with ThreadPoolExecutor(max_workers=num_workers) as pool:
496+
futures = [
497+
pool.submit(_copy_one, group, wname) for group, wname in matched_windows
498+
]
499+
for future in as_completed(futures):
500+
future.result()
501+
pbar.update(1)
502+
pbar.close()
503+
504+
logger.info(" Finished copying %d windows", total)
505+
506+
378507
def _copy_local(
379508
source_path: str,
380509
dest_path: str,
381510
source_groups: list[str] | None = None,
511+
source_tags: dict[str, str] | None = None,
382512
) -> str:
383513
"""Copy dataset locally using streaming tar pipe.
384514
@@ -388,10 +518,15 @@ def _copy_local(
388518
is preserved because tar archives relative paths from the source and
389519
recreates them at the destination.
390520
521+
When *source_tags* is provided the bulk tar copy is replaced by a
522+
per-window copy that only transfers windows whose ``metadata.json``
523+
matches the requested tags.
524+
391525
Args:
392526
source_path: Local source path
393527
dest_path: Local destination path
394528
source_groups: If specified, only copy these groups (subdirs under windows/)
529+
source_tags: If specified, only copy windows matching these tags.
395530
396531
Returns:
397532
Destination path
@@ -409,17 +544,23 @@ def _copy_local(
409544
# Create destination directory
410545
Path(dest_path).mkdir(parents=True, exist_ok=True)
411546

412-
# TODO: remove pv progress bar once copy performance is validated
413-
has_pv = shutil.which("pv") is not None
414-
415-
logger.info(
416-
" Copy method: streaming tar pipe%s", " (with pv progress)" if has_pv else ""
417-
)
418-
419547
_try_copy_config_json(source_path, dest_path)
420548

421-
if source_groups:
422-
# Copy only specified groups under windows/
549+
if source_tags:
550+
logger.info(" Copy method: tag-filtered per-window copy")
551+
matched = _collect_matching_windows(source_path, source_groups, source_tags)
552+
if not matched:
553+
raise ValueError(
554+
f"No windows in {source_path} matched tags {source_tags}. "
555+
"Check that the tag key/values are correct."
556+
)
557+
_copy_filtered_windows(source_path, dest_path, matched)
558+
elif source_groups:
559+
has_pv = shutil.which("pv") is not None
560+
logger.info(
561+
" Copy method: streaming tar pipe%s",
562+
" (with pv progress)" if has_pv else "",
563+
)
423564
logger.info(f" Copying only groups: {source_groups}")
424565
for group in source_groups:
425566
group_src = f"{source_path}/windows/{group}"
@@ -431,7 +572,11 @@ def _copy_local(
431572
subprocess.run(cmd, shell=True, check=True) # nosec B602
432573
logger.info(f" Copied group '{group}'")
433574
else:
434-
# Copy entire directory using streaming tar
575+
has_pv = shutil.which("pv") is not None
576+
logger.info(
577+
" Copy method: streaming tar pipe%s",
578+
" (with pv progress)" if has_pv else "",
579+
)
435580
cmd = _tar_copy_cmd(source_path, dest_path, has_pv)
436581
logger.info(f" Running: {cmd}")
437582
subprocess.run(cmd, shell=True, check=True) # nosec B602
@@ -444,6 +589,7 @@ def _copy_generic(
444589
source_path: str,
445590
dest_path: str,
446591
source_groups: list[str] | None = None,
592+
source_tags: dict[str, str] | None = None,
447593
) -> str:
448594
"""Fallback copy using UPath for unknown storage backends.
449595
@@ -453,6 +599,7 @@ def _copy_generic(
453599
source_path: Source path (any UPath-compatible)
454600
dest_path: Destination path
455601
source_groups: If specified, only copy these groups (subdirs under windows/)
602+
source_tags: If specified, only copy windows matching these tags.
456603
457604
Returns:
458605
Destination path
@@ -466,6 +613,20 @@ def _copy_generic(
466613

467614
_try_copy_config_json(source_path, dest_path)
468615

616+
# Tag-filtered copy: only works when source is local-like (metadata readable)
617+
if source_tags:
618+
logger.info(" Using tag-filtered copy (generic)")
619+
matched = _collect_matching_windows(source_path, source_groups, source_tags)
620+
if not matched:
621+
raise ValueError(f"No windows in {source_path} matched tags {source_tags}.")
622+
for group, wname in matched:
623+
_copy_directory_recursive(
624+
source / "windows" / group / wname,
625+
dest / "windows" / group / wname,
626+
)
627+
logger.info(" Copied %d matched windows", len(matched))
628+
return dest_path
629+
469630
# Copy windows directory (filtered by groups if specified)
470631
windows_src = source / "windows"
471632
windows_dst = dest / "windows"
@@ -524,6 +685,7 @@ def copy_dataset(
524685
source_path: str,
525686
name: str,
526687
source_groups: list[str] | None = None,
688+
source_tags: dict[str, str] | None = None,
527689
untar_source: bool = False,
528690
) -> str:
529691
"""Copy an rslearn dataset to our Weka location.
@@ -534,11 +696,17 @@ def copy_dataset(
534696
- Local/Weka (/weka, /) -> find + xargs -P (parallel local copy)
535697
- Other -> UPath generic copy (fallback)
536698
699+
When *source_tags* is provided, the copy is filtered so that only
700+
windows whose ``metadata.json`` contains the requested tag key/values
701+
are transferred. This avoids copying entire large datasets when only a
702+
subset is needed for evaluation.
703+
537704
Args:
538705
source_path: Path to source rslearn dataset
539706
name: Name for the copied dataset
540707
source_groups: If specified, only copy these groups (subdirs under windows/).
541708
If None, copies everything.
709+
source_tags: If specified, only copy windows matching these tags.
542710
untar_source: If True, source_path is a .tar.gz archive on GCS that
543711
will be streamed and extracted directly to the destination.
544712
@@ -550,10 +718,12 @@ def copy_dataset(
550718
logger.info("=== Dataset Copy ===")
551719
logger.info(f" Source: {source_path}")
552720
logger.info(f" Destination: {dest_path}")
721+
if source_tags:
722+
logger.info(f" Filtering to tags: {source_tags}")
553723
if source_groups:
554724
logger.info(f" Filtering to groups: {source_groups}")
555-
else:
556-
logger.info(" Copying all groups")
725+
if not source_groups and not source_tags:
726+
logger.info(" Copying all groups (no tag/group filter)")
557727

558728
# Check if destination already exists
559729
if Path(dest_path).exists():
@@ -565,11 +735,11 @@ def copy_dataset(
565735
if untar_source and source_path.startswith("gs://"):
566736
actual_path = _copy_from_gcs_tar(source_path, dest_path)
567737
elif source_path.startswith("gs://"):
568-
actual_path = _copy_from_gcs(source_path, dest_path, source_groups)
738+
actual_path = _copy_from_gcs(source_path, dest_path, source_groups, source_tags)
569739
elif source_path.startswith("/weka") or source_path.startswith("/"):
570-
actual_path = _copy_local(source_path, dest_path, source_groups)
740+
actual_path = _copy_local(source_path, dest_path, source_groups, source_tags)
571741
else:
572-
actual_path = _copy_generic(source_path, dest_path, source_groups)
742+
actual_path = _copy_generic(source_path, dest_path, source_groups, source_tags)
573743

574744
logger.info(f" Dataset copy complete: {actual_path}")
575745
return actual_path
@@ -892,6 +1062,7 @@ def ingest_dataset(config: IngestConfig) -> EvalDatasetEntry:
8921062
config.source_path,
8931063
config.name,
8941064
config.source_groups,
1065+
config.source_tags,
8951066
config.untar_source,
8961067
)
8971068
logger.info(f"[Step 1/6] Copy complete: {weka_path}")

0 commit comments

Comments
 (0)