diff --git a/.gitignore b/.gitignore index cf139bc..908494c 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,6 @@ else/ # Test outputs tests/mock_data/output/ tests/mock_data/shards/ + +# devcontainer +.devcontainer/ \ No newline at end of file diff --git a/README.md b/README.md index 953cbf9..1446df4 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,29 @@ For testing and scripts that make use of the library, it is advised to create a ## Example usage +### Running (single command) + +Run the pipeline via the Python CLI. Retry behavior is driven by your YAML config: + +- `execution_params.retry: true` → automatically retries failed shards until completion or `max_retries` +- `execution_params.retry: false` → submits/runs once; you can later trigger retries via `check` + +```bash +python -m mmirage.cli run --config configs/config_mock.yaml +``` + +To check status only: + +```bash +python -m mmirage.cli check --config configs/config_mock.yaml +``` + +To check status and submit retries for failed shards: + +```bash +python -m mmirage.cli check --config configs/config_mock.yaml --retry +``` + ### Text-only: Reformatting dataset Suppose you have a dataset with samples of the following format @@ -58,11 +81,12 @@ processors: max_new_tokens: 384 loading_params: + state_dir: /path/to/state/dir datasets: - path: /path/to/dataset type: loadable output_dir: /path/to/output/shards - num_shards: "$SLURM_ARRAY_TASK_COUNT" + num_shards: 4 shard_id: "$SLURM_ARRAY_TASK_ID" batch_size: 64 @@ -91,17 +115,25 @@ processing_params: - role: assistant content: "{{ formatted_answer }}" modalities: "{{ modalities }}" + +execution_params: + mode: local + retry: false ``` Configuration explanation: - `processors`: List of processor configurations. Currently supports `llm` type for LLM-based generation. - `loading_params`: Parameters for loading and sharding datasets. + - `state_dir`: Optional shared directory for shard status/retry state. Defaults to `~/.cache/MMIRAGE/state_dir`. - `datasets`: List of dataset configurations with path, type, and output directory. - `processing_params`: - `inputs`: Variables extracted from the input dataset using JMESPath queries. - `outputs`: Variables created by processors. Prompts use Jinja2 templating (`{{ variable }}`). - `output_schema`: Defines the structure of output samples. +- `execution_params`: + - `mode`: "local" to run shard processing in the current Python environment or "slurm" to run through SLURM by submitting an sbatch array job. + - `retry`: If true, MMIRAGE automatically retries failed shards until they succeed or `max_retries` is reached. If false, the pipeline runs/submits once, and retries can be triggered later via the check/retry CLI commands. ### Multimodal: Processing images with VLMs @@ -121,11 +153,12 @@ processors: max_new_tokens: 768 loading_params: + state_dir: path/to/state/dir datasets: - path: /path/to/image/dataset type: loadable output_dir: /path/to/output/shards - num_shards: "$SLURM_ARRAY_TASK_COUNT" + num_shards: 4 shard_id: "$SLURM_ARRAY_TASK_ID" batch_size: 32 @@ -152,6 +185,10 @@ processing_params: image: "{{ medical_image }}" caption: "{{ enhanced_caption }}" original_caption: "{{ original_caption }}" + +execution_params: + mode: local + retry: false ``` Key multimodal features: diff --git a/configs/config_comprehensive.yaml b/configs/config_comprehensive.yaml new file mode 100644 index 0000000..57291bc --- /dev/null +++ b/configs/config_comprehensive.yaml @@ -0,0 +1,200 @@ +# MMIRAGE Configuration with all parameters +# +# This is a comprehensive example showing all available configuration options. +# You can copy and modify this file for your specific use case. +# +# Parameters are organized into sections: +# 1. processors - LLM and other data transformation processors +# 2. loading_params - Dataset loading and sharding configuration +# 3. processing_params - How to transform/process the data +# 4. execution_params - SLURM, retry, and execution settings +# + +# ============================================================================ +# PROCESSORS CONFIGURATION +# ============================================================================ +# Define the processors used to transform your data. +# Common types: llm, vision_llm, etc. + +processors: + - type: llm + server_args: + model_path: Qwen/Qwen3-4B-Instruct-2507 + tp_size: 1 + disable_custom_all_reduce: true + default_sampling_params: + temperature: 0.1 + top_p: 0.9 + max_new_tokens: 1024 + custom_params: + chat_template_kwargs: + enable_thinking: false + + +# ============================================================================ +# LOADING PARAMETERS +# ============================================================================ +# Configure how datasets are loaded, sharded, and processed. + +loading_params: + # Directory to store pipeline state (checkpoints, status, retry tracking) + # Supports environment variables: $VAR or ${VAR} + state_dir: tests/output/data/_pipeline_state + + # Dataset configurations to load + # Each dataset can be separately sharded and output + datasets: + - path: tests/mock_data/data.jsonl + type: JSONL + output_dir: tests/output/data + # image_base_path: /path/to/images # Optional, for vision tasks + + # Total number of shards to split datasets into. + # For SLURM, this determines the array job size. + num_shards: 4 + + # Shard ID for this process (0-indexed). + # In SLURM array jobs, this is set automatically. + shard_id: "$SLURM_ARRAY_TASK_ID" + + # Batch size for processing samples + batch_size: 64 + + +# ============================================================================ +# PROCESSING PARAMETERS +# ============================================================================ +# Define what to extract, transform, and output from each sample. + +processing_params: + # Input variables to extract from source data + inputs: + - name: text + key: text + # For vision examples: + # - name: image + # key: image_path + # type: image + + # Output variables generated by processors + outputs: + - name: formatted_answer + type: llm + output_type: JSON + output_schema: + - question + - answer + prompt: | + Generate one question and its corresponding answer using the following text: + ``` + {{ text }} + ``` + + # Whether to remove original columns from the dataset + remove_columns: true + + # Output schema: how to structure the final dataset + output_schema: + conversations: + - role: "user" + content: "{{ formatted_answer.question }}" + - role: "assistant" + content: "{{ formatted_answer.answer }}" + + +# ============================================================================ +# EXECUTION PARAMETERS +# ============================================================================ +# Configure how to execute the pipeline: locally or on SLURM cluster. +# All parameters here are optional with sensible defaults. + +execution_params: + # Execution mode: "local" or "slurm" + # - local: Run directly on this machine + # - slurm: Submit jobs to SLURM cluster + mode: slurm + + # Whether the canonical `run` command should automatically retry failed shards. + # - false: submit one run only + # - true: submit, wait, and keep retrying failed shards until success or retry budget exhaustion + retry: true + + # Maximum number of times to retry a failed shard (default: 3) + max_retries: 3 + + # ========================================================================== + # SLURM CONFIGURATION (only used when mode: slurm) + # ========================================================================== + + # HPC account/partition to charge jobs to (REQUIRED for SLURM mode) + account: a127 + + # SLURM job name (default: "mmirage-sharded") + job_name: mmirage-sharded + + # Optional SLURM reservation name (leave blank or omit to not use) + # reservation: "sai-a127" + + # Number of nodes (default: 1) + nodes: 1 + + # Number of tasks per node (default: 1) + ntasks_per_node: 1 + + # Number of GPUs per node (default: 4) + gpus: 4 + + # Number of CPUs per task (default: 288) + cpus_per_task: 288 + + # Job time limit in HH:MM:SS format (default: "11:59:59") + time_limit: "11:59:59" + + # ========================================================================== + # PATH CONFIGURATION + # ========================================================================== + # These support environment variables ($VAR or ${VAR}) and home directory (~) + + # Project root directory (used as base for relative paths) + # If not set, uses current working directory + # project_root: "/path/to/project" + + # Directory for SLURM output and error files (default: ~/reports) + report_dir: "/users/${USER}/reports" + + # HuggingFace cache directory (default: ~/hf) + hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf" + + # EDF environment file path for cluster-specific setup + edf_env: "/users/${USER}/.edf/mmirage.toml" + + # ========================================================================== + # JOB MONITORING (for "submit" and retry orchestration) + # ========================================================================== + + # Seconds to wait between checking job status (default: 30) + poll_interval_seconds: 30 + + # Seconds to wait after job completes before checking results (default: 60) + # This allows filesystem to settle on distributed systems + settle_time_seconds: 60 + + +# ============================================================================ +# USAGE EXAMPLES +# ============================================================================ +# +# 1. Canonical entrypoint (local or SLURM; retry controlled by config): +# python -m mmirage.cli run --config config.yaml +# +# 2. Submit job to SLURM with wait for completion: +# python -m mmirage.cli submit --config config.yaml --wait +# +# 3. Submit job and get job ID back (for scripting): +# JOB_ID=$(python -m mmirage.cli submit --config config.yaml) +# +# 4. Run a single shard locally: +# python -m mmirage.cli run --config config.yaml --shard-id 0 +# +# 5. Check status of all shards (and optionally submit retries): +# python -m mmirage.cli check --config config.yaml diff --git a/configs/config_mock.yaml b/configs/config_mock.yaml index 1f6c533..9c30dbc 100644 --- a/configs/config_mock.yaml +++ b/configs/config_mock.yaml @@ -13,12 +13,13 @@ processors: enable_thinking: false loading_params: + state_dir: tests/output/data/_pipeline_state datasets: - path: tests/mock_data/data.jsonl type: JSONL output_dir: tests/output/data - num_shards: 4 + num_shards: 1 shard_id: 0 batch_size: 64 @@ -47,3 +48,12 @@ processing_params: content: "{{ formatted_answer.question }}" - role: "assistant" content: "{{ formatted_answer.answer }}" + +# Execution configuration (local or SLURM cluster) +# For local testing, use mode: local +# For SLURM cluster, use mode: slurm and specify account +execution_params: + mode: local + retry: false + report_dir: ~/reports + hf_home: ~/hf diff --git a/configs/config_mock_vision.yaml b/configs/config_mock_vision.yaml index c86172c..f90c6a4 100644 --- a/configs/config_mock_vision.yaml +++ b/configs/config_mock_vision.yaml @@ -11,13 +11,14 @@ processors: max_new_tokens: 512 loading_params: + state_dir: tests/output/data_vision/_pipeline_state datasets: - path: tests/mock_data_vision/data.jsonl type: JSONL output_dir: tests/output/data_vision image_base_path: tests/mock_data_vision # Base directory where images are stored - num_shards: 4 + num_shards: 1 shard_id: 0 batch_size: 1 @@ -38,3 +39,10 @@ processing_params: output_schema: image: "{{ image_input }}" caption: "{{ caption }}" + +# Execution configuration (local or SLURM cluster) +execution_params: + mode: local + retry: false + report_dir: ~/reports + hf_home: ~/hf diff --git a/pyproject.toml b/pyproject.toml index 446406b..5804d4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mmirage" -version = "0.1.3" +version = "0.1.4" description = "Modular Multimodal Intelligent Reformatting and Augmentation Generation Engine - Advanced platform for processing datasets using generative models including vision-language models." readme = "README.md" requires-python = ">=3.10" @@ -49,6 +49,9 @@ dev = [ "pytest", ] +[project.scripts] +mmirage = "mmirage.cli:main" + [tool.hatch.build.targets.wheel] packages = ["src/mmirage"] diff --git a/run.sh b/run.sh deleted file mode 100644 index 5cfc08a..0000000 --- a/run.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=mmirage-sharded -#SBATCH --chdir=/users/$USER/meditron/MMIRAGE/src/mmirage -#SBATCH --output=/users/$USER/reports/R-%x.%A_%a.out -#SBATCH --error=/users/$USER/reports/R-%x.%A_%a.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=288 -#SBATCH --time=11:59:59 -#SBATCH -A a127 -#SBATCH --array=0-3 - -# --- outputs & config --- -export ROOT=$SCRATCH/mmirage_output -export SHARDS_ROOT="$ROOT/shards" -export MERGED_DIR="$ROOT/merged" -export CFG=$MMIRAGE_PATH/configs/config_small.yaml - -# HF cache/home -export HF_HOME=$SCRATCH/hf - -mkdir -p "$SHARDS_ROOT" -mkdir -p "$MERGED_DIR" - -export CMD="python $MMIRAGE_PATH/src/mmirage/shard_process.py --config $CFG" - -SRUN_ARGS=" \ - --cpus-per-task $SLURM_CPUS_PER_TASK \ - --jobid $SLURM_JOB_ID \ - --wait 60 \ - -A a127 \ - --reservation sai-a127 \ - --environment /users/$USER/.edf/mmirage.toml - " -# bash -c is needed for the delayed interpolation of env vars to work -srun $SRUN_ARGS bash -c "$CMD" -echo "END TIME: $(date)" - diff --git a/src/mmirage/__init__.py b/src/mmirage/__init__.py index 000aa27..f255e81 100644 --- a/src/mmirage/__init__.py +++ b/src/mmirage/__init__.py @@ -3,16 +3,12 @@ A platform for processing datasets using generative models including vision-language models (VLMs). """ +from __future__ import annotations -__version__ = "0.2.0" +__version__ = "0.1.4" -from mmirage.config import MMirageConfig, ProcessingParams, LoadingParams +from mmirage.config.config import MMirageConfig, ProcessingParams +from mmirage.config.loading import LoadingParams from mmirage.config.utils import load_mmirage_config -__all__ = [ - "MMirageConfig", - "ProcessingParams", - "LoadingParams", - "load_mmirage_config", - "__version__", -] +__all__ = ["MMirageConfig", "ProcessingParams", "LoadingParams", "load_mmirage_config", "__version__"] diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py new file mode 100644 index 0000000..d1c2add --- /dev/null +++ b/src/mmirage/cli.py @@ -0,0 +1,405 @@ +"""Command-line interface for MMIRAGE pipeline.""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import subprocess +import sys +from dataclasses import asdict +from typing import List, Optional + +from mmirage.cli_utils.runtime import setup_runtime, validate_edf_env_path +from mmirage.cli_utils.slurm import require_slurm, submit_slurm_job, wait_for_slurm_job +from mmirage.cli_utils.status import ( + check_failed_shards, + is_retry_budget_exceeded, + shard_state_dir, + get_shard_status, + status_exit_code, + submit_failed_shards, +) +from mmirage.config.config import MMirageConfig +from mmirage.config.utils import load_mmirage_config + + +logger = logging.getLogger(__name__) + + +def run_local(config_path: str, shard_id: Optional[int] = None) -> int: + """Run one shard in the current Python environment. + + Args: + config_path: Absolute path to the MMIRAGE YAML config file. + shard_id: Optional shard id to inject via SLURM_ARRAY_TASK_ID. + + Returns: + Process return code from shard execution. + """ + command = [sys.executable, "-m", "mmirage.shard_process", "--config", config_path] + env = os.environ.copy() + if shard_id is not None: + env["SLURM_ARRAY_TASK_ID"] = str(shard_id) + + logger.info("Running local shard processing: %s", " ".join(command)) + result = subprocess.run(command, env=env, check=False) + return result.returncode + + +def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = False) -> int: + """Launch the pipeline according to execution mode and retry settings. + + Args: + cfg: Parsed MMIRAGE configuration object. + config_path: Absolute path to the MMIRAGE YAML config file. + force_retry: If True, enable retry orchestration regardless of config flag. + + Returns: + Exit code: 0 on success, 1 on failure. + """ + auto_retry = force_retry or cfg.execution_params.retry + + if not cfg.execution_params.is_slurm(): + initial_shard_id = cfg.loading_params.get_shard_id() + if not auto_retry: + exit_code = run_local(config_path, initial_shard_id) + if exit_code == 0: + logger.info("All shards completed successfully") + return exit_code + + shard_ids: List[int] = [initial_shard_id] + attempts_by_shard = {initial_shard_id: 0} + state_root = cfg.loading_params.get_state_root() + while True: + run_exit_codes = {} + for shard_id in shard_ids: + attempts_by_shard[shard_id] = attempts_by_shard.get(shard_id, 0) + 1 + run_exit_codes[shard_id] = run_local(config_path, shard_id) + + failed_shards, summary = check_failed_shards(cfg) + if status_exit_code(failed_shards, summary) == 0: + logger.info("All shards completed successfully") + return 0 + + runtime_failed = [shard_id for shard_id, rc in run_exit_codes.items() if rc != 0] + candidates = sorted(set(failed_shards) | set(runtime_failed)) + retryable_shards: List[int] = [] + for shard_id in candidates: + _, state_attempt_count = get_shard_status(shard_state_dir(state_root, shard_id)) + memory_attempt_count = attempts_by_shard.get(shard_id, 0) + effective_attempt_count = max(state_attempt_count, memory_attempt_count) + + if not is_retry_budget_exceeded( + effective_attempt_count, + cfg.execution_params.max_retries, + ): + retryable_shards.append(shard_id) + + if not retryable_shards: + logger.error("Pipeline ended with shards that exceeded max retries") + return 1 + + logger.warning("Retrying failed shards locally: %s", ",".join(map(str, retryable_shards))) + shard_ids = retryable_shards + + shard_ids: List[int] = [] + + while True: + job_id = submit_slurm_job(cfg, config_path, shard_ids) + if job_id is None: + return 1 + + logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids or 'ALL'}") + + if not auto_retry: + return 0 + + wait_for_slurm_job(job_id, cfg) + failed_shards, summary = check_failed_shards(cfg) + + if status_exit_code(failed_shards, summary) == 0: + logger.info("All shards completed successfully") + return 0 + + if not failed_shards: + logger.error("Pipeline ended with shards that exceeded max retries") + return 1 + + logger.warning("Retrying failed shards: %s", ",".join(map(str, failed_shards))) + shard_ids = failed_shards + + +def configure_logging(level: str) -> None: + """Configure root logging. + + Args: + level: Root log level name. + """ + logging.basicConfig( + level=getattr(logging, level, logging.INFO), + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + +def add_shared_arguments(parser: argparse.ArgumentParser) -> None: + """Attach common CLI arguments to a subcommand parser. + + Args: + parser: Subcommand parser receiving shared arguments. + """ + parser.add_argument("--config", required=True, help="Path to a MMIRAGE YAML config file") + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Log verbosity", + ) + + +def build_argparser() -> argparse.ArgumentParser: + """Build the CLI parser. + + Returns: + Configured top-level argparse parser. + """ + parser = argparse.ArgumentParser(description="MMIRAGE command-line interface") + subparsers = parser.add_subparsers(dest="command", required=True) + + submit_parser = subparsers.add_parser("submit", help="Submit one SLURM array job") + add_shared_arguments(submit_parser) + submit_parser.add_argument( + "--shard-ids", + help="Comma-separated shard ids to submit instead of the full array", + ) + submit_parser.add_argument("--wait", action="store_true", help="Wait for the submitted job") + + check_parser = subparsers.add_parser("check", help="Inspect shard status") + add_shared_arguments(check_parser) + check_parser.add_argument( + "--retry", + dest="retry", + action="store_true", + help="Submit a retry job for failed shards.", + ) + check_parser.set_defaults(retry=False) + check_parser.add_argument( + "-y", + "--yes", + dest="confirm_mode", + action="store_const", + const="yes", + help="Submit retries without prompting.", + ) + check_parser.set_defaults(confirm_mode="prompt") + + retry_parser = subparsers.add_parser("retry", help="Submit only failed shards") + add_shared_arguments(retry_parser) + retry_parser.add_argument( + "-y", + "--yes", + dest="confirm_mode", + action="store_const", + const="yes", + help="Submit retries without prompting.", + ) + retry_parser.set_defaults(confirm_mode="prompt") + + run_parser = subparsers.add_parser( + "run", + help="Run according to execution_params.mode and execution_params.retry", + ) + add_shared_arguments(run_parser) + run_parser.add_argument( + "--force-retry", + action="store_true", + help="Enable retry orchestration even if execution_params.retry is false", + ) + run_parser.add_argument( + "--shard-id", + type=int, + default=None, + help="Run a single shard locally (overrides execution mode)", + ) + + return parser + + +def parse_shard_ids(raw_value: Optional[str], num_shards: Optional[int] = None) -> List[int]: + """Parse a comma-separated shard id list. + + Args: + raw_value: Comma-separated shard ids, or None/empty for full array. + num_shards: Optional upper bound used for range validation. + + Returns: + Parsed shard ids. + """ + if not raw_value: + return [] + + shard_ids: List[int] = [] + for raw_shard_id in raw_value.split(","): + candidate = raw_shard_id.strip() + if not candidate: + continue + + if candidate.isdigit(): + shard_id = int(candidate) + else: + raise ValueError(f"Invalid shard id {candidate!r}; expected integers") + + if num_shards is not None and shard_id >= num_shards: + raise ValueError(f"Invalid shard id {shard_id}; expected 0 <= shard_id < {num_shards}") + + shard_ids.append(shard_id) + + return shard_ids + + +def handle_run(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: + """Handle the canonical run command. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + config_path: Absolute path to the MMIRAGE YAML config file. + + Returns: + Exit code for the run operation. + """ + if args.shard_id is not None: + return run_local(config_path, args.shard_id) + return launch_pipeline(cfg, config_path, force_retry=args.force_retry) + + +def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: + """Submit a SLURM array job and optionally wait. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + config_path: Absolute path to the MMIRAGE YAML config file. + + Returns: + Exit code for submission/wait outcome. + """ + if require_slurm(cfg, "submit") != 0: + return 1 + + shard_ids = parse_shard_ids(args.shard_ids, cfg.loading_params.get_num_shards()) + job_id = submit_slurm_job(cfg, config_path, shard_ids) + if job_id is None: + return 1 + + logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids or 'ALL'}") + + if not args.wait: + return 0 + + wait_for_slurm_job(job_id, cfg) + failed_shards, summary = check_failed_shards(cfg) + status_code = status_exit_code(failed_shards, summary) + if status_code == 0: + logger.info("All shards completed successfully") + return status_code + + +def handle_check(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: + """Inspect shard status and optionally submit retries. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + config_path: Absolute path to the MMIRAGE YAML config file. + + Returns: + Exit code based on shard status and optional retry submission. + """ + failed_shards, summary = check_failed_shards(cfg) + print(json.dumps(asdict(summary), indent=2)) + + status_code = status_exit_code(failed_shards, summary) + if not cfg.execution_params.is_slurm(): + return status_code + + if not args.retry: + return status_code + + if not failed_shards: + return status_code + + return submit_failed_shards( + cfg=cfg, + config_path=config_path, + failed_shards=failed_shards, + confirm_mode=args.confirm_mode, + ) + + +def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: + """Submit retries for failed shards only. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + config_path: Absolute path to the MMIRAGE YAML config file. + + Returns: + Exit code for retry submission outcome. + """ + if require_slurm(cfg, "retry") != 0: + return 1 + + failed_shards, summary = check_failed_shards(cfg) + print(json.dumps(asdict(summary), indent=2)) + + if not failed_shards: + if summary.max_retries_exceeded > 0: + logger.error("No retryable shards remain") + return 1 + logger.info("All shards already succeeded.") + return 0 + + return submit_failed_shards( + cfg=cfg, + config_path=config_path, + failed_shards=failed_shards, + confirm_mode=args.confirm_mode, + ) + + +def main() -> None: + """CLI entry point.""" + parser = build_argparser() + args = parser.parse_args() + configure_logging(args.log_level) + + try: + config_path = os.path.abspath(args.config) + cfg = load_mmirage_config(config_path) + + setup_runtime(cfg, args.log_level) + validate_edf_env_path(cfg) + + handlers = { + "run": handle_run, + "submit": handle_submit, + "check": handle_check, + "retry": handle_retry, + } + handler = handlers.get(args.command) + if handler is None: + logger.error("Unknown command: %s", args.command) + sys.exit(2) + + sys.exit(handler(args, cfg, config_path)) + + except Exception as exc: + logger.error("Error: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG)) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/mmirage/cli_utils/__init__.py b/src/mmirage/cli_utils/__init__.py new file mode 100644 index 0000000..ecab471 --- /dev/null +++ b/src/mmirage/cli_utils/__init__.py @@ -0,0 +1 @@ +"""Internal utility modules for the MMIRAGE CLI.""" diff --git a/src/mmirage/cli_utils/runtime.py b/src/mmirage/cli_utils/runtime.py new file mode 100644 index 0000000..64f6ae4 --- /dev/null +++ b/src/mmirage/cli_utils/runtime.py @@ -0,0 +1,79 @@ +"""Runtime/path helpers for the MMIRAGE CLI.""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Optional, Sequence + +from mmirage.config.config import MMirageConfig + + +logger = logging.getLogger(__name__) + + +def expand_path(path: str, project_root: Optional[str] = None) -> str: + """Expand environment variables, user home and relative paths.""" + expanded = Path(os.path.expandvars(os.path.expanduser(path))) + if not expanded.is_absolute() and project_root: + expanded = Path(project_root) / expanded + return str(expanded.resolve()) + + +def get_project_root(cfg: MMirageConfig) -> str: + """Return the configured project root, or the current working directory.""" + project_root = cfg.execution_params.project_root + if project_root: + return expand_path(project_root) + return os.getcwd() + + +def create_directories(paths: Sequence[str]) -> None: + """Create directories if they do not already exist.""" + for path in paths: + Path(path).mkdir(parents=True, exist_ok=True) + + +def validate_edf_env_path(cfg: MMirageConfig) -> None: + """Validate the optional EDF environment file path.""" + edf_env = cfg.execution_params.edf_env + if not edf_env: + return + + resolved = expand_path(edf_env, get_project_root(cfg)) + if not Path(resolved).is_file(): + raise FileNotFoundError(f"EDF environment file not found: {resolved}") + + +def add_file_logging(log_file: str, level: str) -> None: + """Add a file handler so logs are also written to disk.""" + resolved_log_file = Path(expand_path(log_file)) + try: + resolved_log_file.parent.mkdir(parents=True, exist_ok=True) + except OSError as exc: + logger.warning("Unable to create log directory for %s: %s", resolved_log_file, exc) + return + + root_logger = logging.getLogger() + for handler in root_logger.handlers: + if isinstance(handler, logging.FileHandler) and Path(handler.baseFilename).resolve() == resolved_log_file: + return + + try: + file_handler = logging.FileHandler(resolved_log_file, mode="a", encoding="utf-8") + except OSError as exc: + logger.warning("Unable to open log file %s: %s", resolved_log_file, exc) + return + + file_handler.setLevel(getattr(logging, level.upper(), logging.INFO)) + file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) + root_logger.addHandler(file_handler) + + +def setup_runtime(cfg: MMirageConfig, log_level: str) -> None: + """Initialize runtime-level logging.""" + report_dir = Path(expand_path(cfg.execution_params.report_dir, get_project_root(cfg))) + log_file = report_dir / f"{cfg.execution_params.job_name}.out" + add_file_logging(str(log_file), log_level) + logger.info("Writing logs to %s", log_file) diff --git a/src/mmirage/cli_utils/slurm.py b/src/mmirage/cli_utils/slurm.py new file mode 100644 index 0000000..ad69d45 --- /dev/null +++ b/src/mmirage/cli_utils/slurm.py @@ -0,0 +1,179 @@ +"""SLURM helpers for the MMIRAGE CLI.""" + +from __future__ import annotations + +import logging +import os +import shlex +import subprocess +import time +from typing import Optional, Sequence + +from mmirage.config.config import MMirageConfig +from mmirage.cli_utils.runtime import create_directories, expand_path, get_project_root + + +logger = logging.getLogger(__name__) + + +def _bash_double_quote(value: str) -> str: + """Return a double-quoted bash string literal. + + We intentionally do NOT escape '$' so that $VARS from config can expand on + compute nodes (e.g. $SCRATCH). This matches typical SLURM job scripts. + + To avoid command injection, we reject values containing shell command + substitution syntax such as ``$(...)`` or backticks. Variable expansion + using ``$VAR`` or ``${VAR}`` is still allowed. + """ + # Disallow command substitution while still allowing $VAR expansion. + if "`" in value or "$(" in value: + raise ValueError( + "Config value contains unsupported shell command substitution " + "(` or '$('). Command substitution is not allowed in SLURM-generated scripts." + ) + escaped = value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + + +def _shell_path(value: str, project_root: str) -> str: + """Expand user home and make relative paths project-rooted. + + If the path starts with '$' we assume it will expand on the compute node and + therefore do not attempt to join it with project_root. + """ + raw = value.strip() + if not raw: + return raw + + raw = os.path.expanduser(raw) + if raw.startswith("$"): + return raw + if not os.path.isabs(raw): + raw = os.path.join(project_root, raw) + return raw + + +def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str: + """Build the sbatch payload executed for each array task.""" + project_root = get_project_root(cfg) + hf_home = _shell_path(cfg.execution_params.hf_home, project_root) + state_root = _shell_path(cfg.loading_params.get_state_root(), project_root) + src_root = os.path.join(project_root, "src") + shard_process_path = os.path.join(src_root, "mmirage", "shard_process.py") + + lines = [ + "#!/bin/bash", + "set -euo pipefail", + f"export PYTHONPATH={_bash_double_quote(src_root)}:${{PYTHONPATH:-}}", + f"export SHARD_PROCESS={_bash_double_quote(shard_process_path)}", + f"export HF_HOME={_bash_double_quote(hf_home)}", + f"export MMIRAGE_CONFIG={_bash_double_quote(config_path)}", + f"mkdir -p {_bash_double_quote(hf_home)}", + f"mkdir -p {_bash_double_quote(state_root)}", + "srun_args=(--cpus-per-task ${SLURM_CPUS_PER_TASK:-1} --wait 60)", + ] + + if cfg.execution_params.edf_env: + edf_env = expand_path(cfg.execution_params.edf_env, project_root) + lines.append(f"srun_args+=(--environment={shlex.quote(edf_env)})") + + account = cfg.execution_params.account + if not account: + raise ValueError("execution_params.account must be set in slurm mode") + lines.append(f"srun_args+=(-A {shlex.quote(account)})") + + if cfg.execution_params.reservation: + lines.append(f"srun_args+=(--reservation={shlex.quote(cfg.execution_params.reservation)})") + + lines.extend( + [ + "srun \"${srun_args[@]}\" bash -c 'if command -v python3 >/dev/null 2>&1; then PYTHON_CMD=python3; elif command -v python >/dev/null 2>&1; then PYTHON_CMD=python; else echo \"python3/python not found in PATH\" >&2; exit 127; fi; echo \"Using Python: ${PYTHON_CMD} ($(${PYTHON_CMD} --version 2>&1))\"; ${PYTHON_CMD} -c \"import sys; raise SystemExit(0 if sys.version_info >= (3, 10) else 2)\" || { echo \"MMIRAGE requires Python >= 3.10 on compute nodes\" >&2; exit 2; }; exec ${PYTHON_CMD} \"$SHARD_PROCESS\" --config \"$MMIRAGE_CONFIG\"'", + "echo \"Shard ${SLURM_ARRAY_TASK_ID:-0} completed\"", + ] + ) + return "\n".join(lines) + "\n" + + +def submit_slurm_job( + cfg: MMirageConfig, + config_path: str, + shard_ids: Optional[Sequence[int]] = None, +) -> Optional[int]: + """Submit a SLURM array job and return its job ID.""" + project_root = get_project_root(cfg) + report_dir = expand_path(cfg.execution_params.report_dir, project_root) + create_directories([report_dir]) + + command = [ + "sbatch", + "--parsable", + f"--job-name={cfg.execution_params.job_name}", + f"--chdir={project_root}", + f"--output={os.path.join(report_dir, 'R-%x.%A_%a.out')}", + f"--error={os.path.join(report_dir, 'R-%x.%A_%a.err')}", + f"--nodes={cfg.execution_params.nodes}", + f"--ntasks-per-node={cfg.execution_params.ntasks_per_node}", + f"--gres=gpu:{cfg.execution_params.gpus}", + f"--cpus-per-task={cfg.execution_params.cpus_per_task}", + f"--time={cfg.execution_params.time_limit}", + f"--account={cfg.execution_params.account}", + ] + + if cfg.execution_params.reservation: + command.append(f"--reservation={cfg.execution_params.reservation}") + + requested_shards = list(shard_ids or []) + if requested_shards: + command.append(f"--array={','.join(str(shard_id) for shard_id in requested_shards)}") + else: + num_shards = cfg.loading_params.get_num_shards() + last_shard_id = num_shards - 1 + command.append(f"--array=0-{last_shard_id}") + + logger.info("Submitting SLURM job: %s", " ".join(command)) + result = subprocess.run( + command, + input=build_sbatch_script(cfg, config_path), + text=True, + capture_output=True, + check=False, + ) + + if result.returncode != 0: + logger.error("sbatch failed: %s", result.stderr.strip()) + return None + + raw_job_id = result.stdout.strip().split(";", 1)[0] + try: + return int(raw_job_id) + except ValueError: + logger.error("Unable to parse job id from sbatch output: %s", result.stdout.strip()) + return None + + +def wait_for_slurm_job(job_id: int, cfg: MMirageConfig) -> None: + """Wait for a SLURM job array to leave the queue.""" + logger.info("Waiting for SLURM job %s", job_id) + while True: + result = subprocess.run( + ["squeue", "-h", "-j", str(job_id)], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0 and not result.stdout.strip(): + break + time.sleep(cfg.execution_params.poll_interval_seconds) + + if cfg.execution_params.settle_time_seconds > 0: + logger.info("Waiting %ss for state files to settle", cfg.execution_params.settle_time_seconds) + time.sleep(cfg.execution_params.settle_time_seconds) + + +def require_slurm(cfg: MMirageConfig, command_name: str) -> int: + """Ensure command can only run in SLURM mode.""" + if cfg.execution_params.is_slurm(): + return 0 + logger.error("%s requires execution_params.mode=slurm", command_name) + return 1 diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py new file mode 100644 index 0000000..838099a --- /dev/null +++ b/src/mmirage/cli_utils/status.py @@ -0,0 +1,155 @@ +"""Shard status and retry helpers for the MMIRAGE CLI.""" + +from __future__ import annotations + +import json +import logging +import os +import sys +from dataclasses import dataclass +from typing import List, Literal, Sequence, Tuple + +from mmirage.config.config import MMirageConfig +from mmirage.cli_utils.slurm import submit_slurm_job +from mmirage.shard_utils import ShardStatus + + +logger = logging.getLogger(__name__) + + +@dataclass +class ShardSummary: + """Compact status summary for shard execution.""" + + total: int + successful: int + running: int + failed: int + max_retries_exceeded: int + + +def max_allowed_attempts(max_retries: int) -> int: + """Return max allowed total attempts for a shard. + + Total attempts = initial attempt + max_retries. + """ + return max_retries + 1 + + +def is_retry_budget_exceeded(attempt_count: int, max_retries: int) -> bool: + """Return whether a shard has exceeded the retry budget.""" + return attempt_count > max_allowed_attempts(max_retries) + + +def shard_state_dir(state_root: str, shard_id: int) -> str: + """Return the state directory for a shard.""" + return os.path.join(state_root, f"shard_{shard_id}") + + +def get_shard_status(state_dir: str) -> Tuple[str, int]: + """Read the current status and attempt counter for a shard.""" + status_file = os.path.join(state_dir, "status.json") + if not os.path.exists(status_file): + return ("missing", 0) + + try: + with open(status_file, "r", encoding="utf-8") as handle: + data = json.load(handle) + if not isinstance(data, dict): + logger.warning("Invalid shard status format in %s; expected object", status_file) + return ("unknown", 0) + except (OSError, json.JSONDecodeError) as exc: + logger.warning("Failed to read shard status from %s: %s", status_file, exc) + return ("unknown", 0) + + parsed = ShardStatus.from_dict(data) + return (parsed.status, parsed.retry_count) + + +def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: + """Return retryable failed shards and a compact summary.""" + state_root = cfg.loading_params.get_state_root() + + num_shards = cfg.loading_params.get_num_shards() + max_retries = cfg.execution_params.max_retries + failed_shards: List[int] = [] + success_count = 0 + running_count = 0 + exhausted_count = 0 + allowed_attempts = max_allowed_attempts(max_retries) + + for shard_id in range(num_shards): + status, attempt_count = get_shard_status(shard_state_dir(state_root, shard_id)) + if status == "success": + success_count += 1 + elif status == "running": + running_count += 1 + elif is_retry_budget_exceeded(attempt_count, max_retries): + exhausted_count += 1 + logger.warning( + "Shard %s exceeded retry budget (attempts=%s, max_allowed_attempts=%s)", + shard_id, + attempt_count, + allowed_attempts, + ) + else: + failed_shards.append(shard_id) + + summary = ShardSummary( + total=num_shards, + successful=success_count, + running=running_count, + failed=len(failed_shards), + max_retries_exceeded=exhausted_count, + ) + return failed_shards, summary + + +def confirm_retry(count: int, confirm_mode: Literal["prompt", "yes"]) -> bool: + """Return whether retry submission is confirmed. + + Modes: + - prompt: ask the user interactively + - yes: submit without prompting + """ + if confirm_mode == "yes": + return True + + if not sys.stdin.isatty(): + logger.error("Interactive confirmation requested but stdin is not a TTY; use --yes") + return False + + response = input(f"Retry {count} shard(s)? (y/N) ") + return response.strip().lower() == "y" + + +def status_exit_code(failed_shards: Sequence[int], summary: ShardSummary) -> int: + """Map shard status to an exit code.""" + return ( + 0 + if not failed_shards + and summary.max_retries_exceeded == 0 + and summary.running == 0 + and summary.successful == summary.total + else 1 + ) + + +def submit_failed_shards( + cfg: MMirageConfig, + config_path: str, + failed_shards: Sequence[int], + confirm_mode: Literal["prompt", "yes"], +) -> int: + """Submit retry jobs for failed shards when requested.""" + if not failed_shards: + return 0 + + if not confirm_retry(len(failed_shards), confirm_mode): + return 1 + + job_id = submit_slurm_job(cfg, config_path, failed_shards) + if job_id is None: + return 1 + + return 0 diff --git a/src/mmirage/config/config.py b/src/mmirage/config/config.py index 8ef3cc7..6063f7a 100644 --- a/src/mmirage/config/config.py +++ b/src/mmirage/config/config.py @@ -1,13 +1,80 @@ """Configuration dataclasses for MMIRAGE pipeline.""" -from dataclasses import dataclass -from typing import Any, Dict, List +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional from mmirage.config.loading import LoadingParams from mmirage.core.process.base import BaseProcessorConfig from mmirage.core.process.variables import InputVar, OutputVar +@dataclass +class ExecutionParams: + """Parameters for executing the MMIRAGE pipeline. + + Defines how the pipeline is executed, including local or SLURM-based + distributed execution, retry logic, and resource allocation. + + Attributes: + mode: Execution mode: "local" or "slurm". Defaults to "local". + retry: Whether automatic retry orchestration is enabled. Defaults to False. + max_retries: Maximum number of retries for failed shards. Defaults to 3. + poll_interval_seconds: Seconds to wait between polling job status. Defaults to 30. + settle_time_seconds: Seconds to wait after job completes before checking results. Defaults to 60. + + # SLURM-specific parameters + account: HPC account/partition to charge. Required for SLURM mode. + job_name: SLURM job name. Defaults to "mmirage-sharded". + reservation: Optional SLURM reservation name. + nodes: Number of nodes. Defaults to 1. + ntasks_per_node: Number of tasks per node. Defaults to 1. + gpus: Number of GPUs per node. Defaults to 4. + cpus_per_task: Number of CPUs per task. Defaults to 288. + time_limit: Job time limit (HH:MM:SS). Defaults to "11:59:59". + + # Paths + project_root: Base project directory. Can use environment variables with ${VAR}. + report_dir: Directory for SLURM output/error files. Defaults to ~/reports. + hf_home: HuggingFace cache directory. Defaults to ~/hf. + edf_env: Optional EDF environment file path. + """ + + mode: str = "local" + retry: bool = False + max_retries: int = 3 + poll_interval_seconds: int = 30 + settle_time_seconds: int = 60 + + # Paths (can contain environment variables like ${VAR} or $VAR) + project_root: Optional[str] = None + report_dir: str = "~/reports" + hf_home: str = "~/hf" + edf_env: Optional[str] = None + + # SLURM parameters + account: Optional[str] = None + job_name: str = "mmirage-sharded" + reservation: Optional[str] = None + nodes: int = 1 + ntasks_per_node: int = 1 + gpus: int = 4 + cpus_per_task: int = 288 + time_limit: str = "11:59:59" + + def __post_init__(self): + """Validate execution parameters.""" + if self.mode not in ("local", "slurm"): + raise ValueError(f"Invalid execution mode: {self.mode!r}. Must be 'local' or 'slurm'.") + if self.mode == "slurm" and not self.account: + raise ValueError("account is required when mode='slurm'") + if self.max_retries < 0: + raise ValueError(f"max_retries must be >= 0, got {self.max_retries}") + + def is_slurm(self) -> bool: + """Check if execution mode is SLURM.""" + return self.mode == "slurm" + + @dataclass class ProcessingParams: """Parameters for processing dataset samples. @@ -33,15 +100,17 @@ class MMirageConfig: """Main configuration class for MMIRAGE pipeline. Contains all configuration needed to run a MMIRAGE processing pipeline, - including processor configurations, dataset loading parameters, and - processing parameters. + including processor configurations, dataset loading parameters, processing + parameters, and execution parameters. Attributes: processors: List of processor configurations for data transformation. loading_params: Parameters for loading input datasets. processing_params: Parameters for processing dataset samples. + execution_params: Parameters for executing the pipeline (local/SLURM). """ processors: List[BaseProcessorConfig] loading_params: LoadingParams processing_params: ProcessingParams + execution_params: ExecutionParams = field(default_factory=ExecutionParams) diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index ec7faca..929d76d 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -1,10 +1,14 @@ """Data loading configuration for MMIRAGE pipeline.""" +import os +import re from dataclasses import dataclass, field from typing import Union, List, cast from mmirage.core.loader.base import BaseDataLoaderConfig +DEFAULT_STATE_DIR = "~/.cache/MMIRAGE/state_dir" + @dataclass class LoadingParams: @@ -15,7 +19,8 @@ class LoadingParams: Attributes: datasets: List of dataset configurations to load. - output_dir: Directory path for saving processed output shards. + state_dir: Shared directory for logical shard state/markers/retry tracking. + output_dir: Legacy top-level output directory. Prefer per-dataset output_dir. num_shards: Total number of shards to split the dataset into. shard_id: ID of this shard (0-indexed). batch_size: Batch size for processing samples. @@ -25,31 +30,60 @@ class LoadingParams: """ datasets: List[BaseDataLoaderConfig] = field(default_factory=list) + state_dir: str = DEFAULT_STATE_DIR output_dir: str = "" num_shards: Union[int, str] = 1 shard_id: Union[int, str] = 0 batch_size: Union[int, str] = 1 def __post_init__(self): + _UNRESOLVED_ENV_VAR_PATTERN = re.compile(r"^\$(?:\{[A-Za-z_][A-Za-z0-9_]*\}|[A-Za-z_][A-Za-z0-9_]*)$") + def is_unresolved_env_var(s: str) -> bool: + return bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(s.strip())) + if isinstance(self.num_shards, str): try: self.num_shards = int(self.num_shards) if self.num_shards < 1: raise ValueError() except (ValueError, TypeError): - raise ValueError(f"Invalid value for num_shards: {self.num_shards!r}") + if is_unresolved_env_var(self.num_shards): + self.num_shards = 1 + else: + raise ValueError(f"Invalid value for num_shards: {self.num_shards!r}") + if isinstance(self.shard_id, str): try: self.shard_id = int(self.shard_id) except (ValueError, TypeError): - raise ValueError(f"Invalid value for shard_id: {self.shard_id!r}") + if is_unresolved_env_var(self.shard_id): + self.shard_id = 0 + else: + raise ValueError(f"Invalid value for shard_id: {self.shard_id!r}") + if isinstance(self.batch_size, str): try: self.batch_size = int(self.batch_size) except (ValueError, TypeError): raise ValueError(f"Invalid value for batch_size: {self.batch_size!r}") + self.batch_size = max(self.batch_size, 1) + raw_state_dir = "" if self.state_dir is None else str(self.state_dir) + self.state_dir = raw_state_dir.strip() + if not self.state_dir: + self.state_dir = DEFAULT_STATE_DIR + + self.state_dir = os.path.expanduser(self.state_dir) + + def get_state_root(self) -> str: + """Get the state root path. + + Returns: + str: State root path. + """ + return self.state_dir + def get_num_shards(self) -> int: """Get the total number of shards. diff --git a/src/mmirage/config/utils.py b/src/mmirage/config/utils.py index 77e8bcb..e69c2d3 100644 --- a/src/mmirage/config/utils.py +++ b/src/mmirage/config/utils.py @@ -9,6 +9,15 @@ from mmirage.core.process.base import BaseProcessorConfig, ProcessorRegistry, OutputVar from mmirage.core.loader.base import BaseDataLoaderConfig, DataLoaderRegistry +# Register built-in processors/loaders. +# +# We import configuration modules (lightweight) here so the registries know how +# to construct config/output-var objects from YAML without importing heavy +# processor implementations (e.g. torch/transformers). +import mmirage.core.process.processors.llm.config # noqa: F401 +import mmirage.core.loader.jsonl # noqa: F401 +import mmirage.core.loader.local_hf # noqa: F401 + EnvValue: TypeAlias = Union[str, List["EnvValue"], Dict[str, "EnvValue"]] @@ -112,4 +121,4 @@ def output_var_hook(data: Dict[str, Any]) -> OutputVar: ) cfg_obj = from_dict(MMirageConfig, cast(dict, cfg), config=config) - return cfg_obj + return cfg_obj \ No newline at end of file diff --git a/src/mmirage/core/loader/__init__.py b/src/mmirage/core/loader/__init__.py index c27714d..217d53d 100644 --- a/src/mmirage/core/loader/__init__.py +++ b/src/mmirage/core/loader/__init__.py @@ -7,13 +7,3 @@ All loaders inherit from BaseDataLoader and are registered with DataLoaderRegistry for dynamic instantiation based on configuration. """ - -from mmirage.core.loader.jsonl import JSONLDataConfig, JSONLDataLoader -from mmirage.core.loader.local_hf import LocalHFConfig, LocalHFDataLoader - -__all__ = [ - "JSONLDataConfig", - "JSONLDataLoader", - "LocalHFDataLoader", - "LocalHFConfig", -] diff --git a/src/mmirage/core/loader/utils.py b/src/mmirage/core/loader/utils.py index 0b43727..c039727 100644 --- a/src/mmirage/core/loader/utils.py +++ b/src/mmirage/core/loader/utils.py @@ -31,23 +31,21 @@ def load_datasets_from_configs(configs: List[BaseDataLoaderConfig]) -> List[Data RuntimeError: If no datasets could be loaded successfully. """ - config_per_type = {} - for ds_config in configs: - config_per_type[ds_config.type] = config_per_type.get(ds_config.type, []) + [ - ds_config - ] - valid_ds: List[DatasetLike] = [] - for config_type, config_list in config_per_type.items(): - loader = AutoDataLoader.from_name(config_type)() - for ds_config in config_list: - try: - ds = loader.from_config(ds_config) - if ds is None: - continue - valid_ds.append(ds) - except Exception as e: - logger.warning(f"⚠️ Dataset loading failed with error: {e}. Skipping") + loader_by_type = {} + for ds_config in configs: + loader = loader_by_type.get(ds_config.type) + if loader is None: + loader = AutoDataLoader.from_name(ds_config.type)() + loader_by_type[ds_config.type] = loader + + try: + ds = loader.from_config(ds_config) + if ds is None: + continue + valid_ds.append(ds) + except Exception as e: + logger.warning(f"Dataset loading failed with error: {e}. Skipping") if not valid_ds: raise RuntimeError("No valid datasets loaded from the provided configs.") diff --git a/src/mmirage/core/process/__init__.py b/src/mmirage/core/process/__init__.py index b031213..3002d4c 100644 --- a/src/mmirage/core/process/__init__.py +++ b/src/mmirage/core/process/__init__.py @@ -1,15 +1,7 @@ """Processing module for MMIRAGE pipeline. -This module provides the core processing infrastructure: -- Base classes for processors and variables -- MMIRAGEMapper for orchestrating transformations -- LLM processor implementation for generative tasks (including multimodal) +Important: keep this package import lightweight. -Processors are responsible for generating new output variables from -existing variables, enabling flexible data transformations. +The LLM processor implementation depends on heavy libraries (e.g. torch). +We intentionally do not import processor implementations from here. """ - -from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig -from mmirage.core.process.processors.llm.llm_processor import LLMProcessor - -__all__ = ["LLMOutputVar", "SGLangLLMConfig", "LLMProcessor"] diff --git a/src/mmirage/core/process/base.py b/src/mmirage/core/process/base.py index f374e12..988bae7 100644 --- a/src/mmirage/core/process/base.py +++ b/src/mmirage/core/process/base.py @@ -1,6 +1,7 @@ """Base classes and registry for processors in MMIRAGE.""" import abc +from importlib import import_module from dataclasses import dataclass from typing import Callable, Generic, List, Type, TypeVar @@ -80,6 +81,28 @@ class ProcessorRegistry: _config_registry = dict() _output_var_registry = dict() + # Import processor implementations lazily because they may depend on heavy + # libraries (torch/transformers). Config/output-var types are registered via + # mmirage.config.utils importing the relevant config modules. + _lazy_processor_imports = {"llm": "mmirage.core.process.processors.llm.llm_processor"} + + @classmethod + def register_types( + cls, + name: str, + config_cls: Type[BaseProcessorConfig], + output_var_cls: Type[OutputVar], + ) -> None: + """Register config/output-var types without importing processor implementations.""" + cls._config_registry[name] = config_cls + cls._output_var_registry[name] = output_var_cls + + @classmethod + def _maybe_import_processor(cls, name: str) -> None: + module = cls._lazy_processor_imports.get(name) + if module: + import_module(module) + @classmethod def register( cls, @@ -118,6 +141,9 @@ def get_processor(cls, name: str) -> Type[BaseProcessor]: Raises: ValueError: If no processor is registered under the given name. """ + if name not in cls._registry: + cls._maybe_import_processor(name) + if name not in cls._registry: raise ValueError( f"Processor {name} not registered. Available processors are {list(cls._registry.keys())}" diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index 4c195af..e53c2c6 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -9,6 +9,7 @@ from mmirage.core.process.variables import BaseVar, OutputVar from mmirage.core.process.base import BaseProcessorConfig +from mmirage.core.process.base import ProcessorRegistry from jinja2 import Environment, meta logger = logging.getLogger(__name__) @@ -103,3 +104,6 @@ def is_computable(self, vars: Sequence[BaseVar]) -> bool: return False return True + + +ProcessorRegistry.register_types("llm", SGLangLLMConfig, LLMOutputVar) diff --git a/src/mmirage/merge_shards.py b/src/mmirage/merge_shards.py index 9f8c562..433feb5 100644 --- a/src/mmirage/merge_shards.py +++ b/src/mmirage/merge_shards.py @@ -7,13 +7,7 @@ from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk from mmirage.core.loader.base import DatasetLike - - -def _count_rows(ds: DatasetLike) -> int: - """Count total rows in a dataset or dataset dict.""" - if isinstance(ds, DatasetDict): - return sum(len(split) for split in ds.values()) - return len(ds) +from mmirage.shard_utils import _count_rows def _merge_datasetdict(shard_dsets: List[DatasetDict]) -> DatasetDict: @@ -169,4 +163,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index f232dc6..66e8529 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -4,57 +4,32 @@ """ import argparse -from functools import reduce -import os +import logging +import sys +import traceback from typing import Any, Dict, List -from datasets import Dataset, DatasetDict - -from mmirage.core.loader.base import BaseDataLoaderConfig, DatasetLike -from mmirage.core.process.mapper import MMIRAGEMapper - from mmirage.config.utils import load_mmirage_config -from mmirage.core.writer.renderer import TemplateRenderer +from mmirage.core.loader.base import DatasetLike from mmirage.core.loader.utils import load_datasets_from_configs -import logging +from mmirage.core.process.mapper import MMIRAGEMapper +from mmirage.core.writer.renderer import TemplateRenderer +from mmirage.shard_utils import ( + _cleanup_old_shard_data, + _count_rows, + _dataset_out_dir, + _mark_failure, + _mark_running, + _mark_success, + _remove_columns, + _save_dataset_atomic, + _shard_dataset, + _shard_state_dir, +) logger = logging.getLogger(__name__) -def _count_rows(ds: DatasetLike) -> int: - """Count total rows in a dataset or dataset dict.""" - if isinstance(ds, DatasetDict): - return sum(len(split) for split in ds.values()) - return len(ds) - - -def _dataset_out_dir(shard_idx: int, ds_config: BaseDataLoaderConfig) -> str: - """Get output directory for a shard of a dataset.""" - return os.path.join(ds_config.output_dir, f"shard_{shard_idx}") - - -def _shard_dataset(ds: DatasetLike, num_shards: int, shard_id: int) -> DatasetLike: - """Shard a dataset or dataset dict.""" - if isinstance(ds, DatasetDict): - return DatasetDict( - { - split: split_ds.shard(num_shards=num_shards, index=shard_id) - for split, split_ds in ds.items() - } - ) - return ds.shard(num_shards=num_shards, index=shard_id) - - -def _remove_columns(ds: DatasetLike, enable: bool) -> List[str]: - """Get columns to remove from dataset if enabled.""" - if not enable: - return [] - if isinstance(ds, DatasetDict): - columns_set = [set(split_ds.column_names) for split_ds in ds.values()] - return list(reduce(lambda x, y: x | y, columns_set)) - return ds.column_names - - def rewrite_batch( batch: Dict[str, List[Any]], mapper: MMIRAGEMapper, @@ -62,16 +37,13 @@ def rewrite_batch( image_base_path: str = None, ) -> Dict[str, List[Any]]: """Rewrite a batch of samples by applying transformations. - Args: batch: Dictionary mapping column names to lists of values. mapper: MMIRAGEMapper for processing transformations. renderer: TemplateRenderer for generating output. image_base_path: Optional base directory for resolving relative image paths. - Returns: Dictionary mapping output keys to lists of rendered values. - Raises: ValueError: If variables are not computable given the configuration. """ @@ -86,14 +58,12 @@ def rewrite_batch( def main(): - """Process a single shard of the dataset. - - Loads configuration, datasets, processes the shard using MMIRAGE - transformations (including multimodal), and saves the result to disk. """ - ap = argparse.ArgumentParser( - "Process dataset shards using MMIRAGE with SGLang." - ) + Process a single shard of the dataset. + Loads configuration, datasets, processes the shard using MMIRAGE + transformations (including multimodal), and saves the result to disk. + """ + ap = argparse.ArgumentParser("Process dataset shards using MMIRAGE with SGLang.") ap.add_argument( "--config", help="YAML config for MMIRAGE pipeline.", @@ -105,52 +75,89 @@ def main(): loading_params = cfg.loading_params processing_params = cfg.processing_params datasets_config = loading_params.datasets + if not datasets_config: raise ValueError("No datasets provided in config.loading_params.datasets") shard_id = loading_params.get_shard_id() num_shards = loading_params.get_num_shards() + last_shard_id = num_shards - 1 if not (0 <= shard_id < num_shards): raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") - ds_all = load_datasets_from_configs(datasets_config) - total_rows = sum(_count_rows(ds) for ds in ds_all) + state_dir = _shard_state_dir(shard_id, loading_params.get_state_root()) - ds_all_shard = [_shard_dataset(ds, num_shards, shard_id) for ds in ds_all] - shard_rows = sum(_count_rows(ds) for ds in ds_all_shard) + try: + retry_count = _mark_running(state_dir, shard_id, datasets_config) + logger.info(f"Starting shard {shard_id}/{last_shard_id} (attempt #{retry_count})") - logger.info( - f"Loaded {len(datasets_config)} dataset(s): {datasets_config} " - f"→ {total_rows} total rows; this shard has {shard_rows} rows." - ) + if retry_count > 1: + for ds_config in datasets_config: + out_dir = _dataset_out_dir(shard_id, ds_config) + _cleanup_old_shard_data(out_dir) - mapper = MMIRAGEMapper( - cfg.processors, processing_params.inputs, processing_params.outputs - ) - renderer = TemplateRenderer(processing_params.output_schema) - ds_processed_all: List[DatasetLike] = [] - for ds_idx, ds_shard in enumerate(ds_all_shard): - ds_config = datasets_config[ds_idx] - remove_columns = _remove_columns(ds_shard, processing_params.remove_columns) - ds_processed = ds_shard.map( - rewrite_batch, - batched=True, - batch_size=loading_params.get_batch_size(), - load_from_cache_file=False, - desc=f"Shard {shard_id}/{num_shards - 1} dataset {ds_idx}", - fn_kwargs={"mapper": mapper, "renderer": renderer, "image_base_path": ds_config.image_base_path}, - remove_columns=remove_columns, - ) - ds_processed_all.append(ds_processed) + ds_all = load_datasets_from_configs(datasets_config) + total_rows = sum(_count_rows(ds) for ds in ds_all) - for ds_config, ds_processed in zip(datasets_config, ds_processed_all): - out_dir = _dataset_out_dir(shard_id, ds_config) - os.makedirs(out_dir, exist_ok=True) - ds_processed.save_to_disk(out_dir) + ds_all_shard = [_shard_dataset(ds, num_shards, shard_id) for ds in ds_all] + shard_rows = sum(_count_rows(ds) for ds in ds_all_shard) - logger.info(f"✅ Saved dataset in: {out_dir}") + logger.info( + f"Loaded {len(datasets_config)} dataset(s): {datasets_config} " + f"→ {total_rows} total rows; this logical shard has {shard_rows} rows." + ) + + mapper = MMIRAGEMapper( + cfg.processors, + processing_params.inputs, + processing_params.outputs, + ) + renderer = TemplateRenderer(processing_params.output_schema) + + ds_processed_all: List[DatasetLike] = [] + for ds_idx, ds_shard in enumerate(ds_all_shard): + ds_config = datasets_config[ds_idx] + if processing_params.remove_columns: + remove_columns = _remove_columns(ds_shard) + else: + remove_columns = [] + + logger.info( + f"Processing dataset {ds_idx} for shard {shard_id}: " + f"path={ds_config.path}, output_dir={ds_config.output_dir}" + ) + + ds_processed = ds_shard.map( + rewrite_batch, + batched=True, + batch_size=loading_params.get_batch_size(), + load_from_cache_file=False, + desc=f"Shard {shard_id}/{last_shard_id} dataset {ds_idx}", + fn_kwargs={ + "mapper": mapper, + "renderer": renderer, + "image_base_path": ds_config.image_base_path, + }, + remove_columns=remove_columns, + ) + ds_processed_all.append(ds_processed) + + for ds_idx, (ds_config, ds_processed) in enumerate(zip(datasets_config, ds_processed_all)): + out_dir = _dataset_out_dir(shard_id, ds_config) + _save_dataset_atomic(ds_processed, out_dir) + logger.info(f"✅ Saved dataset {ds_idx} shard in: {out_dir}") + + _mark_success(state_dir) + logger.info(f"✅ Logical shard {shard_id} completed successfully") + + except Exception as e: + error_msg = f"{type(e).__name__}: {str(e)}" + logger.error(f"❌ Shard {shard_id} failed: {error_msg}") + logger.error(traceback.format_exc()) + _mark_failure(state_dir, error_msg) + sys.exit(1) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py new file mode 100644 index 0000000..c4dce9a --- /dev/null +++ b/src/mmirage/shard_utils.py @@ -0,0 +1,264 @@ +"""Utility functions for shard processing. + +This module contains helper functions for dataset sharding, state management, +and file operations used in the MMIRAGE shard processing pipeline. +""" + +from datetime import datetime +from dataclasses import dataclass +import json +import logging +import os +import shutil +import socket +import uuid +from typing import Any, Dict, List, Optional + +from datasets import DatasetDict + +from mmirage.core.loader.base import BaseDataLoaderConfig, DatasetLike + +logger = logging.getLogger(__name__) + + +@dataclass +class ShardStatus: + """Typed representation of the shard status.json payload.""" + + status: str = "unknown" + retry_count: int = 0 + shard_id: Optional[int] = None + started_at: Optional[str] = None + finished_at: Optional[str] = None + error: Optional[str] = None + hostname: Optional[str] = None + pid: Optional[int] = None + slurm_job_id: Optional[str] = None + slurm_array_task_id: Optional[str] = None + datasets: Optional[List[Dict[str, Any]]] = None + + @classmethod + def from_dict(cls, payload: Dict[str, Any]) -> "ShardStatus": + """Build a status object from a JSON payload.""" + data = payload or {} + try: + retry_count = int(data.get("retry_count", 0)) + except (TypeError, ValueError): + retry_count = 0 + + shard_id = data.get("shard_id") + if shard_id is not None: + try: + shard_id = int(shard_id) + except (TypeError, ValueError): + shard_id = None + + pid = data.get("pid") + if pid is not None: + try: + pid = int(pid) + except (TypeError, ValueError): + pid = None + + datasets = data.get("datasets") + if not isinstance(datasets, list): + datasets = None + + return cls( + status=str(data.get("status", "unknown")), + retry_count=retry_count, + shard_id=shard_id, + started_at=data.get("started_at"), + finished_at=data.get("finished_at"), + error=data.get("error"), + hostname=data.get("hostname"), + pid=pid, + slurm_job_id=data.get("slurm_job_id"), + slurm_array_task_id=data.get("slurm_array_task_id"), + datasets=datasets, + ) + + def to_dict(self) -> Dict[str, Any]: + """Serialize status to the JSON payload written on disk.""" + return { + "status": self.status, + "retry_count": self.retry_count, + "shard_id": self.shard_id, + "started_at": self.started_at, + "finished_at": self.finished_at, + "error": self.error, + "hostname": self.hostname, + "pid": self.pid, + "slurm_job_id": self.slurm_job_id, + "slurm_array_task_id": self.slurm_array_task_id, + "datasets": self.datasets, + } + + +def _count_rows(ds: DatasetLike) -> int: + """Count total rows in a dataset or dataset dict.""" + if isinstance(ds, DatasetDict): + return sum(len(split) for split in ds.values()) + return len(ds) + + +def _shard_dataset(ds: DatasetLike, num_shards: int, shard_id: int) -> DatasetLike: + """Shard a dataset or dataset dict.""" + if isinstance(ds, DatasetDict): + return DatasetDict( + { + split: split_ds.shard(num_shards=num_shards, index=shard_id) + for split, split_ds in ds.items() + } + ) + return ds.shard(num_shards=num_shards, index=shard_id) + + +def _remove_columns(ds: DatasetLike) -> List[str]: + """Get columns to remove from dataset if enabled.""" + if isinstance(ds, DatasetDict): + return list(set(x for split_ds in ds.values() for x in split_ds.column_names)) + return ds.column_names + + +def _save_dataset_atomic(ds_processed: DatasetLike, out_dir: str): + """Save dataset atomically via temporary directory + rename.""" + parent_dir = os.path.dirname(out_dir) + os.makedirs(parent_dir, exist_ok=True) + + tmp_dir = ( + f"{out_dir}.tmp.{socket.gethostname()}.{os.getpid()}.{uuid.uuid4().hex}" + ) + if os.path.exists(tmp_dir): + shutil.rmtree(tmp_dir) + + ds_processed.save_to_disk(tmp_dir) + + if os.path.exists(out_dir): + shutil.rmtree(out_dir) + + os.replace(tmp_dir, out_dir) + + +def _dataset_out_dir(shard_idx: int, ds_config: BaseDataLoaderConfig) -> str: + """Get dataset-specific output directory for a shard.""" + return os.path.join(ds_config.output_dir, f"shard_{shard_idx}") + + +def _shard_state_dir(shard_idx: int, state_root: str) -> str: + """Get central state directory for a logical shard.""" + return os.path.join(state_root, f"shard_{shard_idx}") + + +def _cleanup_old_shard_data(out_dir: str): + """Remove old dataset shard output before retry.""" + if os.path.exists(out_dir): + shutil.rmtree(out_dir) + logger.info(f"Removed old shard output: {out_dir}") + + +def _status_file(state_dir: str) -> str: + """Canonical status file path.""" + return os.path.join(state_dir, "status.json") + + +def _read_status(state_dir: str) -> ShardStatus: + """Read status.json if present.""" + path = _status_file(state_dir) + if not os.path.exists(path): + return ShardStatus(status="missing") + try: + with open(path, "r") as f: + data = json.load(f) + if not isinstance(data, dict): + logger.warning(f"Invalid status format in {path}; expected object") + return ShardStatus(status="unknown") + return ShardStatus.from_dict(data) + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"Failed to read status file {path}: {e}") + return ShardStatus(status="unknown") + + +def _write_status(state_dir: str, payload: ShardStatus): + """Atomically write status.json.""" + os.makedirs(state_dir, exist_ok=True) + tmp_path = _status_file(state_dir) + ".tmp" + with open(tmp_path, "w") as f: + json.dump(payload.to_dict(), f, indent=2, sort_keys=True) + os.replace(tmp_path, _status_file(state_dir)) + + +def _clear_markers(state_dir: str): + """Remove status marker files.""" + for name in (".RUNNING", ".SUCCESS", ".FAILED"): + path = os.path.join(state_dir, name) + if os.path.exists(path): + try: + os.remove(path) + except OSError as e: + logger.warning(f"Failed to remove marker {path}: {e}") + + +def _touch_marker(state_dir: str, name: str): + """Create a marker file.""" + os.makedirs(state_dir, exist_ok=True) + path = os.path.join(state_dir, name) + with open(path, "w") as f: + f.write(f"{datetime.now().isoformat()}\n") + + +def _mark_running( + state_dir: str, + shard_id: int, + datasets_config: List[BaseDataLoaderConfig], +) -> int: + """Mark shard as running and increment retry count.""" + prev = _read_status(state_dir) + retry_count = prev.retry_count + 1 + + payload = ShardStatus( + status="running", + retry_count=retry_count, + shard_id=shard_id, + started_at=datetime.now().isoformat(), + finished_at=None, + error=None, + hostname=socket.gethostname(), + pid=os.getpid(), + slurm_job_id=os.environ.get("SLURM_JOB_ID"), + slurm_array_task_id=os.environ.get("SLURM_ARRAY_TASK_ID"), + datasets=[ + { + "path": ds_config.path, + "output_dir": ds_config.output_dir, + } + for ds_config in datasets_config + ], + ) + + _write_status(state_dir, payload) + _clear_markers(state_dir) + _touch_marker(state_dir, ".RUNNING") + return retry_count + + +def _mark_success(state_dir: str): + """Mark shard as successful.""" + prev = _read_status(state_dir) + prev.status = "success" + prev.finished_at = datetime.now().isoformat() + prev.error = None + _write_status(state_dir, prev) + _clear_markers(state_dir) + _touch_marker(state_dir, ".SUCCESS") + + +def _mark_failure(state_dir: str, error_msg: str): + """Mark shard as failed.""" + prev = _read_status(state_dir) + prev.status = "failed" + prev.finished_at = datetime.now().isoformat() + prev.error = error_msg + _write_status(state_dir, prev) + _clear_markers(state_dir) + _touch_marker(state_dir, ".FAILED")