5555import yaml
5656from rslearn .config import DatasetConfig
5757from rslearn .dataset .dataset import Dataset as RslearnDataset
58+ from tqdm import tqdm
5859from upath import UPath
5960
6061from olmoearth_pretrain .evals .datasets .rslearn_builder import parse_model_config
@@ -278,20 +279,31 @@ def _copy_from_gcs(
278279 source_path : str ,
279280 dest_path : str ,
280281 source_groups : list [str ] | None = None ,
282+ source_tags : dict [str , str ] | None = None ,
281283) -> str :
282284 """Copy dataset from GCS using gsutil with parallel transfers.
283285
284286 Uses gsutil -m for multi-threaded/multi-processing transfers.
285287 Streams output directly to console for progress visibility.
286288
289+ Note: *source_tags* filtering is not supported for GCS sources.
290+ If tags are specified a ``NotImplementedError`` is raised — download
291+ the dataset locally first or use a local source.
292+
287293 Args:
288294 source_path: GCS path (gs://bucket/path)
289295 dest_path: Local destination path
290296 source_groups: If specified, only copy these groups (subdirs under windows/)
297+ source_tags: Not supported for GCS (raises NotImplementedError).
291298
292299 Returns:
293300 Destination path
294301 """
302+ if source_tags :
303+ raise NotImplementedError (
304+ "Tag-filtered copy is not supported for GCS sources. "
305+ "Download the dataset locally first, then ingest from a local path."
306+ )
295307 logger .info (" Copy method: gsutil (parallel GCS transfer)" )
296308
297309 # Create destination directory
@@ -375,10 +387,128 @@ def _tar_copy_cmd(src: str, dst: str, use_pv: bool) -> str:
375387 return f"tar cf - -C { src } . | tar xf - -C { dst } "
376388
377389
390+ def _window_matches_tags (
391+ window_metadata_path : Path ,
392+ source_tags : dict [str , str ],
393+ ) -> bool :
394+ """Check whether a window's metadata.json matches all required tags.
395+
396+ Args:
397+ window_metadata_path: Path to the window's metadata.json
398+ source_tags: Tags to match. Empty string value means "key exists".
399+
400+ Returns:
401+ True if all tags match.
402+ """
403+ try :
404+ with open (window_metadata_path ) as f :
405+ meta = json .load (f )
406+ except (json .JSONDecodeError , OSError ):
407+ return False
408+
409+ options = meta .get ("options" , {})
410+ for key , value in source_tags .items ():
411+ if key not in options :
412+ return False
413+ if value and options [key ] != value :
414+ return False
415+ return True
416+
417+
418+ def _collect_matching_windows (
419+ source_path : str ,
420+ source_groups : list [str ] | None ,
421+ source_tags : dict [str , str ],
422+ ) -> list [tuple [str , str ]]:
423+ """Scan source windows and return (group, window_name) pairs matching tags.
424+
425+ Args:
426+ source_path: Path to rslearn dataset
427+ source_groups: If set, only scan these groups
428+ source_tags: Tags each window must have
429+
430+ Returns:
431+ List of (group_name, window_name) tuples that match.
432+ """
433+ windows_dir = Path (source_path ) / "windows"
434+ if not windows_dir .exists ():
435+ return []
436+
437+ groups = source_groups or [d .name for d in windows_dir .iterdir () if d .is_dir ()]
438+ logger .info (" Scanning groups: %s" , groups )
439+
440+ all_window_dirs : list [tuple [str , Path ]] = []
441+ for group in groups :
442+ group_dir = windows_dir / group
443+ if not group_dir .is_dir ():
444+ continue
445+ for window_dir in group_dir .iterdir ():
446+ if window_dir .is_dir ():
447+ all_window_dirs .append ((group , window_dir ))
448+
449+ matched : list [tuple [str , str ]] = []
450+ pbar = tqdm (all_window_dirs , desc = "Scanning windows for tags" , unit = "win" )
451+ for group , window_dir in pbar :
452+ meta_path = window_dir / "metadata.json"
453+ if meta_path .exists () and _window_matches_tags (meta_path , source_tags ):
454+ matched .append ((group , window_dir .name ))
455+ pbar .set_postfix (matched = len (matched ))
456+ pbar .close ()
457+
458+ logger .info (
459+ " Tag scan complete: %d/%d windows matched tags %s" ,
460+ len (matched ),
461+ len (all_window_dirs ),
462+ source_tags ,
463+ )
464+ return matched
465+
466+
467+ def _copy_filtered_windows (
468+ source_path : str ,
469+ dest_path : str ,
470+ matched_windows : list [tuple [str , str ]],
471+ ) -> None :
472+ """Copy only the matched windows from source to destination.
473+
474+ Uses shutil.copytree per window for simplicity and correctness on Weka.
475+
476+ Args:
477+ source_path: Source dataset path
478+ dest_path: Destination dataset path
479+ matched_windows: List of (group, window_name) to copy
480+ """
481+ from concurrent .futures import ThreadPoolExecutor , as_completed
482+
483+ num_workers = int (os .environ .get ("OLMOEARTH_INGEST_WORKERS" , "8" ))
484+ total = len (matched_windows )
485+ logger .info (" Copying %d matched windows (workers=%d)..." , total , num_workers )
486+
487+ def _copy_one (group : str , wname : str ) -> str :
488+ src = Path (source_path ) / "windows" / group / wname
489+ dst = Path (dest_path ) / "windows" / group / wname
490+ dst .parent .mkdir (parents = True , exist_ok = True )
491+ shutil .copytree (str (src ), str (dst ))
492+ return wname
493+
494+ pbar = tqdm (total = total , desc = "Copying windows" , unit = "win" )
495+ with ThreadPoolExecutor (max_workers = num_workers ) as pool :
496+ futures = [
497+ pool .submit (_copy_one , group , wname ) for group , wname in matched_windows
498+ ]
499+ for future in as_completed (futures ):
500+ future .result ()
501+ pbar .update (1 )
502+ pbar .close ()
503+
504+ logger .info (" Finished copying %d windows" , total )
505+
506+
378507def _copy_local (
379508 source_path : str ,
380509 dest_path : str ,
381510 source_groups : list [str ] | None = None ,
511+ source_tags : dict [str , str ] | None = None ,
382512) -> str :
383513 """Copy dataset locally using streaming tar pipe.
384514
@@ -388,10 +518,15 @@ def _copy_local(
388518 is preserved because tar archives relative paths from the source and
389519 recreates them at the destination.
390520
521+ When *source_tags* is provided the bulk tar copy is replaced by a
522+ per-window copy that only transfers windows whose ``metadata.json``
523+ matches the requested tags.
524+
391525 Args:
392526 source_path: Local source path
393527 dest_path: Local destination path
394528 source_groups: If specified, only copy these groups (subdirs under windows/)
529+ source_tags: If specified, only copy windows matching these tags.
395530
396531 Returns:
397532 Destination path
@@ -409,17 +544,23 @@ def _copy_local(
409544 # Create destination directory
410545 Path (dest_path ).mkdir (parents = True , exist_ok = True )
411546
412- # TODO: remove pv progress bar once copy performance is validated
413- has_pv = shutil .which ("pv" ) is not None
414-
415- logger .info (
416- " Copy method: streaming tar pipe%s" , " (with pv progress)" if has_pv else ""
417- )
418-
419547 _try_copy_config_json (source_path , dest_path )
420548
421- if source_groups :
422- # Copy only specified groups under windows/
549+ if source_tags :
550+ logger .info (" Copy method: tag-filtered per-window copy" )
551+ matched = _collect_matching_windows (source_path , source_groups , source_tags )
552+ if not matched :
553+ raise ValueError (
554+ f"No windows in { source_path } matched tags { source_tags } . "
555+ "Check that the tag key/values are correct."
556+ )
557+ _copy_filtered_windows (source_path , dest_path , matched )
558+ elif source_groups :
559+ has_pv = shutil .which ("pv" ) is not None
560+ logger .info (
561+ " Copy method: streaming tar pipe%s" ,
562+ " (with pv progress)" if has_pv else "" ,
563+ )
423564 logger .info (f" Copying only groups: { source_groups } " )
424565 for group in source_groups :
425566 group_src = f"{ source_path } /windows/{ group } "
@@ -431,7 +572,11 @@ def _copy_local(
431572 subprocess .run (cmd , shell = True , check = True ) # nosec B602
432573 logger .info (f" Copied group '{ group } '" )
433574 else :
434- # Copy entire directory using streaming tar
575+ has_pv = shutil .which ("pv" ) is not None
576+ logger .info (
577+ " Copy method: streaming tar pipe%s" ,
578+ " (with pv progress)" if has_pv else "" ,
579+ )
435580 cmd = _tar_copy_cmd (source_path , dest_path , has_pv )
436581 logger .info (f" Running: { cmd } " )
437582 subprocess .run (cmd , shell = True , check = True ) # nosec B602
@@ -444,6 +589,7 @@ def _copy_generic(
444589 source_path : str ,
445590 dest_path : str ,
446591 source_groups : list [str ] | None = None ,
592+ source_tags : dict [str , str ] | None = None ,
447593) -> str :
448594 """Fallback copy using UPath for unknown storage backends.
449595
@@ -453,6 +599,7 @@ def _copy_generic(
453599 source_path: Source path (any UPath-compatible)
454600 dest_path: Destination path
455601 source_groups: If specified, only copy these groups (subdirs under windows/)
602+ source_tags: If specified, only copy windows matching these tags.
456603
457604 Returns:
458605 Destination path
@@ -466,6 +613,20 @@ def _copy_generic(
466613
467614 _try_copy_config_json (source_path , dest_path )
468615
616+ # Tag-filtered copy: only works when source is local-like (metadata readable)
617+ if source_tags :
618+ logger .info (" Using tag-filtered copy (generic)" )
619+ matched = _collect_matching_windows (source_path , source_groups , source_tags )
620+ if not matched :
621+ raise ValueError (f"No windows in { source_path } matched tags { source_tags } ." )
622+ for group , wname in matched :
623+ _copy_directory_recursive (
624+ source / "windows" / group / wname ,
625+ dest / "windows" / group / wname ,
626+ )
627+ logger .info (" Copied %d matched windows" , len (matched ))
628+ return dest_path
629+
469630 # Copy windows directory (filtered by groups if specified)
470631 windows_src = source / "windows"
471632 windows_dst = dest / "windows"
@@ -524,6 +685,7 @@ def copy_dataset(
524685 source_path : str ,
525686 name : str ,
526687 source_groups : list [str ] | None = None ,
688+ source_tags : dict [str , str ] | None = None ,
527689 untar_source : bool = False ,
528690) -> str :
529691 """Copy an rslearn dataset to our Weka location.
@@ -534,11 +696,17 @@ def copy_dataset(
534696 - Local/Weka (/weka, /) -> find + xargs -P (parallel local copy)
535697 - Other -> UPath generic copy (fallback)
536698
699+ When *source_tags* is provided, the copy is filtered so that only
700+ windows whose ``metadata.json`` contains the requested tag key/values
701+ are transferred. This avoids copying entire large datasets when only a
702+ subset is needed for evaluation.
703+
537704 Args:
538705 source_path: Path to source rslearn dataset
539706 name: Name for the copied dataset
540707 source_groups: If specified, only copy these groups (subdirs under windows/).
541708 If None, copies everything.
709+ source_tags: If specified, only copy windows matching these tags.
542710 untar_source: If True, source_path is a .tar.gz archive on GCS that
543711 will be streamed and extracted directly to the destination.
544712
@@ -550,10 +718,12 @@ def copy_dataset(
550718 logger .info ("=== Dataset Copy ===" )
551719 logger .info (f" Source: { source_path } " )
552720 logger .info (f" Destination: { dest_path } " )
721+ if source_tags :
722+ logger .info (f" Filtering to tags: { source_tags } " )
553723 if source_groups :
554724 logger .info (f" Filtering to groups: { source_groups } " )
555- else :
556- logger .info (" Copying all groups" )
725+ if not source_groups and not source_tags :
726+ logger .info (" Copying all groups (no tag/group filter) " )
557727
558728 # Check if destination already exists
559729 if Path (dest_path ).exists ():
@@ -565,11 +735,11 @@ def copy_dataset(
565735 if untar_source and source_path .startswith ("gs://" ):
566736 actual_path = _copy_from_gcs_tar (source_path , dest_path )
567737 elif source_path .startswith ("gs://" ):
568- actual_path = _copy_from_gcs (source_path , dest_path , source_groups )
738+ actual_path = _copy_from_gcs (source_path , dest_path , source_groups , source_tags )
569739 elif source_path .startswith ("/weka" ) or source_path .startswith ("/" ):
570- actual_path = _copy_local (source_path , dest_path , source_groups )
740+ actual_path = _copy_local (source_path , dest_path , source_groups , source_tags )
571741 else :
572- actual_path = _copy_generic (source_path , dest_path , source_groups )
742+ actual_path = _copy_generic (source_path , dest_path , source_groups , source_tags )
573743
574744 logger .info (f" Dataset copy complete: { actual_path } " )
575745 return actual_path
@@ -892,6 +1062,7 @@ def ingest_dataset(config: IngestConfig) -> EvalDatasetEntry:
8921062 config .source_path ,
8931063 config .name ,
8941064 config .source_groups ,
1065+ config .source_tags ,
8951066 config .untar_source ,
8961067 )
8971068 logger .info (f"[Step 1/6] Copy complete: { weka_path } " )
0 commit comments