diff --git a/README.md b/README.md index 1446df4..95b8edd 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ Run the pipeline via the Python CLI. Retry behavior is driven by your YAML confi - `execution_params.retry: true` → automatically retries failed shards until completion or `max_retries` - `execution_params.retry: false` → submits/runs once; you can later trigger retries via `check` +- `execution_params.merge: true` → after a successful run, automatically merges shard outputs ```bash python -m mmirage.cli run --config configs/config_mock.yaml @@ -55,6 +56,31 @@ To check status and submit retries for failed shards: python -m mmirage.cli check --config configs/config_mock.yaml --retry ``` +To merge shards from the CLI directly: + +```bash +mmirage merge --config configs/config_mock.yaml +``` + +To merge shards without a config file (input directory + output directory only): + +```bash +mmirage merge-dir --input-dir /path/to/shards --output-dir /path/to/merged +``` + +`--input-dir` can point either to a single dataset directory that contains `shard_*` +folders, or to a parent directory containing multiple dataset subdirectories. +If `shard_*` folders are present directly in `--input-dir`, MMIRAGE merges that +root dataset directly and ignores nested internal folders. + +For multiple datasets, you can also choose a shared merge root: + +```bash +mmirage merge --config configs/config_mock.yaml --output-root /path/to/merged +``` + +MMIRAGE still keeps datasets separate by creating one subdirectory per dataset under the root. + ### Text-only: Reformatting dataset Suppose you have a dataset with samples of the following format @@ -119,6 +145,7 @@ processing_params: execution_params: mode: local retry: false + merge: false ``` Configuration explanation: @@ -134,6 +161,11 @@ Configuration explanation: - `execution_params`: - `mode`: "local" to run shard processing in the current Python environment or "slurm" to run through SLURM by submitting an sbatch array job. - `retry`: If true, MMIRAGE automatically retries failed shards until they succeed or `max_retries` is reached. If false, the pipeline runs/submits once, and retries can be triggered later via the check/retry CLI commands. + - `merge`: If true, MMIRAGE merges shard outputs after a successful `run`. Merged datasets are written under each dataset `output_dir` in a `merged` subdirectory. + +Merge output behavior with multiple datasets: +- Default (`run` with `execution_params.merge: true`, or `merge` without `--output-root`): each dataset is merged to its own `/merged`. +- Shared root (`merge --output-root ...`): one merged subdirectory is created per dataset under the root. ### Multimodal: Processing images with VLMs diff --git a/configs/config_comprehensive.yaml b/configs/config_comprehensive.yaml index 57291bc..f6418d6 100644 --- a/configs/config_comprehensive.yaml +++ b/configs/config_comprehensive.yaml @@ -119,6 +119,11 @@ execution_params: # - true: submit, wait, and keep retrying failed shards until success or retry budget exhaustion retry: true + # Whether to merge shard outputs after a successful run. + # - false: keep shard_* outputs only + # - true: build merged datasets from shard_* outputs + merge: false + # Maximum number of times to retry a failed shard (default: 3) max_retries: 3 diff --git a/configs/config_mock.yaml b/configs/config_mock.yaml index 9c30dbc..9d70885 100644 --- a/configs/config_mock.yaml +++ b/configs/config_mock.yaml @@ -55,5 +55,6 @@ processing_params: execution_params: mode: local retry: false + merge: true report_dir: ~/reports hf_home: ~/hf diff --git a/configs/config_mock_vision.yaml b/configs/config_mock_vision.yaml index f90c6a4..7d881ea 100644 --- a/configs/config_mock_vision.yaml +++ b/configs/config_mock_vision.yaml @@ -44,5 +44,6 @@ processing_params: execution_params: mode: local retry: false + merge: true report_dir: ~/reports hf_home: ~/hf diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index d1c2add..4b40cc0 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -23,6 +23,7 @@ ) from mmirage.config.config import MMirageConfig from mmirage.config.utils import load_mmirage_config +from mmirage.merge_shards import MergeReport, merge_from_config, merge_input_dir logger = logging.getLogger(__name__) @@ -48,13 +49,20 @@ def run_local(config_path: str, shard_id: Optional[int] = None) -> int: return result.returncode -def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = False) -> int: +def launch_pipeline( + cfg: MMirageConfig, + config_path: str, + force_retry: bool = False, + require_completion: bool = False, +) -> int: """Launch the pipeline according to execution mode and retry settings. Args: cfg: Parsed MMIRAGE configuration object. config_path: Absolute path to the MMIRAGE YAML config file. force_retry: If True, enable retry orchestration regardless of config flag. + require_completion: If True, wait for completion and verify shard + status before returning success in SLURM mode when auto-retry is off. Returns: Exit code: 0 on success, 1 on failure. @@ -114,7 +122,17 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids or 'ALL'}") if not auto_retry: - return 0 + if not require_completion: + return 0 + + wait_for_slurm_job(job_id, cfg) + failed_shards, summary = check_failed_shards(cfg) + status_code = status_exit_code(failed_shards, summary) + if status_code == 0: + logger.info("All shards completed successfully") + else: + logger.error("SLURM run completed with failed shards; merge will not start") + return status_code wait_for_slurm_job(job_id, cfg) failed_shards, summary = check_failed_shards(cfg) @@ -223,9 +241,67 @@ def build_argparser() -> argparse.ArgumentParser: help="Run a single shard locally (overrides execution mode)", ) + merge_parser = subparsers.add_parser( + "merge", + help="Merge shard outputs listed in config.loading_params.datasets", + ) + add_shared_arguments(merge_parser) + merge_parser.add_argument( + "--output-dir", + "--output-root", + dest="output_dir", + default=None, + help=( + "Optional root directory for merged outputs. MMIRAGE creates one " + "subdirectory per configured dataset under this root. If omitted, " + "each dataset is merged into /merged" + ), + ) + + merge_dir_parser = subparsers.add_parser( + "merge-dir", + help="Merge shards directly from an input directory into an output directory", + ) + merge_dir_parser.add_argument( + "--input-dir", + required=True, + help=( + "Input directory containing one dataset with shard_* folders, or " + "multiple dataset subdirectories each containing shard_* folders" + ), + ) + merge_dir_parser.add_argument( + "--output-dir", + required=True, + help="Output directory for merged dataset(s)", + ) + merge_dir_parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Log verbosity", + ) + return parser +def log_merge_reports(reports: List[MergeReport]) -> None: + """Log merge summary for one or more datasets.""" + for report in reports: + skipped_total = report.skipped_invalid_dirs + report.skipped_zero_rows + logger.info( + "Merged dataset %s: shards=%d rows=%d output=%s skipped=%d " + "(invalid=%d, zero_rows=%d)", + report.dataset_name, + report.used_shards, + report.merged_rows, + report.output_dir, + skipped_total, + report.skipped_invalid_dirs, + report.skipped_zero_rows, + ) + + def parse_shard_ids(raw_value: Optional[str], num_shards: Optional[int] = None) -> List[int]: """Parse a comma-separated shard id list. @@ -271,7 +347,22 @@ def handle_run(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) - """ if args.shard_id is not None: return run_local(config_path, args.shard_id) - return launch_pipeline(cfg, config_path, force_retry=args.force_retry) + + exit_code = launch_pipeline( + cfg, + config_path, + force_retry=args.force_retry, + require_completion=cfg.execution_params.merge, + ) + if exit_code != 0: + return exit_code + + if cfg.execution_params.merge: + logger.info("Execution_params.merge is true; merging shard outputs") + reports = merge_from_config(cfg) + log_merge_reports(reports) + + return 0 def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: @@ -370,6 +461,36 @@ def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) ) +def handle_merge(args: argparse.Namespace, cfg: MMirageConfig, _config_path: str) -> int: + """Merge shard outputs defined in config.loading_params.datasets. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + _config_path: Absolute path to the MMIRAGE YAML config file (not needed here). + + Returns: + Exit code for merge outcome. + """ + reports = merge_from_config(cfg, output_root=args.output_dir) + log_merge_reports(reports) + return 0 + + +def handle_merge_dir(args: argparse.Namespace) -> int: + """Merge shard outputs directly from input/output directory arguments. + + Args: + args: Parsed CLI namespace. + + Returns: + Exit code for merge outcome. + """ + reports = merge_input_dir(args.input_dir, args.output_dir) + log_merge_reports(reports) + return 0 + + def main() -> None: """CLI entry point.""" parser = build_argparser() @@ -377,6 +498,9 @@ def main() -> None: configure_logging(args.log_level) try: + if args.command == "merge-dir": + sys.exit(handle_merge_dir(args)) + config_path = os.path.abspath(args.config) cfg = load_mmirage_config(config_path) @@ -388,6 +512,7 @@ def main() -> None: "submit": handle_submit, "check": handle_check, "retry": handle_retry, + "merge": handle_merge, } handler = handlers.get(args.command) if handler is None: diff --git a/src/mmirage/config/config.py b/src/mmirage/config/config.py index 6063f7a..3c31252 100644 --- a/src/mmirage/config/config.py +++ b/src/mmirage/config/config.py @@ -18,6 +18,7 @@ class ExecutionParams: Attributes: mode: Execution mode: "local" or "slurm". Defaults to "local". retry: Whether automatic retry orchestration is enabled. Defaults to False. + merge: Whether to merge shard outputs after a successful run. Defaults to False. max_retries: Maximum number of retries for failed shards. Defaults to 3. poll_interval_seconds: Seconds to wait between polling job status. Defaults to 30. settle_time_seconds: Seconds to wait after job completes before checking results. Defaults to 60. @@ -41,6 +42,7 @@ class ExecutionParams: mode: str = "local" retry: bool = False + merge: bool = False max_retries: int = 3 poll_interval_seconds: int = 30 settle_time_seconds: int = 60 diff --git a/src/mmirage/merge_shards.py b/src/mmirage/merge_shards.py index 433feb5..fbddf10 100644 --- a/src/mmirage/merge_shards.py +++ b/src/mmirage/merge_shards.py @@ -1,13 +1,41 @@ -"""Script to merge processed dataset shards.""" +"""Merge processed dataset shards.""" import argparse import os -from typing import Dict, List +import logging +from typing import Dict, List, Optional from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk +from mmirage.config.config import MMirageConfig from mmirage.core.loader.base import DatasetLike -from mmirage.shard_utils import _count_rows +from mmirage.shard_utils import ( + _count_rows, + _save_dataset_atomic, + _validate_safe_output_dir, + MergeReport, + _list_shard_dirs, + _dataset_dirs, + _validate_input_dir, +) + +logger = logging.getLogger(__name__) + + +def _configure_logging(level: str) -> None: + """Configure logging for direct module execution. + + Keeps existing logging configuration intact when this module is invoked + from another CLI entrypoint that already configured handlers. + """ + root_logger = logging.getLogger() + if root_logger.handlers: + return + + logging.basicConfig( + level=getattr(logging, level.upper(), logging.INFO), + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) def _merge_datasetdict(shard_dsets: List[DatasetDict]) -> DatasetDict: @@ -39,64 +67,91 @@ def _merge_shards(shard_dsets: List[DatasetLike]) -> DatasetLike: ) -def _list_shard_dirs(dataset_dir: str) -> List[str]: - """List shard directories in a dataset directory.""" - shard_dirs: List[str] = [] - for name in os.listdir(dataset_dir): - if not name.startswith("shard_"): - continue - path = os.path.join(dataset_dir, name) - if os.path.isdir(path): - shard_dirs.append(path) - - def _shard_key(path: str) -> int: - base = os.path.basename(path) - suffix = base.removeprefix("shard_") - return int(suffix) if suffix.isdigit() else 0 +def merge_dataset_dir(dataset_dir: str, output_dir: str) -> MergeReport: + """Merge one dataset directory containing shard_* folders. - shard_dirs.sort(key=_shard_key) - return shard_dirs + Args: + dataset_dir: Input directory containing shard_* folders. + output_dir: Destination directory for merged dataset. + Returns: + MergeReport with summary details. + """ + dataset_dir = os.path.abspath(os.path.expandvars(os.path.expanduser(dataset_dir))) + normalized_output_dir = os.path.abspath(os.path.expandvars(os.path.expanduser(output_dir))) + _validate_input_dir(dataset_dir, "dataset_dir") + _validate_safe_output_dir(dataset_dir, normalized_output_dir) + + shard_dirs = _list_shard_dirs(dataset_dir) + if not shard_dirs: + raise RuntimeError(f"No shard_* folders found in {dataset_dir}.") + + shard_dsets: List[DatasetLike] = [] + skipped_invalid_dirs = 0 + skipped_zero_rows = 0 + + for shard_dir in shard_dirs: + try: + ds = load_from_disk(shard_dir) + except FileNotFoundError as e: + logger.warning( + f"{shard_dir} is not a valid HF dataset directory, skipping. " + f"Reason: {e}" + ) + skipped_invalid_dirs += 1 + continue -def _dataset_dirs(input_dir: str) -> List[str]: - """Find dataset directories containing shard folders.""" - candidates: List[str] = [] - for name in os.listdir(input_dir): - path = os.path.join(input_dir, name) - if not os.path.isdir(path): + row_count = _count_rows(ds) + if row_count == 0: + logger.warning(f"Shard dataset has 0 rows, skipping: {shard_dir}") + skipped_zero_rows += 1 continue - if _list_shard_dirs(path): - candidates.append(path) - return sorted(candidates) + logger.info(f"Using {os.path.basename(shard_dir)} with {row_count} rows.") + shard_dsets.append(ds) -def main(): - """Merge processed shard datasets into per-dataset Hugging Face datasets. + if not shard_dsets: + raise RuntimeError( + f"No non-empty shards found in {dataset_dir}. " + f"empty/invalid dirs: {skipped_invalid_dirs}, " + f"zero-row datasets: {skipped_zero_rows}." + ) - Scans --input-dir for dataset subdirectories containing shard_* folders. - For each dataset directory, merges shard datasets and writes to --output-dir - while preserving the dataset directory name. - """ - ap = argparse.ArgumentParser("Merge processed shard datasets into HF datasets.") - ap.add_argument( - "--input-dir", - required=True, - help="Directory containing dataset subdirectories with shard_* folders.", - ) - ap.add_argument( - "--output-dir", - required=True, - help="Directory to write merged datasets into.", + ds_merged = _merge_shards(shard_dsets) + merged_rows = _count_rows(ds_merged) + + _save_dataset_atomic(ds_merged, normalized_output_dir) + + dataset_name = os.path.basename(os.path.normpath(dataset_dir)) + return MergeReport( + dataset_name=dataset_name, + input_dir=dataset_dir, + output_dir=normalized_output_dir, + used_shards=len(shard_dsets), + merged_rows=merged_rows, + skipped_invalid_dirs=skipped_invalid_dirs, + skipped_zero_rows=skipped_zero_rows, ) - args = ap.parse_args() - input_dir = args.input_dir - output_dir = args.output_dir - dataset_dirs = _dataset_dirs(input_dir) +def merge_input_dir(input_dir: str, output_dir: str) -> List[MergeReport]: + """Merge all shard datasets found under an input directory. + + The input can be either: + - one dataset dir containing shard_* folders directly + - a parent dir containing multiple dataset subdirectories, each with shard_* + """ + input_dir = os.path.abspath(os.path.expandvars(os.path.expanduser(input_dir))) + output_dir = os.path.abspath(os.path.expandvars(os.path.expanduser(output_dir))) + _validate_input_dir(input_dir, "input_dir") + root_shards = _list_shard_dirs(input_dir) + dataset_dirs = _dataset_dirs(input_dir) - if not dataset_dirs and root_shards: + # If shards are present at the input root, treat it as a single dataset. + # This avoids accidentally picking internal subdirectories (for example + # pipeline state folders that may also contain shard_* entries). + if root_shards: dataset_dirs = [input_dir] if not dataset_dirs: @@ -104,61 +159,98 @@ def main(): f"No dataset directories with shard_* folders found in {input_dir}." ) + reports: List[MergeReport] = [] for dataset_dir in dataset_dirs: - shard_dirs = _list_shard_dirs(dataset_dir) - if not shard_dirs: - continue + if dataset_dir == input_dir: + ds_output_dir = output_dir + else: + dataset_name = os.path.basename(dataset_dir) + ds_output_dir = os.path.join(output_dir, dataset_name) - shard_dsets: List[DatasetLike] = [] - skipped_empty_dir = 0 - skipped_zero_rows = 0 - - for shard_dir in shard_dirs: - try: - ds = load_from_disk(shard_dir) - except FileNotFoundError as e: - print( - f"⚠️ {shard_dir} is not a valid HF dataset directory, skipping. " - f"Reason: {e}" - ) - skipped_empty_dir += 1 - continue - - if _count_rows(ds) == 0: - print(f"⚠️ Shard dataset has 0 rows, skipping: {shard_dir}") - skipped_zero_rows += 1 - continue - - print(f"✅ Using {os.path.basename(shard_dir)} with {_count_rows(ds)} rows.") - shard_dsets.append(ds) - - if not shard_dsets: - raise RuntimeError( - f"No non-empty shards found in {dataset_dir}. " - f"empty/invalid dirs: {skipped_empty_dir}, " - f"zero-row datasets: {skipped_zero_rows}." - ) + reports.append(merge_dataset_dir(dataset_dir, ds_output_dir)) - ds_merged = _merge_shards(shard_dsets) - n_rows = _count_rows(ds_merged) + return reports - total_skipped = skipped_empty_dir + skipped_zero_rows - if dataset_dir == input_dir: - ds_out_dir = output_dir - dataset_name = os.path.basename(os.path.normpath(input_dir)) +def merge_from_config( + cfg: MMirageConfig, + output_root: Optional[str] = None, +) -> List[MergeReport]: + """Merge shard outputs described in config.loading_params.datasets. + + Args: + cfg: Loaded MMIRAGE config. + output_root: Optional destination root. If omitted, each dataset writes + into /merged. + + Returns: + Merge reports for each dataset entry. + """ + reports: List[MergeReport] = [] + datasets = cfg.loading_params.datasets + if not datasets: + raise RuntimeError("No datasets configured in loading_params.datasets.") + + dataset_names = [ + os.path.basename(os.path.normpath(ds_config.output_dir)) or f"dataset_{index}" + for index, ds_config in enumerate(datasets) + ] + name_counts: Dict[str, int] = {} + for dataset_name in dataset_names: + name_counts[dataset_name] = name_counts.get(dataset_name, 0) + 1 + + for index, ds_config in enumerate(datasets): + dataset_dir = ds_config.output_dir + dataset_name = dataset_names[index] + if output_root is None: + output_dir = os.path.join(dataset_dir, "merged") else: - dataset_name = os.path.basename(dataset_dir) - ds_out_dir = os.path.join(output_dir, dataset_name) + folder_name = dataset_name + if name_counts[dataset_name] > 1: + folder_name = f"{dataset_name}_{index}" + output_dir = os.path.join(output_root, folder_name) + + reports.append(merge_dataset_dir(dataset_dir, output_dir)) - os.makedirs(ds_out_dir, exist_ok=True) - ds_merged.save_to_disk(ds_out_dir) + return reports - print( - f"✅ Concatenated {len(shard_dsets)} shards for {dataset_name} " - f"with {n_rows} rows.\n" - f" Skipped shards: {total_skipped} total " - f"(empty/invalid dir: {skipped_empty_dir}, zero rows: {skipped_zero_rows})." + +def main(): + """CLI entrypoint for directory-based shard merging. + Scans --input-dir for dataset subdirectories containing shard_* folders. + For each dataset directory, merges shard datasets and writes directly to + the provided `--output-dir`. + """ + ap = argparse.ArgumentParser("Merge processed shard datasets into HF datasets.") + ap.add_argument( + "--input-dir", + required=True, + help="Directory containing dataset subdirectories with shard_* folders.", + ) + ap.add_argument( + "--output-dir", + required=True, + help="Directory to write merged datasets into.", + ) + ap.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level for merge summaries.", + ) + args = ap.parse_args() + _configure_logging(args.log_level) + + reports = merge_input_dir(args.input_dir, args.output_dir) + for report in reports: + skipped_total = report.skipped_invalid_dirs + report.skipped_zero_rows + logger.info( + f"Concatenated {report.used_shards} shards for {report.dataset_name} " + f"with {report.merged_rows} rows.\n" + f" Output: {report.output_dir}\n" + f" Skipped shards: {skipped_total} total " + f"(empty/invalid dir: {report.skipped_invalid_dirs}, " + f"zero rows: {report.skipped_zero_rows})." ) diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index c4dce9a..76af6c9 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -1,4 +1,4 @@ -"""Utility functions for shard processing. +"""Utility functions for shard and merge processing. This module contains helper functions for dataset sharding, state management, and file operations used in the MMIRAGE shard processing pipeline. @@ -95,6 +95,19 @@ def to_dict(self) -> Dict[str, Any]: } +@dataclass +class MergeReport: + """Summary of a merge operation for one dataset directory.""" + + dataset_name: str + input_dir: str + output_dir: str + used_shards: int + merged_rows: int + skipped_invalid_dirs: int + skipped_zero_rows: int + + def _count_rows(ds: DatasetLike) -> int: """Count total rows in a dataset or dataset dict.""" if isinstance(ds, DatasetDict): @@ -140,6 +153,35 @@ def _save_dataset_atomic(ds_processed: DatasetLike, out_dir: str): os.replace(tmp_dir, out_dir) +def _validate_safe_output_dir(dataset_dir: str, output_dir: str) -> None: + """Reject output paths that could delete input data. + + We forbid output directories that are the same as, or ancestors of, + the input dataset directory. This prevents accidental deletion when + clearing pre-existing output_dir before writing merged data. + """ + dataset_real = os.path.realpath(os.path.abspath(dataset_dir)) + output_real = os.path.realpath(os.path.abspath(output_dir)) + + if output_real == dataset_real: + raise RuntimeError( + "Unsafe merge output path: output_dir equals dataset_dir " + f"(dataset_dir={dataset_real}, output_dir={output_real})." + ) + + try: + common = os.path.commonpath([dataset_real, output_real]) + except ValueError: + # Different drives (Windows) -> no ancestor relationship possible + return + + if common == output_real: + raise RuntimeError( + "Unsafe merge output path: output_dir contains dataset_dir " + f"(dataset_dir={dataset_real}, output_dir={output_real})." + ) + + def _dataset_out_dir(shard_idx: int, ds_config: BaseDataLoaderConfig) -> str: """Get dataset-specific output directory for a shard.""" return os.path.join(ds_config.output_dir, f"shard_{shard_idx}") @@ -262,3 +304,50 @@ def _mark_failure(state_dir: str, error_msg: str): _write_status(state_dir, prev) _clear_markers(state_dir) _touch_marker(state_dir, ".FAILED") + + +def _list_shard_dirs(dataset_dir: str) -> List[str]: + """List shard directories in a dataset directory.""" + shard_dirs: List[str] = [] + for name in os.listdir(dataset_dir): + if not name.startswith("shard_"): + continue + # Only accept canonical shard directories of the form "shard_" + # and explicitly skip atomic-save temp dirs like + # "shard_0.tmp...". + if ".tmp" in name: + continue + suffix = name[len("shard_") :] + if not suffix.isdigit(): + continue + path = os.path.join(dataset_dir, name) + if os.path.isdir(path): + shard_dirs.append(path) + + def _shard_key(path: str) -> int: + base = os.path.basename(path) + suffix = base.removeprefix("shard_") + return int(suffix) if suffix.isdigit() else 0 + + shard_dirs.sort(key=_shard_key) + return shard_dirs + + +def _dataset_dirs(input_dir: str) -> List[str]: + """Find dataset directories containing shard folders.""" + candidates: List[str] = [] + for name in os.listdir(input_dir): + path = os.path.join(input_dir, name) + if not os.path.isdir(path): + continue + if _list_shard_dirs(path): + candidates.append(path) + return sorted(candidates) + +def _validate_input_dir(path: str, arg_name: str) -> None: + """Ensure a user-provided input path exists and is a directory.""" + normalized = os.path.abspath(os.path.expandvars(os.path.expanduser(path))) + if not os.path.isdir(normalized): + raise RuntimeError( + f"{arg_name} does not exist or is not a directory: {normalized}" + ) \ No newline at end of file