Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Run the pipeline via the Python CLI. Retry behavior is driven by your YAML confi

- `execution_params.retry: true` → automatically retries failed shards until completion or `max_retries`
- `execution_params.retry: false` → submits/runs once; you can later trigger retries via `check`
- `execution_params.merge: true` → after a successful run, automatically merges shard outputs

```bash
python -m mmirage.cli run --config configs/config_mock.yaml
Expand All @@ -55,6 +56,31 @@ To check status and submit retries for failed shards:
python -m mmirage.cli check --config configs/config_mock.yaml --retry
```

To merge shards from the CLI directly:

```bash
mmirage merge --config configs/config_mock.yaml
```

To merge shards without a config file (input directory + output directory only):

```bash
mmirage merge-dir --input-dir /path/to/shards --output-dir /path/to/merged
```

`--input-dir` can point either to a single dataset directory that contains `shard_*`
folders, or to a parent directory containing multiple dataset subdirectories.
If `shard_*` folders are present directly in `--input-dir`, MMIRAGE merges that
root dataset directly and ignores nested internal folders.
Comment thread
qchapp marked this conversation as resolved.

For multiple datasets, you can also choose a shared merge root:

```bash
mmirage merge --config configs/config_mock.yaml --output-root /path/to/merged
```

MMIRAGE still keeps datasets separate by creating one subdirectory per dataset under the root.

### Text-only: Reformatting dataset

Suppose you have a dataset with samples of the following format
Expand Down Expand Up @@ -119,6 +145,7 @@ processing_params:
execution_params:
mode: local
retry: false
merge: false
```

Configuration explanation:
Expand All @@ -134,6 +161,11 @@ Configuration explanation:
- `execution_params`:
- `mode`: "local" to run shard processing in the current Python environment or "slurm" to run through SLURM by submitting an sbatch array job.
- `retry`: If true, MMIRAGE automatically retries failed shards until they succeed or `max_retries` is reached. If false, the pipeline runs/submits once, and retries can be triggered later via the check/retry CLI commands.
- `merge`: If true, MMIRAGE merges shard outputs after a successful `run`. Merged datasets are written under each dataset `output_dir` in a `merged` subdirectory.

Merge output behavior with multiple datasets:
- Default (`run` with `execution_params.merge: true`, or `merge` without `--output-root`): each dataset is merged to its own `<dataset.output_dir>/merged`.
- Shared root (`merge --output-root ...`): one merged subdirectory is created per dataset under the root.

### Multimodal: Processing images with VLMs

Expand Down
5 changes: 5 additions & 0 deletions configs/config_comprehensive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ execution_params:
# - true: submit, wait, and keep retrying failed shards until success or retry budget exhaustion
retry: true

# Whether to merge shard outputs after a successful run.
# - false: keep shard_* outputs only
# - true: build merged datasets from shard_* outputs
merge: false

# Maximum number of times to retry a failed shard (default: 3)
max_retries: 3

Expand Down
1 change: 1 addition & 0 deletions configs/config_mock.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ processing_params:
execution_params:
mode: local
retry: false
merge: true
report_dir: ~/reports
hf_home: ~/hf
1 change: 1 addition & 0 deletions configs/config_mock_vision.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@ processing_params:
execution_params:
mode: local
retry: false
merge: true
report_dir: ~/reports
hf_home: ~/hf
131 changes: 128 additions & 3 deletions src/mmirage/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from mmirage.config.config import MMirageConfig
from mmirage.config.utils import load_mmirage_config
from mmirage.merge_shards import MergeReport, merge_from_config, merge_input_dir


logger = logging.getLogger(__name__)
Expand All @@ -48,13 +49,20 @@ def run_local(config_path: str, shard_id: Optional[int] = None) -> int:
return result.returncode


def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = False) -> int:
def launch_pipeline(
cfg: MMirageConfig,
config_path: str,
force_retry: bool = False,
require_completion: bool = False,
) -> int:
"""Launch the pipeline according to execution mode and retry settings.

Args:
cfg: Parsed MMIRAGE configuration object.
config_path: Absolute path to the MMIRAGE YAML config file.
force_retry: If True, enable retry orchestration regardless of config flag.
require_completion: If True, wait for completion and verify shard
status before returning success in SLURM mode when auto-retry is off.

Returns:
Exit code: 0 on success, 1 on failure.
Expand Down Expand Up @@ -114,7 +122,17 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa
logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids or 'ALL'}")

if not auto_retry:
return 0
if not require_completion:
return 0

wait_for_slurm_job(job_id, cfg)
failed_shards, summary = check_failed_shards(cfg)
status_code = status_exit_code(failed_shards, summary)
if status_code == 0:
logger.info("All shards completed successfully")
else:
logger.error("SLURM run completed with failed shards; merge will not start")
return status_code

wait_for_slurm_job(job_id, cfg)
failed_shards, summary = check_failed_shards(cfg)
Expand Down Expand Up @@ -223,9 +241,67 @@ def build_argparser() -> argparse.ArgumentParser:
help="Run a single shard locally (overrides execution mode)",
)

merge_parser = subparsers.add_parser(
"merge",
help="Merge shard outputs listed in config.loading_params.datasets",
)
add_shared_arguments(merge_parser)
merge_parser.add_argument(
"--output-dir",
"--output-root",
dest="output_dir",
default=None,
help=(
"Optional root directory for merged outputs. MMIRAGE creates one "
"subdirectory per configured dataset under this root. If omitted, "
"each dataset is merged into <dataset.output_dir>/merged"
),
)

merge_dir_parser = subparsers.add_parser(
"merge-dir",
help="Merge shards directly from an input directory into an output directory",
)
merge_dir_parser.add_argument(
"--input-dir",
required=True,
help=(
"Input directory containing one dataset with shard_* folders, or "
"multiple dataset subdirectories each containing shard_* folders"
),
)
merge_dir_parser.add_argument(
"--output-dir",
required=True,
help="Output directory for merged dataset(s)",
)
merge_dir_parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Log verbosity",
)

return parser


def log_merge_reports(reports: List[MergeReport]) -> None:
"""Log merge summary for one or more datasets."""
for report in reports:
skipped_total = report.skipped_invalid_dirs + report.skipped_zero_rows
logger.info(
"Merged dataset %s: shards=%d rows=%d output=%s skipped=%d "
"(invalid=%d, zero_rows=%d)",
report.dataset_name,
report.used_shards,
report.merged_rows,
report.output_dir,
skipped_total,
report.skipped_invalid_dirs,
report.skipped_zero_rows,
)


def parse_shard_ids(raw_value: Optional[str], num_shards: Optional[int] = None) -> List[int]:
"""Parse a comma-separated shard id list.

Expand Down Expand Up @@ -271,7 +347,22 @@ def handle_run(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -
"""
if args.shard_id is not None:
return run_local(config_path, args.shard_id)
return launch_pipeline(cfg, config_path, force_retry=args.force_retry)

exit_code = launch_pipeline(
cfg,
config_path,
force_retry=args.force_retry,
require_completion=cfg.execution_params.merge,
)
if exit_code != 0:
return exit_code

if cfg.execution_params.merge:
logger.info("Execution_params.merge is true; merging shard outputs")
reports = merge_from_config(cfg)
log_merge_reports(reports)

return 0


def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int:
Expand Down Expand Up @@ -370,13 +461,46 @@ def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str)
)


def handle_merge(args: argparse.Namespace, cfg: MMirageConfig, _config_path: str) -> int:
"""Merge shard outputs defined in config.loading_params.datasets.

Args:
args: Parsed CLI namespace.
cfg: Parsed MMIRAGE configuration object.
_config_path: Absolute path to the MMIRAGE YAML config file (not needed here).

Returns:
Exit code for merge outcome.
"""
reports = merge_from_config(cfg, output_root=args.output_dir)
log_merge_reports(reports)
return 0


def handle_merge_dir(args: argparse.Namespace) -> int:
"""Merge shard outputs directly from input/output directory arguments.

Args:
args: Parsed CLI namespace.

Returns:
Exit code for merge outcome.
"""
reports = merge_input_dir(args.input_dir, args.output_dir)
log_merge_reports(reports)
return 0


def main() -> None:
"""CLI entry point."""
parser = build_argparser()
args = parser.parse_args()
configure_logging(args.log_level)

try:
if args.command == "merge-dir":
sys.exit(handle_merge_dir(args))

config_path = os.path.abspath(args.config)
cfg = load_mmirage_config(config_path)

Expand All @@ -388,6 +512,7 @@ def main() -> None:
"submit": handle_submit,
"check": handle_check,
"retry": handle_retry,
"merge": handle_merge,
}
handler = handlers.get(args.command)
if handler is None:
Expand Down
2 changes: 2 additions & 0 deletions src/mmirage/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ExecutionParams:
Attributes:
mode: Execution mode: "local" or "slurm". Defaults to "local".
retry: Whether automatic retry orchestration is enabled. Defaults to False.
merge: Whether to merge shard outputs after a successful run. Defaults to False.
max_retries: Maximum number of retries for failed shards. Defaults to 3.
poll_interval_seconds: Seconds to wait between polling job status. Defaults to 30.
settle_time_seconds: Seconds to wait after job completes before checking results. Defaults to 60.
Expand All @@ -41,6 +42,7 @@ class ExecutionParams:

mode: str = "local"
retry: bool = False
merge: bool = False
max_retries: int = 3
poll_interval_seconds: int = 30
settle_time_seconds: int = 60
Expand Down
Loading
Loading