From c3031927624173464e25cf8005b0d1f136eb06a6 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 3 Mar 2026 17:14:24 +0100 Subject: [PATCH 01/47] first version with auto retry to test on real data --- README.md | 17 +++ run_with_retry.sh | 233 +++++++++++++++++++++++++++++++++++ src/mmirage/merge_shards.py | 137 ++++++++++++++++++-- src/mmirage/shard_process.py | 130 ++++++++++++++----- 4 files changed, 480 insertions(+), 37 deletions(-) create mode 100644 run_with_retry.sh diff --git a/README.md b/README.md index 953cbf9..fb175fd 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,23 @@ For testing and scripts that make use of the library, it is advised to create a ## Example usage +### Running with Automatic Retry + +The simplest way to run MMIRAGE with automatic failure detection and retry: + +```bash +# 1. Edit configuration in run_with_retry.sh (set your paths, num shards, etc.) + +# 2. Submit - everything else is automatic +sbatch run_with_retry.sh + +# That's it! The system will: +# - Process all shards +# - Detect failures automatically +# - Retry only failed shards +# - Repeat until all succeed or max retries +``` + ### Text-only: Reformatting dataset Suppose you have a dataset with samples of the following format diff --git a/run_with_retry.sh b/run_with_retry.sh new file mode 100644 index 0000000..6934baa --- /dev/null +++ b/run_with_retry.sh @@ -0,0 +1,233 @@ +#!/bin/bash +#SBATCH --job-name=mmirage-auto-retry +#SBATCH --chdir=/users/$USER/meditron/MMIRAGE/src/mmirage +#SBATCH --output=/users/$USER/reports/R-%x.%A_%a.out +#SBATCH --error=/users/$USER/reports/R-%x.%A_%a.err +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=288 +#SBATCH --time=11:59:59 +#SBATCH -A a127 + +############################################################################## +# MMIRAGE with Automatic Retry +# +# This script automatically detects and relaunches failed shards until all +# complete successfully or max retries is reached. +# +# Usage: +# 1. Edit the configuration section below +# 2. Run locally (NOT via sbatch): ./run_with_retry.sh +# 3. That's it - everything else is automatic +# +# NOTE: This script submits jobs to SLURM internally, so run it as a regular +# bash script, not with sbatch. +############################################################################## + +# ============================================================================ +# CONFIGURATION - Edit these to match your setup +# ============================================================================ +export USER="username" + +export MMIRAGE_PATH="/users/$USER/meditron/MMIRAGE" + +# Number of shards (will launch array 0 to NUM_SHARDS-1) +export NUM_SHARDS=32 + +# Path to your MMIRAGE config file +export CFG=$MMIRAGE_PATH/configs/config_medtrinity.yaml + +# Output directory for shards +export SHARDS_ROOT=$SCRATCH/mmirage_output/shards + +# HF cache/home +export HF_HOME=$SCRATCH/hf + +# Maximum retry attempts per shard (prevents infinite loops) +export MAX_RETRIES=3 + +# SLURM settings for worker nodes +export WORKER_ACCOUNT="a127" +export WORKER_RESERVATION="sai-a127" +export WORKER_ENVIRONMENT="/users/$USER/.edf/sglang.toml" + +# ============================================================================ +# END CONFIGURATION - Don't edit below unless you know what you're doing +# ============================================================================ + +mkdir -p "$SHARDS_ROOT" + +# Detect which mode we're in based on SLURM variables +if [ -n "$SLURM_ARRAY_TASK_ID" ]; then + # ======================================================================== + # WORKER MODE - Process one shard + # ======================================================================== + echo "==========================================" + echo "Worker: Processing shard $SLURM_ARRAY_TASK_ID" + echo "Started at: $(date)" + echo "==========================================" + + export CMD="python $MMIRAGE_PATH/src/mmirage/shard_process.py --config $CFG" + + SRUN_ARGS=" \ + --cpus-per-task $SLURM_CPUS_PER_TASK \ + --jobid $SLURM_JOB_ID \ + --wait 60 \ + -A $WORKER_ACCOUNT \ + --reservation $WORKER_RESERVATION \ + --environment $WORKER_ENVIRONMENT + " + + srun $SRUN_ARGS bash -c "$CMD" + EXIT_CODE=$? + + echo "END TIME: $(date)" + echo "EXIT CODE: $EXIT_CODE" + + exit $EXIT_CODE + +elif [ -n "$SLURM_JOB_ID" ]; then + # ======================================================================== + # CONTROLLER MODE - Check for failures and resubmit + # ======================================================================== + echo "==========================================" + echo "Controller: Checking for failed shards" + echo "Started at: $(date)" + echo "==========================================" + echo "" + + # Function to check for successful shards + check_shards() { + local failed_shards=() + local success_count=0 + local failed_count=0 + local missing_count=0 + + for i in $(seq 0 $((NUM_SHARDS-1))); do + # Find shard directories (may be nested under dataset dirs) + shard_dirs=$(find "$SHARDS_ROOT" -type d -name "shard_$i" 2>/dev/null) + + if [ -z "$shard_dirs" ]; then + echo "ā“ Shard $i: MISSING (no directory)" + failed_shards+=($i) + ((missing_count++)) + continue + fi + + # Check each shard directory for success marker + shard_success=false + for shard_dir in $shard_dirs; do + if [ -f "$shard_dir/.SUCCESS" ]; then + shard_success=true + break + fi + done + + if [ "$shard_success" = true ]; then + echo "āœ… Shard $i: SUCCESS" + ((success_count++)) + else + # Check retry count + retry_count=0 + for shard_dir in $shard_dirs; do + if [ -f "$shard_dir/.retry_count" ]; then + retry_count=$(cat "$shard_dir/.retry_count") + break + fi + done + + if [ $retry_count -ge $MAX_RETRIES ]; then + echo "šŸ›‘ Shard $i: MAX RETRIES EXCEEDED ($retry_count/$MAX_RETRIES)" + else + echo "āŒ Shard $i: FAILED (retries: $retry_count/$MAX_RETRIES)" + failed_shards+=($i) + ((failed_count++)) + fi + fi + done + + echo "" + echo "==========================================" + echo "šŸ“Š Summary:" + echo " āœ… Successful: $success_count / $NUM_SHARDS" + echo " āŒ Failed/Missing: $failed_count" + echo "==========================================" + echo "" + + # Return failed shards as comma-separated list + if [ ${#failed_shards[@]} -eq 0 ]; then + return 0 + else + IFS=',' + echo "${failed_shards[*]}" + return 1 + fi + } + + # Check for failures + FAILED_LIST=$(check_shards) + CHECK_EXIT=$? + + if [ $CHECK_EXIT -eq 0 ]; then + echo "šŸŽ‰ All shards completed successfully!" + echo "You can now merge with:" + echo " python $MMIRAGE_PATH/src/mmirage/merge_shards.py \\" + echo " --input-dir $SHARDS_ROOT \\" + echo " --output-dir \$MERGED_DIR" + exit 0 + fi + + echo "šŸ”„ Relaunching failed shards: $FAILED_LIST" + echo "" + + # Resubmit workers for failed shards + WORKER_JOB=$(sbatch --parsable --array=$FAILED_LIST $0) + echo "āœ… Worker job submitted: $WORKER_JOB" + + # Resubmit controller to check again after workers finish + CONTROLLER_JOB=$(sbatch --parsable --dependency=afterany:$WORKER_JOB $0) + echo "āœ… Controller job submitted: $CONTROLLER_JOB" + + echo "" + echo "Automatic retry chain activated." + echo "Monitor with: squeue -u \$USER | grep mmirage-auto-retry" + +else + # ======================================================================== + # INITIAL MODE - Submit the first job array + # ======================================================================== + echo "==========================================" + echo "Submitting initial MMIRAGE job array" + echo "==========================================" + echo "Config: $CFG" + echo "Shards: $NUM_SHARDS (0-$((NUM_SHARDS-1)))" + echo "Output: $SHARDS_ROOT" + echo "Max retries: $MAX_RETRIES" + echo "" + + # Submit worker array + WORKER_JOB=$(sbatch --parsable --array=0-$((NUM_SHARDS-1)) $0) + echo "āœ… Worker job submitted: $WORKER_JOB" + + # Submit controller to run after workers + CONTROLLER_JOB=$(sbatch --parsable --dependency=afterany:$WORKER_JOB $0) + echo "āœ… Controller job submitted: $CONTROLLER_JOB" + + echo "" + echo "==========================================" + echo "Jobs submitted successfully!" + echo "==========================================" + echo "" + echo "The system will automatically:" + echo " 1. Process all $NUM_SHARDS shards" + echo " 2. Check for failures" + echo " 3. Retry failed shards" + echo " 4. Repeat until all succeed or max retries" + echo "" + echo "Monitor with:" + echo " squeue -u \$USER | grep mmirage-auto-retry" + echo "" + echo "Cancel with:" + echo " scancel -n mmirage-auto-retry" +fi diff --git a/src/mmirage/merge_shards.py b/src/mmirage/merge_shards.py index 9f8c562..99357ca 100644 --- a/src/mmirage/merge_shards.py +++ b/src/mmirage/merge_shards.py @@ -2,7 +2,8 @@ import argparse import os -from typing import Dict, List +import sys +from typing import Dict, List, Set, Tuple from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk @@ -64,6 +65,63 @@ def _shard_key(path: str) -> int: return shard_dirs +def _extract_shard_id(shard_path: str) -> int: + """Extract shard ID from shard directory path.""" + base = os.path.basename(shard_path) + suffix = base.removeprefix("shard_") + return int(suffix) if suffix.isdigit() else -1 + + +def _check_shard_success(shard_dir: str) -> bool: + """Check if a shard completed successfully based on .SUCCESS marker.""" + success_file = os.path.join(shard_dir, ".SUCCESS") + return os.path.exists(success_file) + + +def _check_shard_failed(shard_dir: str) -> bool: + """Check if a shard failed based on .FAILED marker.""" + failed_file = os.path.join(shard_dir, ".FAILED") + return os.path.exists(failed_file) + + +def _analyze_shard_status( + dataset_dir: str, expected_shards: int = None +) -> Tuple[Set[int], Set[int], Set[int], Set[int]]: + """Analyze status of all shards in a dataset directory. + + Args: + dataset_dir: Path to dataset directory containing shards + expected_shards: Expected number of shards (for detecting missing ones) + + Returns: + Tuple of (success_ids, failed_ids, incomplete_ids, missing_ids) + """ + shard_dirs = _list_shard_dirs(dataset_dir) + + success_ids: Set[int] = set() + failed_ids: Set[int] = set() + incomplete_ids: Set[int] = set() + + for shard_dir in shard_dirs: + shard_id = _extract_shard_id(shard_dir) + if shard_id < 0: + continue + + if _check_shard_success(shard_dir): + success_ids.add(shard_id) + elif _check_shard_failed(shard_dir): + failed_ids.add(shard_id) + else: + incomplete_ids.add(shard_id) + + missing_ids: Set[int] = set() + if expected_shards is not None: + existing_ids = success_ids | failed_ids | incomplete_ids + missing_ids = set(range(expected_shards)) - existing_ids + + return success_ids, failed_ids, incomplete_ids, missing_ids + + def _dataset_dirs(input_dir: str) -> List[str]: """Find dataset directories containing shard folders.""" candidates: List[str] = [] @@ -94,6 +152,22 @@ def main(): required=True, help="Directory to write merged datasets into.", ) + ap.add_argument( + "--expected-shards", + type=int, + default=None, + help="Expected number of shards (for detecting missing shards).", + ) + ap.add_argument( + "--fail-on-missing", + action="store_true", + help="Fail if any shards are missing, failed, or incomplete.", + ) + ap.add_argument( + "--check-markers", + action="store_true", + help="Check for .SUCCESS markers and report shard status.", + ) args = ap.parse_args() input_dir = args.input_dir @@ -111,15 +185,60 @@ def main(): ) for dataset_dir in dataset_dirs: + dataset_name = os.path.basename(dataset_dir) + print(f"\n{'='*60}") + print(f"Processing dataset: {dataset_name}") + print(f"{'='*60}") + shard_dirs = _list_shard_dirs(dataset_dir) if not shard_dirs: continue + # Check shard status if requested + if args.check_markers or args.expected_shards is not None: + success_ids, failed_ids, incomplete_ids, missing_ids = _analyze_shard_status( + dataset_dir, args.expected_shards + ) + + total_found = len(success_ids) + len(failed_ids) + len(incomplete_ids) + total_expected = args.expected_shards if args.expected_shards else total_found + + print(f"\nšŸ“Š Shard Status Report:") + print(f" āœ… Successful: {len(success_ids)} / {total_expected}") + print(f" āŒ Failed: {len(failed_ids)}") + print(f" āš ļø Incomplete: {len(incomplete_ids)}") + print(f" ā“ Missing: {len(missing_ids)}") + + if failed_ids: + print(f"\n Failed shard IDs: {sorted(failed_ids)}") + if incomplete_ids: + print(f" Incomplete shard IDs: {sorted(incomplete_ids)}") + if missing_ids: + print(f" Missing shard IDs: {sorted(missing_ids)}") + + # Check if we should fail + has_problems = bool(failed_ids or incomplete_ids or missing_ids) + if has_problems and args.fail_on_missing: + print(f"\nāŒ ERROR: Found failed/missing/incomplete shards and --fail-on-missing is set") + sys.exit(1) + elif has_problems: + print(f"\nāš ļø WARNING: Some shards are incomplete - merged dataset may be missing data") + print(f" Consider running failure detection and relaunching failed shards:") + print(f" python src/mmirage/detect_failures.py --input-dir {input_dir} --num-shards {total_expected}") + shard_dsets: List[DatasetLike] = [] skipped_empty_dir = 0 skipped_zero_rows = 0 + skipped_failed = 0 for shard_dir in shard_dirs: + # Skip explicitly failed shards if check-markers is enabled + if args.check_markers and _check_shard_failed(shard_dir): + shard_id = _extract_shard_id(shard_dir) + print(f"āš ļø Skipping failed shard {shard_id}: {shard_dir}") + skipped_failed += 1 + continue + try: ds = load_from_disk(shard_dir) except FileNotFoundError as e: @@ -142,13 +261,14 @@ def main(): raise RuntimeError( f"No non-empty shards found in {dataset_dir}. " f"empty/invalid dirs: {skipped_empty_dir}, " - f"zero-row datasets: {skipped_zero_rows}." + f"zero-row datasets: {skipped_zero_rows}, " + f"failed shards: {skipped_failed}." ) ds_merged = _merge_shards(shard_dsets) n_rows = _count_rows(ds_merged) - total_skipped = skipped_empty_dir + skipped_zero_rows + total_skipped = skipped_empty_dir + skipped_zero_rows + skipped_failed if dataset_dir == input_dir: ds_out_dir = output_dir @@ -161,11 +281,14 @@ def main(): ds_merged.save_to_disk(ds_out_dir) print( - f"āœ… Concatenated {len(shard_dsets)} shards for {dataset_name} " - f"with {n_rows} rows.\n" - f" Skipped shards: {total_skipped} total " - f"(empty/invalid dir: {skipped_empty_dir}, zero rows: {skipped_zero_rows})." + f"\nāœ… Concatenated {len(shard_dsets)} shards for {dataset_name} " + f"with {n_rows} rows." ) + if total_skipped > 0: + print( + f" Skipped shards: {total_skipped} total " + f"(empty/invalid: {skipped_empty_dir}, zero rows: {skipped_zero_rows}, failed: {skipped_failed})" + ) if __name__ == "__main__": diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index f232dc6..c6ee932 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -4,8 +4,11 @@ """ import argparse +from datetime import datetime from functools import reduce import os +import sys +import traceback from typing import Any, Dict, List from datasets import Dataset, DatasetDict @@ -55,6 +58,45 @@ def _remove_columns(ds: DatasetLike, enable: bool) -> List[str]: return ds.column_names +def _get_retry_count(shard_dir: str) -> int: + """Get retry count for a shard from retry marker file.""" + retry_file = os.path.join(shard_dir, ".retry_count") + if not os.path.exists(retry_file): + return 0 + try: + with open(retry_file, "r") as f: + return int(f.read().strip()) + except (ValueError, IOError): + return 0 + + +def _increment_retry_count(shard_dir: str) -> int: + """Increment and write retry count for a shard.""" + count = _get_retry_count(shard_dir) + 1 + retry_file = os.path.join(shard_dir, ".retry_count") + os.makedirs(shard_dir, exist_ok=True) + with open(retry_file, "w") as f: + f.write(str(count)) + return count + + +def _write_success_marker(shard_dir: str): + """Write success marker file for a completed shard.""" + marker_file = os.path.join(shard_dir, ".SUCCESS") + os.makedirs(shard_dir, exist_ok=True) + with open(marker_file, "w") as f: + f.write(f"completed_at: {datetime.now().isoformat()}\n") + + +def _write_failure_marker(shard_dir: str, error_msg: str): + """Write failure marker file with error information.""" + marker_file = os.path.join(shard_dir, ".FAILED") + os.makedirs(shard_dir, exist_ok=True) + with open(marker_file, "w") as f: + f.write(f"failed_at: {datetime.now().isoformat()}\n") + f.write(f"error: {error_msg}\n") + + def rewrite_batch( batch: Dict[str, List[Any]], mapper: MMIRAGEMapper, @@ -114,42 +156,70 @@ def main(): if not (0 <= shard_id < num_shards): raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") - ds_all = load_datasets_from_configs(datasets_config) - total_rows = sum(_count_rows(ds) for ds in ds_all) + # Track shard directories for marker files + shard_dirs = [] - ds_all_shard = [_shard_dataset(ds, num_shards, shard_id) for ds in ds_all] - shard_rows = sum(_count_rows(ds) for ds in ds_all_shard) + try: + ds_all = load_datasets_from_configs(datasets_config) + total_rows = sum(_count_rows(ds) for ds in ds_all) - logger.info( - f"Loaded {len(datasets_config)} dataset(s): {datasets_config} " - f"→ {total_rows} total rows; this shard has {shard_rows} rows." - ) + ds_all_shard = [_shard_dataset(ds, num_shards, shard_id) for ds in ds_all] + shard_rows = sum(_count_rows(ds) for ds in ds_all_shard) - mapper = MMIRAGEMapper( - cfg.processors, processing_params.inputs, processing_params.outputs - ) - renderer = TemplateRenderer(processing_params.output_schema) - ds_processed_all: List[DatasetLike] = [] - for ds_idx, ds_shard in enumerate(ds_all_shard): - ds_config = datasets_config[ds_idx] - remove_columns = _remove_columns(ds_shard, processing_params.remove_columns) - ds_processed = ds_shard.map( - rewrite_batch, - batched=True, - batch_size=loading_params.get_batch_size(), - load_from_cache_file=False, - desc=f"Shard {shard_id}/{num_shards - 1} dataset {ds_idx}", - fn_kwargs={"mapper": mapper, "renderer": renderer, "image_base_path": ds_config.image_base_path}, - remove_columns=remove_columns, + logger.info( + f"Loaded {len(datasets_config)} dataset(s): {datasets_config} " + f"→ {total_rows} total rows; this shard has {shard_rows} rows." ) - ds_processed_all.append(ds_processed) - for ds_config, ds_processed in zip(datasets_config, ds_processed_all): - out_dir = _dataset_out_dir(shard_id, ds_config) - os.makedirs(out_dir, exist_ok=True) - ds_processed.save_to_disk(out_dir) + # Increment retry count for each shard directory + for ds_config in datasets_config: + shard_dir = _dataset_out_dir(shard_id, ds_config) + retry_count = _increment_retry_count(shard_dir) + shard_dirs.append(shard_dir) + if retry_count > 1: + logger.info(f"Retry attempt #{retry_count} for shard {shard_id}") - logger.info(f"āœ… Saved dataset in: {out_dir}") + mapper = MMIRAGEMapper( + cfg.processors, processing_params.inputs, processing_params.outputs + ) + renderer = TemplateRenderer(processing_params.output_schema) + ds_processed_all: List[DatasetLike] = [] + for ds_idx, ds_shard in enumerate(ds_all_shard): + ds_config = datasets_config[ds_idx] + remove_columns = _remove_columns(ds_shard, processing_params.remove_columns) + ds_processed = ds_shard.map( + rewrite_batch, + batched=True, + batch_size=loading_params.get_batch_size(), + load_from_cache_file=False, + desc=f"Shard {shard_id}/{num_shards - 1} dataset {ds_idx}", + fn_kwargs={"mapper": mapper, "renderer": renderer, "image_base_path": ds_config.image_base_path}, + remove_columns=remove_columns, + ) + ds_processed_all.append(ds_processed) + + for ds_config, ds_processed in zip(datasets_config, ds_processed_all): + out_dir = _dataset_out_dir(shard_id, ds_config) + os.makedirs(out_dir, exist_ok=True) + ds_processed.save_to_disk(out_dir) + logger.info(f"āœ… Saved dataset in: {out_dir}") + + # Write success markers for all shards + for shard_dir in shard_dirs: + _write_success_marker(shard_dir) + logger.info(f"āœ… Shard {shard_id} completed successfully") + + except Exception as e: + error_msg = f"{type(e).__name__}: {str(e)}" + logger.error(f"āŒ Shard {shard_id} failed: {error_msg}") + logger.error(traceback.format_exc()) + + # Write failure markers for all shards + for shard_dir in shard_dirs: + _write_failure_marker(shard_dir, error_msg) + + # Re-raise to ensure non-zero exit code + sys.exit(1) if __name__ == "__main__": From c4f32fe4dc1b291ff3a2fa27cecb72b92b244f0e Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 3 Mar 2026 17:19:28 +0100 Subject: [PATCH 02/47] config is needed as well --- configs/config_medtrinity.yaml | 63 ++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 configs/config_medtrinity.yaml diff --git a/configs/config_medtrinity.yaml b/configs/config_medtrinity.yaml new file mode 100644 index 0000000..d02e9fb --- /dev/null +++ b/configs/config_medtrinity.yaml @@ -0,0 +1,63 @@ +processors: + - type: llm + server_args: + model_path: Qwen/Qwen3-4B-Instruct-2507 + tp_size: 4 + disable_custom_all_reduce: true + default_sampling_params: + temperature: 0.1 + top_p: 0.9 + max_new_tokens: 1024 + chat_template_kwargs: + enable_thinking: false + +loading_params: + datasets: + - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/dirty/medtrinity_conversations_sampled + type: loadable + output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled + num_shards: "$SLURM_ARRAY_TASK_COUNT" + shard_id: "$SLURM_ARRAY_TASK_ID" + conversations_field: "conversations" + batch_size: 128 + +processing_params: + inputs: + - name: assistant_answer + key: conversations[1].content + - name: user_prompt + key: conversations[0].content + - name: modalities + key: modalities + + outputs: + - name: formatted_answer + type: llm + output_type: plain + prompt: | + You will receive a JSON object with a single field "assistant_text". + + Task: + - Rewrite "assistant_text" as clearer, well-structured **Markdown**. + - You may add headings, bullet points, and simple formatting to improve readability. + - Keep the original meaning; do **not** invent new facts or interpretations. + + Output: + - Return only the rewritten Markdown text as a plain string. + - Do NOT wrap it in JSON, quotes, or code fences. + + Input: + {{ assistant_answer }} + output_schema: + - question + - explanation + - answer + + remove_columns: True + output_schema: + conversations: + - role: user + content: "{{ user_prompt }}" + - role: assistant + content: "{{ formatted_answer }}" + modalities: "{{ modalities }}" \ No newline at end of file From 7e16532493a94aa83b47f2eca3817290a00c75ef Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 3 Mar 2026 17:31:42 +0100 Subject: [PATCH 03/47] simplified scripts to test for first version --- README.md | 18 ++-- retry_failed.sh | 89 +++++++++++++++++ run_with_retry.sh | 245 +++++----------------------------------------- 3 files changed, 120 insertions(+), 232 deletions(-) create mode 100644 retry_failed.sh diff --git a/README.md b/README.md index fb175fd..3443a37 100644 --- a/README.md +++ b/README.md @@ -34,21 +34,21 @@ For testing and scripts that make use of the library, it is advised to create a ### Running with Automatic Retry -The simplest way to run MMIRAGE with automatic failure detection and retry: +Your job scripts now automatically track success/failure with marker files. After your job completes, just run a simple retry script: ```bash -# 1. Edit configuration in run_with_retry.sh (set your paths, num shards, etc.) +# 1. Submit your job normally +sbatch run_with_retry.sh # or run_medtrinity.sh -# 2. Submit - everything else is automatic -sbatch run_with_retry.sh +# 2. After job completes, check for failures and retry +./retry_failed.sh -# That's it! The system will: -# - Process all shards -# - Detect failures automatically -# - Retry only failed shards -# - Repeat until all succeed or max retries +# It will show you which shards failed and ask if you want to relaunch them +# Keep running retry_failed.sh until all shards succeed ``` +See [docs/SIMPLE_RETRY.md](docs/SIMPLE_RETRY.md) for details. + ### Text-only: Reformatting dataset Suppose you have a dataset with samples of the following format diff --git a/retry_failed.sh b/retry_failed.sh new file mode 100644 index 0000000..144e50d --- /dev/null +++ b/retry_failed.sh @@ -0,0 +1,89 @@ +#!/bin/bash +# Check for failed shards and relaunch them +# +# Usage: ./retry_failed.sh + +# Configuration +SHARDS_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled" +NUM_SHARDS=32 +MAX_RETRIES=3 +SCRIPT_PATH="/users/qchapp/meditron/MIRAGE/run_with_retry.sh" + +echo "Checking for failed shards in: $SHARDS_ROOT" +echo "" + +failed_shards=() +success_count=0 + +for i in $(seq 0 $((NUM_SHARDS-1))); do + # Find shard directories (may be nested under dataset dirs) + shard_dirs=$(find "$SHARDS_ROOT" -type d -name "shard_$i" 2>/dev/null) + + if [ -z "$shard_dirs" ]; then + echo "āŒ Shard $i: MISSING" + failed_shards+=($i) + continue + fi + + # Check each shard directory for success marker + shard_success=false + for shard_dir in $shard_dirs; do + if [ -f "$shard_dir/.SUCCESS" ]; then + shard_success=true + break + fi + done + + if [ "$shard_success" = true ]; then + echo "āœ… Shard $i: SUCCESS" + ((success_count++)) + else + # Check retry count + retry_count=0 + for shard_dir in $shard_dirs; do + if [ -f "$shard_dir/.retry_count" ]; then + retry_count=$(cat "$shard_dir/.retry_count") + break + fi + done + + if [ $retry_count -ge $MAX_RETRIES ]; then + echo "šŸ›‘ Shard $i: MAX RETRIES EXCEEDED ($retry_count/$MAX_RETRIES)" + else + echo "āŒ Shard $i: FAILED (retries: $retry_count/$MAX_RETRIES)" + failed_shards+=($i) + fi + fi +done + +echo "" +echo "==========================================" +echo "Summary:" +echo " āœ… Successful: $success_count / $NUM_SHARDS" +echo " āŒ To retry: ${#failed_shards[@]}" +echo "==========================================" +echo "" + +if [ ${#failed_shards[@]} -eq 0 ]; then + echo "šŸŽ‰ All shards completed successfully!" + exit 0 +fi + +# Build array spec +IFS=',' +ARRAY_SPEC="${failed_shards[*]}" +unset IFS + +echo "Failed shards: $ARRAY_SPEC" +echo "" +read -p "Submit retry job for these shards? (y/N) " -n 1 -r +echo + +if [[ $REPLY =~ ^[Yy]$ ]]; then + JOB_ID=$(sbatch --array=$ARRAY_SPEC "$SCRIPT_PATH" | grep -oE '[0-9]+') + echo "āœ… Job submitted: $JOB_ID" + echo "" + echo "Monitor with: squeue -j $JOB_ID" +else + echo "Cancelled." +fi diff --git a/run_with_retry.sh b/run_with_retry.sh index 6934baa..b80a64c 100644 --- a/run_with_retry.sh +++ b/run_with_retry.sh @@ -1,233 +1,32 @@ #!/bin/bash -#SBATCH --job-name=mmirage-auto-retry -#SBATCH --chdir=/users/$USER/meditron/MMIRAGE/src/mmirage -#SBATCH --output=/users/$USER/reports/R-%x.%A_%a.out -#SBATCH --error=/users/$USER/reports/R-%x.%A_%a.err +#SBATCH --job-name=mmirage-medtrinity +#SBATCH --chdir=/users/qchapp/meditron/MIRAGE/src/mmirage +#SBATCH --output=/users/qchapp/reports/R-%x.%A_%a.out +#SBATCH --error=/users/qchapp/reports/R-%x.%A_%a.err #SBATCH --nodes=1 #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:4 #SBATCH --cpus-per-task=288 #SBATCH --time=11:59:59 #SBATCH -A a127 +#SBATCH --array=0-31 -############################################################################## -# MMIRAGE with Automatic Retry -# -# This script automatically detects and relaunches failed shards until all -# complete successfully or max retries is reached. -# -# Usage: -# 1. Edit the configuration section below -# 2. Run locally (NOT via sbatch): ./run_with_retry.sh -# 3. That's it - everything else is automatic -# -# NOTE: This script submits jobs to SLURM internally, so run it as a regular -# bash script, not with sbatch. -############################################################################## - -# ============================================================================ -# CONFIGURATION - Edit these to match your setup -# ============================================================================ -export USER="username" - -export MMIRAGE_PATH="/users/$USER/meditron/MMIRAGE" - -# Number of shards (will launch array 0 to NUM_SHARDS-1) -export NUM_SHARDS=32 - -# Path to your MMIRAGE config file -export CFG=$MMIRAGE_PATH/configs/config_medtrinity.yaml - -# Output directory for shards -export SHARDS_ROOT=$SCRATCH/mmirage_output/shards +# --- outputs & config --- +export CFG=/users/qchapp/meditron/MIRAGE/configs/config_medtrinity.yaml # HF cache/home -export HF_HOME=$SCRATCH/hf - -# Maximum retry attempts per shard (prevents infinite loops) -export MAX_RETRIES=3 - -# SLURM settings for worker nodes -export WORKER_ACCOUNT="a127" -export WORKER_RESERVATION="sai-a127" -export WORKER_ENVIRONMENT="/users/$USER/.edf/sglang.toml" - -# ============================================================================ -# END CONFIGURATION - Don't edit below unless you know what you're doing -# ============================================================================ - -mkdir -p "$SHARDS_ROOT" - -# Detect which mode we're in based on SLURM variables -if [ -n "$SLURM_ARRAY_TASK_ID" ]; then - # ======================================================================== - # WORKER MODE - Process one shard - # ======================================================================== - echo "==========================================" - echo "Worker: Processing shard $SLURM_ARRAY_TASK_ID" - echo "Started at: $(date)" - echo "==========================================" - - export CMD="python $MMIRAGE_PATH/src/mmirage/shard_process.py --config $CFG" - - SRUN_ARGS=" \ - --cpus-per-task $SLURM_CPUS_PER_TASK \ - --jobid $SLURM_JOB_ID \ - --wait 60 \ - -A $WORKER_ACCOUNT \ - --reservation $WORKER_RESERVATION \ - --environment $WORKER_ENVIRONMENT - " - - srun $SRUN_ARGS bash -c "$CMD" - EXIT_CODE=$? - - echo "END TIME: $(date)" - echo "EXIT CODE: $EXIT_CODE" - - exit $EXIT_CODE - -elif [ -n "$SLURM_JOB_ID" ]; then - # ======================================================================== - # CONTROLLER MODE - Check for failures and resubmit - # ======================================================================== - echo "==========================================" - echo "Controller: Checking for failed shards" - echo "Started at: $(date)" - echo "==========================================" - echo "" - - # Function to check for successful shards - check_shards() { - local failed_shards=() - local success_count=0 - local failed_count=0 - local missing_count=0 - - for i in $(seq 0 $((NUM_SHARDS-1))); do - # Find shard directories (may be nested under dataset dirs) - shard_dirs=$(find "$SHARDS_ROOT" -type d -name "shard_$i" 2>/dev/null) - - if [ -z "$shard_dirs" ]; then - echo "ā“ Shard $i: MISSING (no directory)" - failed_shards+=($i) - ((missing_count++)) - continue - fi - - # Check each shard directory for success marker - shard_success=false - for shard_dir in $shard_dirs; do - if [ -f "$shard_dir/.SUCCESS" ]; then - shard_success=true - break - fi - done - - if [ "$shard_success" = true ]; then - echo "āœ… Shard $i: SUCCESS" - ((success_count++)) - else - # Check retry count - retry_count=0 - for shard_dir in $shard_dirs; do - if [ -f "$shard_dir/.retry_count" ]; then - retry_count=$(cat "$shard_dir/.retry_count") - break - fi - done - - if [ $retry_count -ge $MAX_RETRIES ]; then - echo "šŸ›‘ Shard $i: MAX RETRIES EXCEEDED ($retry_count/$MAX_RETRIES)" - else - echo "āŒ Shard $i: FAILED (retries: $retry_count/$MAX_RETRIES)" - failed_shards+=($i) - ((failed_count++)) - fi - fi - done - - echo "" - echo "==========================================" - echo "šŸ“Š Summary:" - echo " āœ… Successful: $success_count / $NUM_SHARDS" - echo " āŒ Failed/Missing: $failed_count" - echo "==========================================" - echo "" - - # Return failed shards as comma-separated list - if [ ${#failed_shards[@]} -eq 0 ]; then - return 0 - else - IFS=',' - echo "${failed_shards[*]}" - return 1 - fi - } - - # Check for failures - FAILED_LIST=$(check_shards) - CHECK_EXIT=$? - - if [ $CHECK_EXIT -eq 0 ]; then - echo "šŸŽ‰ All shards completed successfully!" - echo "You can now merge with:" - echo " python $MMIRAGE_PATH/src/mmirage/merge_shards.py \\" - echo " --input-dir $SHARDS_ROOT \\" - echo " --output-dir \$MERGED_DIR" - exit 0 - fi - - echo "šŸ”„ Relaunching failed shards: $FAILED_LIST" - echo "" - - # Resubmit workers for failed shards - WORKER_JOB=$(sbatch --parsable --array=$FAILED_LIST $0) - echo "āœ… Worker job submitted: $WORKER_JOB" - - # Resubmit controller to check again after workers finish - CONTROLLER_JOB=$(sbatch --parsable --dependency=afterany:$WORKER_JOB $0) - echo "āœ… Controller job submitted: $CONTROLLER_JOB" - - echo "" - echo "Automatic retry chain activated." - echo "Monitor with: squeue -u \$USER | grep mmirage-auto-retry" - -else - # ======================================================================== - # INITIAL MODE - Submit the first job array - # ======================================================================== - echo "==========================================" - echo "Submitting initial MMIRAGE job array" - echo "==========================================" - echo "Config: $CFG" - echo "Shards: $NUM_SHARDS (0-$((NUM_SHARDS-1)))" - echo "Output: $SHARDS_ROOT" - echo "Max retries: $MAX_RETRIES" - echo "" - - # Submit worker array - WORKER_JOB=$(sbatch --parsable --array=0-$((NUM_SHARDS-1)) $0) - echo "āœ… Worker job submitted: $WORKER_JOB" - - # Submit controller to run after workers - CONTROLLER_JOB=$(sbatch --parsable --dependency=afterany:$WORKER_JOB $0) - echo "āœ… Controller job submitted: $CONTROLLER_JOB" - - echo "" - echo "==========================================" - echo "Jobs submitted successfully!" - echo "==========================================" - echo "" - echo "The system will automatically:" - echo " 1. Process all $NUM_SHARDS shards" - echo " 2. Check for failures" - echo " 3. Retry failed shards" - echo " 4. Repeat until all succeed or max retries" - echo "" - echo "Monitor with:" - echo " squeue -u \$USER | grep mmirage-auto-retry" - echo "" - echo "Cancel with:" - echo " scancel -n mmirage-auto-retry" -fi +export HF_HOME=/capstor/store/cscs/swissai/a127/homes/qchapp/hf + +export CMD="python /users/qchapp/meditron/MIRAGE/src/mmirage/shard_process.py --config $CFG" + +SRUN_ARGS=" \ + --cpus-per-task $SLURM_CPUS_PER_TASK \ + --jobid $SLURM_JOB_ID \ + --wait 60 \ + -A a127 \ + --reservation sai-a127 \ + --environment /users/qchapp/.edf/sglang.toml + " +# bash -c is needed for the delayed interpolation of env vars to work +srun $SRUN_ARGS bash -c "$CMD" +echo "END TIME: $(date)" From a62b1add34df1cc73ee0f511f24a6d301888f231 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Fri, 6 Mar 2026 21:20:52 +0100 Subject: [PATCH 04/47] remove old dataset if we need to clean up --- src/mmirage/shard_process.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index c6ee932..26b2e91 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -6,6 +6,7 @@ import argparse from datetime import datetime from functools import reduce +import glob import os import sys import traceback @@ -80,6 +81,35 @@ def _increment_retry_count(shard_dir: str) -> int: return count +def _cleanup_old_shard_data(shard_dir: str): + """Remove old data files from a shard directory before retry. + + Keeps marker files (.SUCCESS, .FAILED, .retry_count) but removes + arrow files and dataset metadata to prevent duplicates. + """ + if not os.path.exists(shard_dir): + return + + # Patterns for files to remove + patterns_to_remove = [ + "*.arrow", + "dataset_info.json", + "state.json", + ] + + removed_count = 0 + for pattern in patterns_to_remove: + for file_path in glob.glob(os.path.join(shard_dir, pattern)): + try: + os.remove(file_path) + removed_count += 1 + except OSError as e: + logger.warning(f"Failed to remove {file_path}: {e}") + + if removed_count > 0: + logger.info(f"Cleaned up {removed_count} old data files from {shard_dir}") + + def _write_success_marker(shard_dir: str): """Write success marker file for a completed shard.""" marker_file = os.path.join(shard_dir, ".SUCCESS") @@ -178,6 +208,8 @@ def main(): shard_dirs.append(shard_dir) if retry_count > 1: logger.info(f"Retry attempt #{retry_count} for shard {shard_id}") + # Clean up old data files to prevent duplicates + _cleanup_old_shard_data(shard_dir) mapper = MMIRAGEMapper( cfg.processors, processing_params.inputs, processing_params.outputs From a45f795c6d4646bc4937783efa001c9847750a5c Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Fri, 6 Mar 2026 21:41:21 +0100 Subject: [PATCH 05/47] testing something --- retry_failed.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/retry_failed.sh b/retry_failed.sh index 144e50d..107aa8f 100644 --- a/retry_failed.sh +++ b/retry_failed.sh @@ -1,7 +1,7 @@ #!/bin/bash # Check for failed shards and relaunch them # -# Usage: ./retry_failed.sh +# Usage: bash retry_failed.sh # Configuration SHARDS_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled" @@ -80,10 +80,10 @@ read -p "Submit retry job for these shards? (y/N) " -n 1 -r echo if [[ $REPLY =~ ^[Yy]$ ]]; then - JOB_ID=$(sbatch --array=$ARRAY_SPEC "$SCRIPT_PATH" | grep -oE '[0-9]+') + JOB_ID=$(sbatch --export=ALL,TOTAL_SHARDS=$NUM_SHARDS --array=$ARRAY_SPEC "$SCRIPT_PATH" | grep -oE '[0-9]+') echo "āœ… Job submitted: $JOB_ID" echo "" - echo "Monitor with: squeue -j $JOB_ID" + echo "Monitor with: squeue -j ${JOB_ID}_" else echo "Cancelled." -fi +fi \ No newline at end of file From 06ea67c26c5a16623725fedbf1e3c74c85a2f996 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Fri, 6 Mar 2026 21:50:46 +0100 Subject: [PATCH 06/47] something missing --- configs/config_medtrinity.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/config_medtrinity.yaml b/configs/config_medtrinity.yaml index d02e9fb..744ddac 100644 --- a/configs/config_medtrinity.yaml +++ b/configs/config_medtrinity.yaml @@ -16,7 +16,7 @@ loading_params: - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/dirty/medtrinity_conversations_sampled type: loadable output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled - num_shards: "$SLURM_ARRAY_TASK_COUNT" + num_shards: "${TOTAL_SHARDS:-$SLURM_ARRAY_TASK_COUNT}" shard_id: "$SLURM_ARRAY_TASK_ID" conversations_field: "conversations" batch_size: 128 From e8aa35a96e60ad952835e10108a4fb438f22e837 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Fri, 6 Mar 2026 21:55:19 +0100 Subject: [PATCH 07/47] fixed maybe --- configs/config_medtrinity.yaml | 2 +- retry_failed.sh | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/configs/config_medtrinity.yaml b/configs/config_medtrinity.yaml index 744ddac..d02e9fb 100644 --- a/configs/config_medtrinity.yaml +++ b/configs/config_medtrinity.yaml @@ -16,7 +16,7 @@ loading_params: - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/dirty/medtrinity_conversations_sampled type: loadable output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled - num_shards: "${TOTAL_SHARDS:-$SLURM_ARRAY_TASK_COUNT}" + num_shards: "$SLURM_ARRAY_TASK_COUNT" shard_id: "$SLURM_ARRAY_TASK_ID" conversations_field: "conversations" batch_size: 128 diff --git a/retry_failed.sh b/retry_failed.sh index 107aa8f..496c9f5 100644 --- a/retry_failed.sh +++ b/retry_failed.sh @@ -80,10 +80,11 @@ read -p "Submit retry job for these shards? (y/N) " -n 1 -r echo if [[ $REPLY =~ ^[Yy]$ ]]; then - JOB_ID=$(sbatch --export=ALL,TOTAL_SHARDS=$NUM_SHARDS --array=$ARRAY_SPEC "$SCRIPT_PATH" | grep -oE '[0-9]+') + # Override SLURM_ARRAY_TASK_COUNT to total shards count for retries + JOB_ID=$(sbatch --export=ALL,SLURM_ARRAY_TASK_COUNT=$NUM_SHARDS --array=$ARRAY_SPEC "$SCRIPT_PATH" | grep -oE '[0-9]+') echo "āœ… Job submitted: $JOB_ID" echo "" - echo "Monitor with: squeue -j ${JOB_ID}_" + echo "Monitor with: squeue -j $JOB_ID" else echo "Cancelled." fi \ No newline at end of file From 186e6b5bf7db5e88fed52832e10b5b5ef28bb465 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Sat, 7 Mar 2026 12:23:46 +0100 Subject: [PATCH 08/47] removing check to test something --- src/mmirage/shard_process.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 26b2e91..ee7657b 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -183,8 +183,8 @@ def main(): shard_id = loading_params.get_shard_id() num_shards = loading_params.get_num_shards() - if not (0 <= shard_id < num_shards): - raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") + # if not (0 <= shard_id < num_shards): + # raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") # Track shard directories for marker files shard_dirs = [] From 392552770427ebcf1bec1691507be46094f4f909 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Sat, 7 Mar 2026 12:43:28 +0100 Subject: [PATCH 09/47] fixed number of shards --- configs/config_medtrinity.yaml | 2 +- retry_failed.sh | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/configs/config_medtrinity.yaml b/configs/config_medtrinity.yaml index d02e9fb..313196d 100644 --- a/configs/config_medtrinity.yaml +++ b/configs/config_medtrinity.yaml @@ -16,7 +16,7 @@ loading_params: - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/dirty/medtrinity_conversations_sampled type: loadable output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled - num_shards: "$SLURM_ARRAY_TASK_COUNT" + num_shards: 32 shard_id: "$SLURM_ARRAY_TASK_ID" conversations_field: "conversations" batch_size: 128 diff --git a/retry_failed.sh b/retry_failed.sh index 496c9f5..3e042f6 100644 --- a/retry_failed.sh +++ b/retry_failed.sh @@ -80,8 +80,7 @@ read -p "Submit retry job for these shards? (y/N) " -n 1 -r echo if [[ $REPLY =~ ^[Yy]$ ]]; then - # Override SLURM_ARRAY_TASK_COUNT to total shards count for retries - JOB_ID=$(sbatch --export=ALL,SLURM_ARRAY_TASK_COUNT=$NUM_SHARDS --array=$ARRAY_SPEC "$SCRIPT_PATH" | grep -oE '[0-9]+') + JOB_ID=$(sbatch --array=$ARRAY_SPEC "$SCRIPT_PATH" | grep -oE '[0-9]+') echo "āœ… Job submitted: $JOB_ID" echo "" echo "Monitor with: squeue -j $JOB_ID" From 5b849dbf4d4bb8c014c47ec8d3607289eded317e Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Sat, 7 Mar 2026 12:58:40 +0100 Subject: [PATCH 10/47] working! --- README.md | 4 ++-- src/mmirage/shard_process.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 3443a37..aaf9bb5 100644 --- a/README.md +++ b/README.md @@ -38,10 +38,10 @@ Your job scripts now automatically track success/failure with marker files. Afte ```bash # 1. Submit your job normally -sbatch run_with_retry.sh # or run_medtrinity.sh +sbatch run.sh # 2. After job completes, check for failures and retry -./retry_failed.sh +bash retry_failed.sh # It will show you which shards failed and ask if you want to relaunch them # Keep running retry_failed.sh until all shards succeed diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index ee7657b..26b2e91 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -183,8 +183,8 @@ def main(): shard_id = loading_params.get_shard_id() num_shards = loading_params.get_num_shards() - # if not (0 <= shard_id < num_shards): - # raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") + if not (0 <= shard_id < num_shards): + raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") # Track shard directories for marker files shard_dirs = [] From 83096b3bde86b0b050fb936d0d93f7f47bcad679 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Sat, 7 Mar 2026 13:10:31 +0100 Subject: [PATCH 11/47] removed wrong part of the doc --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index aaf9bb5..539b227 100644 --- a/README.md +++ b/README.md @@ -47,8 +47,6 @@ bash retry_failed.sh # Keep running retry_failed.sh until all shards succeed ``` -See [docs/SIMPLE_RETRY.md](docs/SIMPLE_RETRY.md) for details. - ### Text-only: Reformatting dataset Suppose you have a dataset with samples of the following format From 53d7c0442008a9f7b86442b6c27e78dad62bbd8c Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 9 Mar 2026 17:35:36 +0100 Subject: [PATCH 12/47] ready for real medtrinity processing --- configs/config_medtrinity.yaml | 7 +++++-- retry_failed.sh | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/configs/config_medtrinity.yaml b/configs/config_medtrinity.yaml index 313196d..c746854 100644 --- a/configs/config_medtrinity.yaml +++ b/configs/config_medtrinity.yaml @@ -13,9 +13,12 @@ processors: loading_params: datasets: - - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/dirty/medtrinity_conversations_sampled + - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/medtrinity_conversations_1/ type: loadable - output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled + output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled2 + - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/medtrinity_conversations_2/ + type: loadable + output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled2 num_shards: 32 shard_id: "$SLURM_ARRAY_TASK_ID" conversations_field: "conversations" diff --git a/retry_failed.sh b/retry_failed.sh index 3e042f6..fd42ed0 100644 --- a/retry_failed.sh +++ b/retry_failed.sh @@ -4,7 +4,7 @@ # Usage: bash retry_failed.sh # Configuration -SHARDS_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled" +SHARDS_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled2" NUM_SHARDS=32 MAX_RETRIES=3 SCRIPT_PATH="/users/qchapp/meditron/MIRAGE/run_with_retry.sh" From 6dbc1cf54a8aa3316fbdbf9ef9dc93950b9fbeaf Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 10 Mar 2026 15:12:27 +0100 Subject: [PATCH 13/47] full run with retry --- run_with_retry.sh | 306 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 275 insertions(+), 31 deletions(-) diff --git a/run_with_retry.sh b/run_with_retry.sh index b80a64c..8baebc1 100644 --- a/run_with_retry.sh +++ b/run_with_retry.sh @@ -1,32 +1,276 @@ #!/bin/bash -#SBATCH --job-name=mmirage-medtrinity -#SBATCH --chdir=/users/qchapp/meditron/MIRAGE/src/mmirage -#SBATCH --output=/users/qchapp/reports/R-%x.%A_%a.out -#SBATCH --error=/users/qchapp/reports/R-%x.%A_%a.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=288 -#SBATCH --time=11:59:59 -#SBATCH -A a127 -#SBATCH --array=0-31 - -# --- outputs & config --- -export CFG=/users/qchapp/meditron/MIRAGE/configs/config_medtrinity.yaml - -# HF cache/home -export HF_HOME=/capstor/store/cscs/swissai/a127/homes/qchapp/hf - -export CMD="python /users/qchapp/meditron/MIRAGE/src/mmirage/shard_process.py --config $CFG" - -SRUN_ARGS=" \ - --cpus-per-task $SLURM_CPUS_PER_TASK \ - --jobid $SLURM_JOB_ID \ - --wait 60 \ - -A a127 \ - --reservation sai-a127 \ - --environment /users/qchapp/.edf/sglang.toml - " -# bash -c is needed for the delayed interpolation of env vars to work -srun $SRUN_ARGS bash -c "$CMD" -echo "END TIME: $(date)" +set -euo pipefail + +# ========================================================= +# MMIRAGE full pipeline wrapper +# - submits shard-processing SLURM array jobs +# - waits for completion +# - checks missing/failed shards +# - retries failed shards up to MAX_RETRIES +# - writes all terminal output to a global log file +# ========================================================= + +# ----------------------------- +# User configuration +# ----------------------------- +JOB_NAME="mmirage-sharded" +ACCOUNT="a127" +RESERVATION="sai-a127" + +MMIRAGE_CHDIR="/users/qchapp/meditron/MIRAGE/src/mmirage" +REPORT_DIR="/users/qchapp/reports" +EDF_ENV="/users/qchapp/.edf/mmirage.toml" + +CFG="${MMIRAGE_PATH}/configs/config_medtrinity.yaml" +# HF_HOME="${SCRATCH}/hf" +HF_HOME="/capstor/store/cscs/swissai/a127/homes/qchapp/hf" + +SHARDS_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled" +NUM_SHARDS=32 +MAX_RETRIES=3 + +# SLURM resources +NODES=1 +NTASKS_PER_NODE=1 +GPUS=4 +CPUS_PER_TASK=288 +TIME_LIMIT="11:59:59" + +# Optional: poll interval while waiting for jobs +POLL_SECONDS=30 + +# ----------------------------- +# Logging setup +# ----------------------------- +mkdir -p "$REPORT_DIR" +LOG_FILE="$REPORT_DIR/${JOB_NAME}_logs.out" + +# Send everything to terminal + logfile +exec > >(tee -a "$LOG_FILE") 2>&1 + +echo "==================================================" +echo "Pipeline : $JOB_NAME" +echo "User : $USER" +echo "Host : $(hostname)" +echo "Start Time : $(date)" +echo "Log File : $LOG_FILE" +echo "==================================================" +echo "" + +# ----------------------------- +# Environment +# ----------------------------- +export HF_HOME +export CFG +export TOTAL_SHARDS="$NUM_SHARDS" + +mkdir -p "$HF_HOME" + +echo "[INFO] Environment snapshot" +echo " MMIRAGE_CHDIR : $MMIRAGE_CHDIR" +echo " CFG : $CFG" +echo " HF_HOME : $HF_HOME" +echo " SHARDS_ROOT : $SHARDS_ROOT" +echo " NUM_SHARDS : $NUM_SHARDS" +echo " MAX_RETRIES : $MAX_RETRIES" +echo "" + +# ----------------------------- +# Retry state +# ----------------------------- +declare -A RETRY_COUNTS +for i in $(seq 0 $((NUM_SHARDS - 1))); do + RETRY_COUNTS[$i]=0 +done + +# ----------------------------- +# Submit an array job +# ----------------------------- +submit_array_job() { + local array_spec="$1" + + echo "[INFO] Submitting SLURM array job for shards: $array_spec" + + local job_id + job_id=$( + sbatch --parsable \ + --job-name="$JOB_NAME" \ + --chdir="$MMIRAGE_CHDIR" \ + --output="$REPORT_DIR/R-%x.%A_%a.out" \ + --error="$REPORT_DIR/R-%x.%A_%a.err" \ + --nodes="$NODES" \ + --ntasks-per-node="$NTASKS_PER_NODE" \ + --gres="gpu:${GPUS}" \ + --cpus-per-task="$CPUS_PER_TASK" \ + --time="$TIME_LIMIT" \ + -A "$ACCOUNT" \ + --array="$array_spec" \ + --export=ALL,CFG="$CFG",TOTAL_SHARDS="$NUM_SHARDS",HF_HOME="$HF_HOME" \ + --wrap=" + set -euo pipefail + export CFG='$CFG' + export TOTAL_SHARDS='$NUM_SHARDS' + export HF_HOME='$HF_HOME' + + echo 'START TIME: ' \$(date) + echo 'HOST: ' \$(hostname) + echo 'SLURM_JOB_ID: ' \$SLURM_JOB_ID + echo 'SLURM_ARRAY_TASK_ID: ' \$SLURM_ARRAY_TASK_ID + + CMD=\"python \$MMIRAGE_PATH/src/mmirage/shard_process.py --config \$CFG\" + + SRUN_ARGS=\" \ + --cpus-per-task $CPUS_PER_TASK \ + --jobid \$SLURM_JOB_ID \ + --wait 60 \ + -A $ACCOUNT \ + --reservation $RESERVATION \ + --environment $EDF_ENV \ + \" + + echo \"COMMAND: \$CMD\" + srun \$SRUN_ARGS bash -c \"\$CMD\" + + echo 'END TIME: ' \$(date) + " + ) + + echo "[INFO] Submitted job ID: $job_id" + SUBMITTED_JOB_ID="$job_id" +} + +# ----------------------------- +# Wait for job completion +# ----------------------------- +wait_for_job() { + local job_id="$1" + + echo "[INFO] Waiting for job $job_id to finish..." + + while squeue -h -j "$job_id" | grep -q .; do + echo "[INFO] Job $job_id still running or pending at $(date)" + sleep "$POLL_SECONDS" + done + + echo "[INFO] Job $job_id finished at $(date)" +} + +# ----------------------------- +# Check shard results +# Populates the named array passed as $1 +# ----------------------------- +check_failed_shards() { + local -n out_failed_shards="$1" + out_failed_shards=() + + local success_count=0 + local exhausted_count=0 + + echo "" + echo "[INFO] Checking shard status in: $SHARDS_ROOT" + echo "" + + for i in $(seq 0 $((NUM_SHARDS - 1))); do + mapfile -t shard_dirs < <(find "$SHARDS_ROOT" -type d -name "shard_$i" 2>/dev/null) + + if [ "${#shard_dirs[@]}" -eq 0 ]; then + if [ "${RETRY_COUNTS[$i]}" -ge "$MAX_RETRIES" ]; then + echo "šŸ›‘ Shard $i: MISSING, max retries exceeded (${RETRY_COUNTS[$i]}/$MAX_RETRIES)" + ((exhausted_count+=1)) + else + echo "āŒ Shard $i: MISSING" + out_failed_shards+=("$i") + fi + continue + fi + + local shard_success=false + for shard_dir in "${shard_dirs[@]}"; do + if [ -f "$shard_dir/.SUCCESS" ]; then + shard_success=true + break + fi + done + + if [ "$shard_success" = true ]; then + echo "āœ… Shard $i: SUCCESS" + ((success_count+=1)) + else + if [ "${RETRY_COUNTS[$i]}" -ge "$MAX_RETRIES" ]; then + echo "šŸ›‘ Shard $i: FAILED, max retries exceeded (${RETRY_COUNTS[$i]}/$MAX_RETRIES)" + ((exhausted_count+=1)) + else + echo "āŒ Shard $i: FAILED (retries used: ${RETRY_COUNTS[$i]}/$MAX_RETRIES)" + out_failed_shards+=("$i") + fi + fi + done + + echo "" + echo "==================================================" + echo "Summary" + echo " Successful : $success_count / $NUM_SHARDS" + echo " Failed to retry : ${#out_failed_shards[@]}" + echo " Retry budget expired : $exhausted_count" + echo "==================================================" + echo "" +} + +# ----------------------------- +# Main loop +# ----------------------------- +main() { + local failed_shards=() + local retryable_shards=() + local array_spec="0-$((NUM_SHARDS - 1))" + local iteration=0 + + while true; do + ((iteration+=1)) + echo "--------------------------------------------------" + echo "[INFO] Iteration $iteration started at $(date)" + echo "--------------------------------------------------" + + submit_array_job "$array_spec" + wait_for_job "$SUBMITTED_JOB_ID" + + check_failed_shards failed_shards + + if [ "${#failed_shards[@]}" -eq 0 ]; then + echo "šŸŽ‰ All shards completed successfully!" + echo "" + echo "==================================================" + echo "Pipeline finished successfully at: $(date)" + echo "==================================================" + exit 0 + fi + + retryable_shards=() + for shard in "${failed_shards[@]}"; do + RETRY_COUNTS[$shard]=$((RETRY_COUNTS[$shard] + 1)) + if [ "${RETRY_COUNTS[$shard]}" -le "$MAX_RETRIES" ]; then + retryable_shards+=("$shard") + fi + done + + if [ "${#retryable_shards[@]}" -eq 0 ]; then + echo "šŸ›‘ No retryable failed shards remain." + echo "" + echo "[INFO] Final retry counters:" + for i in $(seq 0 $((NUM_SHARDS - 1))); do + echo " Shard $i -> ${RETRY_COUNTS[$i]}" + done + echo "" + echo "==================================================" + echo "Pipeline finished with failures at: $(date)" + echo "==================================================" + exit 1 + fi + + array_spec=$(IFS=,; echo "${retryable_shards[*]}") + echo "[INFO] Retrying failed shards: $array_spec" + echo "" + done +} + +main \ No newline at end of file From 2534c2a2d755914b4a1b80119f9b514fbe3a2b03 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 10 Mar 2026 15:14:01 +0100 Subject: [PATCH 14/47] full run with retry changed --- run_with_retry.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run_with_retry.sh b/run_with_retry.sh index 8baebc1..1295320 100644 --- a/run_with_retry.sh +++ b/run_with_retry.sh @@ -21,7 +21,7 @@ MMIRAGE_CHDIR="/users/qchapp/meditron/MIRAGE/src/mmirage" REPORT_DIR="/users/qchapp/reports" EDF_ENV="/users/qchapp/.edf/mmirage.toml" -CFG="${MMIRAGE_PATH}/configs/config_medtrinity.yaml" +CFG="/users/qchapp/meditron/MIRAGE/configs/config_medtrinity.yaml" # HF_HOME="${SCRATCH}/hf" HF_HOME="/capstor/store/cscs/swissai/a127/homes/qchapp/hf" @@ -117,7 +117,7 @@ submit_array_job() { echo 'SLURM_JOB_ID: ' \$SLURM_JOB_ID echo 'SLURM_ARRAY_TASK_ID: ' \$SLURM_ARRAY_TASK_ID - CMD=\"python \$MMIRAGE_PATH/src/mmirage/shard_process.py --config \$CFG\" + CMD=\"python /users/qchapp/meditron/MIRAGE/src/mmirage/shard_process.py --config \$CFG\" SRUN_ARGS=\" \ --cpus-per-task $CPUS_PER_TASK \ From 3c6320652c597da195c1943660b23cff1728eb7e Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 10 Mar 2026 15:54:48 +0100 Subject: [PATCH 15/47] trying a more robust approach --- README.md | 15 +- configs/config_medtrinity.yaml | 5 +- retry_failed.sh | 80 +++++---- run_medtrinity.sh | 33 ++++ src/mmirage/shard_process.py | 296 ++++++++++++++++++++------------- 5 files changed, 260 insertions(+), 169 deletions(-) create mode 100644 run_medtrinity.sh diff --git a/README.md b/README.md index 539b227..cb30734 100644 --- a/README.md +++ b/README.md @@ -34,17 +34,20 @@ For testing and scripts that make use of the library, it is advised to create a ### Running with Automatic Retry -Your job scripts now automatically track success/failure with marker files. After your job completes, just run a simple retry script: +Your job scripts now automatically track success/failure with marker files. Run your job as follows: ```bash -# 1. Submit your job normally +bash run_with_retry.sh +``` + +If you want you can still do it separately: + +```bash +# 1. Launch the job sbatch run.sh -# 2. After job completes, check for failures and retry +# 2. Look for retries bash retry_failed.sh - -# It will show you which shards failed and ask if you want to relaunch them -# Keep running retry_failed.sh until all shards succeed ``` ### Text-only: Reformatting dataset diff --git a/configs/config_medtrinity.yaml b/configs/config_medtrinity.yaml index c746854..dbe97cb 100644 --- a/configs/config_medtrinity.yaml +++ b/configs/config_medtrinity.yaml @@ -12,13 +12,14 @@ processors: enable_thinking: false loading_params: + state_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/_pipeline_state datasets: - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/medtrinity_conversations_1/ type: loadable - output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled2 + output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled2/part1 - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/medtrinity_conversations_2/ type: loadable - output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled2 + output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled2/part2 num_shards: 32 shard_id: "$SLURM_ARRAY_TASK_ID" conversations_field: "conversations" diff --git a/retry_failed.sh b/retry_failed.sh index fd42ed0..7afbed9 100644 --- a/retry_failed.sh +++ b/retry_failed.sh @@ -1,58 +1,55 @@ #!/bin/bash -# Check for failed shards and relaunch them +# Check for failed logical shards and relaunch them # # Usage: bash retry_failed.sh -# Configuration -SHARDS_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled2" +set -euo pipefail + +STATE_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/_pipeline_state" NUM_SHARDS=32 MAX_RETRIES=3 SCRIPT_PATH="/users/qchapp/meditron/MIRAGE/run_with_retry.sh" -echo "Checking for failed shards in: $SHARDS_ROOT" +echo "Checking shard states in: $STATE_ROOT" echo "" failed_shards=() success_count=0 -for i in $(seq 0 $((NUM_SHARDS-1))); do - # Find shard directories (may be nested under dataset dirs) - shard_dirs=$(find "$SHARDS_ROOT" -type d -name "shard_$i" 2>/dev/null) - - if [ -z "$shard_dirs" ]; then - echo "āŒ Shard $i: MISSING" - failed_shards+=($i) +for i in $(seq 0 $((NUM_SHARDS - 1))); do + state_dir="$STATE_ROOT/shard_$i" + status_file="$state_dir/status.json" + + if [ ! -f "$status_file" ]; then + echo "āŒ Shard $i: MISSING STATUS" + failed_shards+=("$i") continue fi - - # Check each shard directory for success marker - shard_success=false - for shard_dir in $shard_dirs; do - if [ -f "$shard_dir/.SUCCESS" ]; then - shard_success=true - break - fi - done - - if [ "$shard_success" = true ]; then + + status=$(python - < int: def _dataset_out_dir(shard_idx: int, ds_config: BaseDataLoaderConfig) -> str: - """Get output directory for a shard of a dataset.""" + """Get dataset-specific output directory for a shard.""" return os.path.join(ds_config.output_dir, f"shard_{shard_idx}") +def _shard_state_dir(shard_idx: int, state_root: str) -> str: + """Get central state directory for a logical shard.""" + return os.path.join(state_root, f"shard_{shard_idx}") + + def _shard_dataset(ds: DatasetLike, num_shards: int, shard_id: int) -> DatasetLike: """Shard a dataset or dataset dict.""" if isinstance(ds, DatasetDict): @@ -59,72 +65,131 @@ def _remove_columns(ds: DatasetLike, enable: bool) -> List[str]: return ds.column_names -def _get_retry_count(shard_dir: str) -> int: - """Get retry count for a shard from retry marker file.""" - retry_file = os.path.join(shard_dir, ".retry_count") - if not os.path.exists(retry_file): - return 0 +def _status_file(state_dir: str) -> str: + """Canonical status file path.""" + return os.path.join(state_dir, "status.json") + + +def _read_status(state_dir: str) -> dict: + """Read status.json if present.""" + path = _status_file(state_dir) + if not os.path.exists(path): + return {} try: - with open(retry_file, "r") as f: - return int(f.read().strip()) - except (ValueError, IOError): - return 0 - - -def _increment_retry_count(shard_dir: str) -> int: - """Increment and write retry count for a shard.""" - count = _get_retry_count(shard_dir) + 1 - retry_file = os.path.join(shard_dir, ".retry_count") - os.makedirs(shard_dir, exist_ok=True) - with open(retry_file, "w") as f: - f.write(str(count)) - return count - - -def _cleanup_old_shard_data(shard_dir: str): - """Remove old data files from a shard directory before retry. - - Keeps marker files (.SUCCESS, .FAILED, .retry_count) but removes - arrow files and dataset metadata to prevent duplicates. - """ - if not os.path.exists(shard_dir): - return - - # Patterns for files to remove - patterns_to_remove = [ - "*.arrow", - "dataset_info.json", - "state.json", - ] - - removed_count = 0 - for pattern in patterns_to_remove: - for file_path in glob.glob(os.path.join(shard_dir, pattern)): + with open(path, "r") as f: + return json.load(f) + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"Failed to read status file {path}: {e}") + return {} + + +def _write_status(state_dir: str, payload: dict): + """Atomically write status.json.""" + os.makedirs(state_dir, exist_ok=True) + tmp_path = _status_file(state_dir) + ".tmp" + with open(tmp_path, "w") as f: + json.dump(payload, f, indent=2, sort_keys=True) + os.replace(tmp_path, _status_file(state_dir)) + + +def _clear_markers(state_dir: str): + """Remove status marker files.""" + for name in (".RUNNING", ".SUCCESS", ".FAILED"): + path = os.path.join(state_dir, name) + if os.path.exists(path): try: - os.remove(file_path) - removed_count += 1 + os.remove(path) except OSError as e: - logger.warning(f"Failed to remove {file_path}: {e}") - - if removed_count > 0: - logger.info(f"Cleaned up {removed_count} old data files from {shard_dir}") + logger.warning(f"Failed to remove marker {path}: {e}") + + +def _touch_marker(state_dir: str, name: str): + """Create a marker file.""" + os.makedirs(state_dir, exist_ok=True) + path = os.path.join(state_dir, name) + with open(path, "w") as f: + f.write(f"{datetime.now().isoformat()}\n") + + +def _mark_running( + state_dir: str, + shard_id: int, + datasets_config: List[BaseDataLoaderConfig], +) -> int: + """Mark shard as running and increment retry count.""" + prev = _read_status(state_dir) + retry_count = int(prev.get("retry_count", 0)) + 1 + + payload = { + "status": "running", + "retry_count": retry_count, + "shard_id": shard_id, + "started_at": datetime.now().isoformat(), + "finished_at": None, + "error": None, + "hostname": socket.gethostname(), + "pid": os.getpid(), + "slurm_job_id": os.environ.get("SLURM_JOB_ID"), + "slurm_array_task_id": os.environ.get("SLURM_ARRAY_TASK_ID"), + "datasets": [ + { + "path": ds_config.path, + "output_dir": ds_config.output_dir, + } + for ds_config in datasets_config + ], + } + + _write_status(state_dir, payload) + _clear_markers(state_dir) + _touch_marker(state_dir, ".RUNNING") + return retry_count + + +def _mark_success(state_dir: str): + """Mark shard as successful.""" + prev = _read_status(state_dir) + prev["status"] = "success" + prev["finished_at"] = datetime.now().isoformat() + prev["error"] = None + _write_status(state_dir, prev) + _clear_markers(state_dir) + _touch_marker(state_dir, ".SUCCESS") + + +def _mark_failure(state_dir: str, error_msg: str): + """Mark shard as failed.""" + prev = _read_status(state_dir) + prev["status"] = "failed" + prev["finished_at"] = datetime.now().isoformat() + prev["error"] = error_msg + _write_status(state_dir, prev) + _clear_markers(state_dir) + _touch_marker(state_dir, ".FAILED") + + +def _cleanup_old_shard_data(out_dir: str): + """Remove old dataset shard output before retry.""" + if os.path.exists(out_dir): + shutil.rmtree(out_dir) + logger.info(f"Removed old shard output: {out_dir}") -def _write_success_marker(shard_dir: str): - """Write success marker file for a completed shard.""" - marker_file = os.path.join(shard_dir, ".SUCCESS") - os.makedirs(shard_dir, exist_ok=True) - with open(marker_file, "w") as f: - f.write(f"completed_at: {datetime.now().isoformat()}\n") +def _save_dataset_atomic(ds_processed: DatasetLike, out_dir: str): + """Save dataset atomically via temporary directory + rename.""" + parent_dir = os.path.dirname(out_dir) + os.makedirs(parent_dir, exist_ok=True) + tmp_dir = f"{out_dir}.tmp.{os.getpid()}" + if os.path.exists(tmp_dir): + shutil.rmtree(tmp_dir) -def _write_failure_marker(shard_dir: str, error_msg: str): - """Write failure marker file with error information.""" - marker_file = os.path.join(shard_dir, ".FAILED") - os.makedirs(shard_dir, exist_ok=True) - with open(marker_file, "w") as f: - f.write(f"failed_at: {datetime.now().isoformat()}\n") - f.write(f"error: {error_msg}\n") + ds_processed.save_to_disk(tmp_dir) + + if os.path.exists(out_dir): + shutil.rmtree(out_dir) + + os.replace(tmp_dir, out_dir) def rewrite_batch( @@ -133,20 +198,7 @@ def rewrite_batch( renderer: TemplateRenderer, image_base_path: str = None, ) -> Dict[str, List[Any]]: - """Rewrite a batch of samples by applying transformations. - - Args: - batch: Dictionary mapping column names to lists of values. - mapper: MMIRAGEMapper for processing transformations. - renderer: TemplateRenderer for generating output. - image_base_path: Optional base directory for resolving relative image paths. - - Returns: - Dictionary mapping output keys to lists of rendered values. - - Raises: - ValueError: If variables are not computable given the configuration. - """ + """Rewrite a batch of samples by applying transformations.""" if not mapper.validate_vars(): raise ValueError( "Uncomputable variables detected. Verify your configuration and make sure that there is no undefined variables" @@ -157,15 +209,19 @@ def rewrite_batch( return rendered_list -def main(): - """Process a single shard of the dataset. +def _get_state_root(cfg) -> str: + """Get the shared pipeline state root from config.""" + state_dir = getattr(cfg.loading_params, "state_dir", None) + if not state_dir: + raise ValueError( + "loading_params.state_dir must be set when using multiple datasets with independent output_dir values" + ) + return state_dir - Loads configuration, datasets, processes the shard using MMIRAGE - transformations (including multimodal), and saves the result to disk. - """ - ap = argparse.ArgumentParser( - "Process dataset shards using MMIRAGE with SGLang." - ) + +def main(): + """Process a single logical shard across all configured datasets.""" + ap = argparse.ArgumentParser("Process dataset shards using MMIRAGE with SGLang.") ap.add_argument( "--config", help="YAML config for MMIRAGE pipeline.", @@ -177,6 +233,7 @@ def main(): loading_params = cfg.loading_params processing_params = cfg.processing_params datasets_config = loading_params.datasets + if not datasets_config: raise ValueError("No datasets provided in config.loading_params.datasets") @@ -186,10 +243,18 @@ def main(): if not (0 <= shard_id < num_shards): raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") - # Track shard directories for marker files - shard_dirs = [] + state_root = _get_state_root(cfg) + state_dir = _shard_state_dir(shard_id, state_root) try: + retry_count = _mark_running(state_dir, shard_id, datasets_config) + logger.info(f"Starting shard {shard_id}/{num_shards - 1} (attempt #{retry_count})") + + if retry_count > 1: + for ds_config in datasets_config: + out_dir = _dataset_out_dir(shard_id, ds_config) + _cleanup_old_shard_data(out_dir) + ds_all = load_datasets_from_configs(datasets_config) total_rows = sum(_count_rows(ds) for ds in ds_all) @@ -198,61 +263,56 @@ def main(): logger.info( f"Loaded {len(datasets_config)} dataset(s): {datasets_config} " - f"→ {total_rows} total rows; this shard has {shard_rows} rows." + f"→ {total_rows} total rows; this logical shard has {shard_rows} rows." ) - # Increment retry count for each shard directory - for ds_config in datasets_config: - shard_dir = _dataset_out_dir(shard_id, ds_config) - retry_count = _increment_retry_count(shard_dir) - shard_dirs.append(shard_dir) - if retry_count > 1: - logger.info(f"Retry attempt #{retry_count} for shard {shard_id}") - # Clean up old data files to prevent duplicates - _cleanup_old_shard_data(shard_dir) - mapper = MMIRAGEMapper( - cfg.processors, processing_params.inputs, processing_params.outputs + cfg.processors, + processing_params.inputs, + processing_params.outputs, ) renderer = TemplateRenderer(processing_params.output_schema) + ds_processed_all: List[DatasetLike] = [] for ds_idx, ds_shard in enumerate(ds_all_shard): ds_config = datasets_config[ds_idx] remove_columns = _remove_columns(ds_shard, processing_params.remove_columns) + + logger.info( + f"Processing dataset {ds_idx} for shard {shard_id}: " + f"path={ds_config.path}, output_dir={ds_config.output_dir}" + ) + ds_processed = ds_shard.map( rewrite_batch, batched=True, batch_size=loading_params.get_batch_size(), load_from_cache_file=False, desc=f"Shard {shard_id}/{num_shards - 1} dataset {ds_idx}", - fn_kwargs={"mapper": mapper, "renderer": renderer, "image_base_path": ds_config.image_base_path}, + fn_kwargs={ + "mapper": mapper, + "renderer": renderer, + "image_base_path": ds_config.image_base_path, + }, remove_columns=remove_columns, ) ds_processed_all.append(ds_processed) - for ds_config, ds_processed in zip(datasets_config, ds_processed_all): + for ds_idx, (ds_config, ds_processed) in enumerate(zip(datasets_config, ds_processed_all)): out_dir = _dataset_out_dir(shard_id, ds_config) - os.makedirs(out_dir, exist_ok=True) - ds_processed.save_to_disk(out_dir) - logger.info(f"āœ… Saved dataset in: {out_dir}") + _save_dataset_atomic(ds_processed, out_dir) + logger.info(f"āœ… Saved dataset {ds_idx} shard in: {out_dir}") - # Write success markers for all shards - for shard_dir in shard_dirs: - _write_success_marker(shard_dir) - logger.info(f"āœ… Shard {shard_id} completed successfully") + _mark_success(state_dir) + logger.info(f"āœ… Logical shard {shard_id} completed successfully") except Exception as e: error_msg = f"{type(e).__name__}: {str(e)}" logger.error(f"āŒ Shard {shard_id} failed: {error_msg}") logger.error(traceback.format_exc()) - - # Write failure markers for all shards - for shard_dir in shard_dirs: - _write_failure_marker(shard_dir, error_msg) - - # Re-raise to ensure non-zero exit code + _mark_failure(state_dir, error_msg) sys.exit(1) if __name__ == "__main__": - main() + main() \ No newline at end of file From 6993969559856801ba6014515da0ae32a9de3d9e Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 10 Mar 2026 16:05:57 +0100 Subject: [PATCH 16/47] added state dir in config --- src/mmirage/config/loading.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index ec7faca..c636520 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -1,7 +1,7 @@ """Data loading configuration for MMIRAGE pipeline.""" from dataclasses import dataclass, field -from typing import Union, List, cast +from typing import Union, List, cast, Optional from mmirage.core.loader.base import BaseDataLoaderConfig @@ -15,7 +15,8 @@ class LoadingParams: Attributes: datasets: List of dataset configurations to load. - output_dir: Directory path for saving processed output shards. + state_dir: Shared directory for logical shard state/markers/retry tracking. + output_dir: Legacy top-level output directory. Prefer per-dataset output_dir. num_shards: Total number of shards to split the dataset into. shard_id: ID of this shard (0-indexed). batch_size: Batch size for processing samples. @@ -25,6 +26,7 @@ class LoadingParams: """ datasets: List[BaseDataLoaderConfig] = field(default_factory=list) + state_dir: Optional[str] = None output_dir: str = "" num_shards: Union[int, str] = 1 shard_id: Union[int, str] = 0 @@ -38,18 +40,24 @@ def __post_init__(self): raise ValueError() except (ValueError, TypeError): raise ValueError(f"Invalid value for num_shards: {self.num_shards!r}") + if isinstance(self.shard_id, str): try: self.shard_id = int(self.shard_id) except (ValueError, TypeError): raise ValueError(f"Invalid value for shard_id: {self.shard_id!r}") + if isinstance(self.batch_size, str): try: self.batch_size = int(self.batch_size) except (ValueError, TypeError): raise ValueError(f"Invalid value for batch_size: {self.batch_size!r}") + self.batch_size = max(self.batch_size, 1) + if self.state_dir is not None: + self.state_dir = str(self.state_dir).strip() or None + def get_num_shards(self) -> int: """Get the total number of shards. From f9929110320fdc4578a76853ac14e92d9454f3bb Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Thu, 12 Mar 2026 18:16:35 +0100 Subject: [PATCH 17/47] preparing for PR --- README.md | 8 +- configs/config_medtrinity.yaml | 8 +- run_with_retry.sh | 207 ++++++++++++------------------- src/mmirage/config/loading.py | 10 +- src/mmirage/merge_shards.py | 139 ++------------------- src/mmirage/shard_process.py | 218 +++++---------------------------- src/mmirage/shard_utils.py | 186 ++++++++++++++++++++++++++++ 7 files changed, 317 insertions(+), 459 deletions(-) create mode 100644 src/mmirage/shard_utils.py diff --git a/README.md b/README.md index cb30734..c162e8c 100644 --- a/README.md +++ b/README.md @@ -34,19 +34,21 @@ For testing and scripts that make use of the library, it is advised to create a ### Running with Automatic Retry -Your job scripts now automatically track success/failure with marker files. Run your job as follows: +Your job scripts now automatically track success and failure using marker files. + +Run your job with automatic retry: ```bash bash run_with_retry.sh ``` -If you want you can still do it separately: +Alternatively, you can run the steps separately: ```bash # 1. Launch the job sbatch run.sh -# 2. Look for retries +# 2. Retry failed jobs bash retry_failed.sh ``` diff --git a/configs/config_medtrinity.yaml b/configs/config_medtrinity.yaml index dbe97cb..775f6aa 100644 --- a/configs/config_medtrinity.yaml +++ b/configs/config_medtrinity.yaml @@ -14,12 +14,12 @@ processors: loading_params: state_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/_pipeline_state datasets: - - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/medtrinity_conversations_1/ + - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/dirty/medtrinity_conversations_1/ type: loadable - output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled2/part1 - - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/arrow/medtrinity_conversations_2/ + output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled2/part1 + - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/dirty/medtrinity_conversations_2/ type: loadable - output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity//medtrinity_conversations_sampled2/part2 + output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled2/part2 num_shards: 32 shard_id: "$SLURM_ARRAY_TASK_ID" conversations_field: "conversations" diff --git a/run_with_retry.sh b/run_with_retry.sh index 1295320..edb8f1a 100644 --- a/run_with_retry.sh +++ b/run_with_retry.sh @@ -2,12 +2,11 @@ set -euo pipefail # ========================================================= -# MMIRAGE full pipeline wrapper -# - submits shard-processing SLURM array jobs -# - waits for completion -# - checks missing/failed shards -# - retries failed shards up to MAX_RETRIES -# - writes all terminal output to a global log file +# MMIRAGE pipeline controller +# - submit shard-processing SLURM array +# - wait for completion +# - inspect shard state_dir +# - retry failed shards # ========================================================= # ----------------------------- @@ -19,13 +18,13 @@ RESERVATION="sai-a127" MMIRAGE_CHDIR="/users/qchapp/meditron/MIRAGE/src/mmirage" REPORT_DIR="/users/qchapp/reports" -EDF_ENV="/users/qchapp/.edf/mmirage.toml" +EDF_ENV="/users/qchapp/.edf/sglang.toml" CFG="/users/qchapp/meditron/MIRAGE/configs/config_medtrinity.yaml" -# HF_HOME="${SCRATCH}/hf" HF_HOME="/capstor/store/cscs/swissai/a127/homes/qchapp/hf" -SHARDS_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled" +STATE_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/_pipeline_state" + NUM_SHARDS=32 MAX_RETRIES=3 @@ -36,16 +35,14 @@ GPUS=4 CPUS_PER_TASK=288 TIME_LIMIT="11:59:59" -# Optional: poll interval while waiting for jobs POLL_SECONDS=30 # ----------------------------- -# Logging setup +# Logging # ----------------------------- mkdir -p "$REPORT_DIR" LOG_FILE="$REPORT_DIR/${JOB_NAME}_logs.out" -# Send everything to terminal + logfile exec > >(tee -a "$LOG_FILE") 2>&1 echo "==================================================" @@ -55,7 +52,6 @@ echo "Host : $(hostname)" echo "Start Time : $(date)" echo "Log File : $LOG_FILE" echo "==================================================" -echo "" # ----------------------------- # Environment @@ -66,33 +62,16 @@ export TOTAL_SHARDS="$NUM_SHARDS" mkdir -p "$HF_HOME" -echo "[INFO] Environment snapshot" -echo " MMIRAGE_CHDIR : $MMIRAGE_CHDIR" -echo " CFG : $CFG" -echo " HF_HOME : $HF_HOME" -echo " SHARDS_ROOT : $SHARDS_ROOT" -echo " NUM_SHARDS : $NUM_SHARDS" -echo " MAX_RETRIES : $MAX_RETRIES" -echo "" - -# ----------------------------- -# Retry state -# ----------------------------- -declare -A RETRY_COUNTS -for i in $(seq 0 $((NUM_SHARDS - 1))); do - RETRY_COUNTS[$i]=0 -done - # ----------------------------- -# Submit an array job +# Submit SLURM array job # ----------------------------- submit_array_job() { + local array_spec="$1" - echo "[INFO] Submitting SLURM array job for shards: $array_spec" + echo "[INFO] Submitting job array: $array_spec" - local job_id - job_id=$( + SUBMITTED_JOB_ID=$( sbatch --parsable \ --job-name="$JOB_NAME" \ --chdir="$MMIRAGE_CHDIR" \ @@ -108,111 +87,101 @@ submit_array_job() { --export=ALL,CFG="$CFG",TOTAL_SHARDS="$NUM_SHARDS",HF_HOME="$HF_HOME" \ --wrap=" set -euo pipefail - export CFG='$CFG' - export TOTAL_SHARDS='$NUM_SHARDS' - export HF_HOME='$HF_HOME' - - echo 'START TIME: ' \$(date) - echo 'HOST: ' \$(hostname) - echo 'SLURM_JOB_ID: ' \$SLURM_JOB_ID - echo 'SLURM_ARRAY_TASK_ID: ' \$SLURM_ARRAY_TASK_ID - CMD=\"python /users/qchapp/meditron/MIRAGE/src/mmirage/shard_process.py --config \$CFG\" + echo 'START:' \$(date) + echo 'HOST:' \$(hostname) + echo 'TASK:' \$SLURM_ARRAY_TASK_ID - SRUN_ARGS=\" \ + srun \ --cpus-per-task $CPUS_PER_TASK \ --jobid \$SLURM_JOB_ID \ --wait 60 \ -A $ACCOUNT \ --reservation $RESERVATION \ --environment $EDF_ENV \ - \" - - echo \"COMMAND: \$CMD\" - srun \$SRUN_ARGS bash -c \"\$CMD\" + python shard_process.py --config \$CFG - echo 'END TIME: ' \$(date) + echo 'END:' \$(date) " ) - echo "[INFO] Submitted job ID: $job_id" - SUBMITTED_JOB_ID="$job_id" + echo "[INFO] Submitted job ID: $SUBMITTED_JOB_ID" } # ----------------------------- # Wait for job completion # ----------------------------- wait_for_job() { + local job_id="$1" - echo "[INFO] Waiting for job $job_id to finish..." + echo "[INFO] Waiting for job $job_id" while squeue -h -j "$job_id" | grep -q .; do - echo "[INFO] Job $job_id still running or pending at $(date)" sleep "$POLL_SECONDS" done - echo "[INFO] Job $job_id finished at $(date)" + echo "[INFO] Job finished" } # ----------------------------- -# Check shard results -# Populates the named array passed as $1 +# Inspect shard states # ----------------------------- check_failed_shards() { - local -n out_failed_shards="$1" - out_failed_shards=() - local success_count=0 - local exhausted_count=0 + local -n result="$1" + result=() + + success=0 + exhausted=0 echo "" - echo "[INFO] Checking shard status in: $SHARDS_ROOT" + echo "[INFO] Inspecting shard states" echo "" - for i in $(seq 0 $((NUM_SHARDS - 1))); do - mapfile -t shard_dirs < <(find "$SHARDS_ROOT" -type d -name "shard_$i" 2>/dev/null) - - if [ "${#shard_dirs[@]}" -eq 0 ]; then - if [ "${RETRY_COUNTS[$i]}" -ge "$MAX_RETRIES" ]; then - echo "šŸ›‘ Shard $i: MISSING, max retries exceeded (${RETRY_COUNTS[$i]}/$MAX_RETRIES)" - ((exhausted_count+=1)) - else - echo "āŒ Shard $i: MISSING" - out_failed_shards+=("$i") - fi + for i in $(seq 0 $((NUM_SHARDS-1))); do + + status_file="$STATE_ROOT/shard_$i/status.json" + + if [[ ! -f "$status_file" ]]; then + echo "āŒ shard $i: missing state" + result+=("$i") continue fi - local shard_success=false - for shard_dir in "${shard_dirs[@]}"; do - if [ -f "$shard_dir/.SUCCESS" ]; then - shard_success=true - break - fi - done - - if [ "$shard_success" = true ]; then - echo "āœ… Shard $i: SUCCESS" - ((success_count+=1)) + status=$(python - < ${RETRY_COUNTS[$i]}" - done - echo "" - echo "==================================================" - echo "Pipeline finished with failures at: $(date)" - echo "==================================================" - exit 1 - fi + array_spec=$(IFS=,; echo "${failed[*]}") - array_spec=$(IFS=,; echo "${retryable_shards[*]}") - echo "[INFO] Retrying failed shards: $array_spec" - echo "" + echo "[INFO] Retrying shards: $array_spec" done } diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index c636520..49ab6cf 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -58,6 +58,14 @@ def __post_init__(self): if self.state_dir is not None: self.state_dir = str(self.state_dir).strip() or None + def get_state_root(self) -> str: + """Get the state root path. + + Returns: + str: State root path + """ + return self.state_dir + def get_num_shards(self) -> int: """Get the total number of shards. @@ -80,4 +88,4 @@ def get_batch_size(self) -> int: Returns: int: Batch size (minimum 1). """ - return cast(int, self.batch_size) + return cast(int, self.batch_size) \ No newline at end of file diff --git a/src/mmirage/merge_shards.py b/src/mmirage/merge_shards.py index 99357ca..e290cb7 100644 --- a/src/mmirage/merge_shards.py +++ b/src/mmirage/merge_shards.py @@ -2,8 +2,7 @@ import argparse import os -import sys -from typing import Dict, List, Set, Tuple +from typing import Dict, List from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk @@ -65,63 +64,6 @@ def _shard_key(path: str) -> int: return shard_dirs -def _extract_shard_id(shard_path: str) -> int: - """Extract shard ID from shard directory path.""" - base = os.path.basename(shard_path) - suffix = base.removeprefix("shard_") - return int(suffix) if suffix.isdigit() else -1 - - -def _check_shard_success(shard_dir: str) -> bool: - """Check if a shard completed successfully based on .SUCCESS marker.""" - success_file = os.path.join(shard_dir, ".SUCCESS") - return os.path.exists(success_file) - - -def _check_shard_failed(shard_dir: str) -> bool: - """Check if a shard failed based on .FAILED marker.""" - failed_file = os.path.join(shard_dir, ".FAILED") - return os.path.exists(failed_file) - - -def _analyze_shard_status( - dataset_dir: str, expected_shards: int = None -) -> Tuple[Set[int], Set[int], Set[int], Set[int]]: - """Analyze status of all shards in a dataset directory. - - Args: - dataset_dir: Path to dataset directory containing shards - expected_shards: Expected number of shards (for detecting missing ones) - - Returns: - Tuple of (success_ids, failed_ids, incomplete_ids, missing_ids) - """ - shard_dirs = _list_shard_dirs(dataset_dir) - - success_ids: Set[int] = set() - failed_ids: Set[int] = set() - incomplete_ids: Set[int] = set() - - for shard_dir in shard_dirs: - shard_id = _extract_shard_id(shard_dir) - if shard_id < 0: - continue - - if _check_shard_success(shard_dir): - success_ids.add(shard_id) - elif _check_shard_failed(shard_dir): - failed_ids.add(shard_id) - else: - incomplete_ids.add(shard_id) - - missing_ids: Set[int] = set() - if expected_shards is not None: - existing_ids = success_ids | failed_ids | incomplete_ids - missing_ids = set(range(expected_shards)) - existing_ids - - return success_ids, failed_ids, incomplete_ids, missing_ids - - def _dataset_dirs(input_dir: str) -> List[str]: """Find dataset directories containing shard folders.""" candidates: List[str] = [] @@ -152,22 +94,6 @@ def main(): required=True, help="Directory to write merged datasets into.", ) - ap.add_argument( - "--expected-shards", - type=int, - default=None, - help="Expected number of shards (for detecting missing shards).", - ) - ap.add_argument( - "--fail-on-missing", - action="store_true", - help="Fail if any shards are missing, failed, or incomplete.", - ) - ap.add_argument( - "--check-markers", - action="store_true", - help="Check for .SUCCESS markers and report shard status.", - ) args = ap.parse_args() input_dir = args.input_dir @@ -185,60 +111,15 @@ def main(): ) for dataset_dir in dataset_dirs: - dataset_name = os.path.basename(dataset_dir) - print(f"\n{'='*60}") - print(f"Processing dataset: {dataset_name}") - print(f"{'='*60}") - shard_dirs = _list_shard_dirs(dataset_dir) if not shard_dirs: continue - # Check shard status if requested - if args.check_markers or args.expected_shards is not None: - success_ids, failed_ids, incomplete_ids, missing_ids = _analyze_shard_status( - dataset_dir, args.expected_shards - ) - - total_found = len(success_ids) + len(failed_ids) + len(incomplete_ids) - total_expected = args.expected_shards if args.expected_shards else total_found - - print(f"\nšŸ“Š Shard Status Report:") - print(f" āœ… Successful: {len(success_ids)} / {total_expected}") - print(f" āŒ Failed: {len(failed_ids)}") - print(f" āš ļø Incomplete: {len(incomplete_ids)}") - print(f" ā“ Missing: {len(missing_ids)}") - - if failed_ids: - print(f"\n Failed shard IDs: {sorted(failed_ids)}") - if incomplete_ids: - print(f" Incomplete shard IDs: {sorted(incomplete_ids)}") - if missing_ids: - print(f" Missing shard IDs: {sorted(missing_ids)}") - - # Check if we should fail - has_problems = bool(failed_ids or incomplete_ids or missing_ids) - if has_problems and args.fail_on_missing: - print(f"\nāŒ ERROR: Found failed/missing/incomplete shards and --fail-on-missing is set") - sys.exit(1) - elif has_problems: - print(f"\nāš ļø WARNING: Some shards are incomplete - merged dataset may be missing data") - print(f" Consider running failure detection and relaunching failed shards:") - print(f" python src/mmirage/detect_failures.py --input-dir {input_dir} --num-shards {total_expected}") - shard_dsets: List[DatasetLike] = [] skipped_empty_dir = 0 skipped_zero_rows = 0 - skipped_failed = 0 for shard_dir in shard_dirs: - # Skip explicitly failed shards if check-markers is enabled - if args.check_markers and _check_shard_failed(shard_dir): - shard_id = _extract_shard_id(shard_dir) - print(f"āš ļø Skipping failed shard {shard_id}: {shard_dir}") - skipped_failed += 1 - continue - try: ds = load_from_disk(shard_dir) except FileNotFoundError as e: @@ -261,14 +142,13 @@ def main(): raise RuntimeError( f"No non-empty shards found in {dataset_dir}. " f"empty/invalid dirs: {skipped_empty_dir}, " - f"zero-row datasets: {skipped_zero_rows}, " - f"failed shards: {skipped_failed}." + f"zero-row datasets: {skipped_zero_rows}." ) ds_merged = _merge_shards(shard_dsets) n_rows = _count_rows(ds_merged) - total_skipped = skipped_empty_dir + skipped_zero_rows + skipped_failed + total_skipped = skipped_empty_dir + skipped_zero_rows if dataset_dir == input_dir: ds_out_dir = output_dir @@ -281,15 +161,12 @@ def main(): ds_merged.save_to_disk(ds_out_dir) print( - f"\nāœ… Concatenated {len(shard_dsets)} shards for {dataset_name} " - f"with {n_rows} rows." + f"āœ… Concatenated {len(shard_dsets)} shards for {dataset_name} " + f"with {n_rows} rows.\n" + f" Skipped shards: {total_skipped} total " + f"(empty/invalid dir: {skipped_empty_dir}, zero rows: {skipped_zero_rows})." ) - if total_skipped > 0: - print( - f" Skipped shards: {total_skipped} total " - f"(empty/invalid: {skipped_empty_dir}, zero rows: {skipped_zero_rows}, failed: {skipped_failed})" - ) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 5b1b7e1..2a7390a 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -4,201 +4,49 @@ """ import argparse -from datetime import datetime -from functools import reduce -import json import logging -import os -import shutil -import socket import sys import traceback from typing import Any, Dict, List -from datasets import DatasetDict - from mmirage.config.utils import load_mmirage_config -from mmirage.core.loader.base import BaseDataLoaderConfig, DatasetLike +from mmirage.core.loader.base import DatasetLike from mmirage.core.loader.utils import load_datasets_from_configs from mmirage.core.process.mapper import MMIRAGEMapper from mmirage.core.writer.renderer import TemplateRenderer +from mmirage.shard_utils import ( + _cleanup_old_shard_data, + _count_rows, + _dataset_out_dir, + _mark_failure, + _mark_running, + _mark_success, + _remove_columns, + _save_dataset_atomic, + _shard_dataset, + _shard_state_dir, +) logger = logging.getLogger(__name__) -def _count_rows(ds: DatasetLike) -> int: - """Count total rows in a dataset or dataset dict.""" - if isinstance(ds, DatasetDict): - return sum(len(split) for split in ds.values()) - return len(ds) - - -def _dataset_out_dir(shard_idx: int, ds_config: BaseDataLoaderConfig) -> str: - """Get dataset-specific output directory for a shard.""" - return os.path.join(ds_config.output_dir, f"shard_{shard_idx}") - - -def _shard_state_dir(shard_idx: int, state_root: str) -> str: - """Get central state directory for a logical shard.""" - return os.path.join(state_root, f"shard_{shard_idx}") - - -def _shard_dataset(ds: DatasetLike, num_shards: int, shard_id: int) -> DatasetLike: - """Shard a dataset or dataset dict.""" - if isinstance(ds, DatasetDict): - return DatasetDict( - { - split: split_ds.shard(num_shards=num_shards, index=shard_id) - for split, split_ds in ds.items() - } - ) - return ds.shard(num_shards=num_shards, index=shard_id) - - -def _remove_columns(ds: DatasetLike, enable: bool) -> List[str]: - """Get columns to remove from dataset if enabled.""" - if not enable: - return [] - if isinstance(ds, DatasetDict): - columns_set = [set(split_ds.column_names) for split_ds in ds.values()] - return list(reduce(lambda x, y: x | y, columns_set)) - return ds.column_names - - -def _status_file(state_dir: str) -> str: - """Canonical status file path.""" - return os.path.join(state_dir, "status.json") - - -def _read_status(state_dir: str) -> dict: - """Read status.json if present.""" - path = _status_file(state_dir) - if not os.path.exists(path): - return {} - try: - with open(path, "r") as f: - return json.load(f) - except (json.JSONDecodeError, OSError) as e: - logger.warning(f"Failed to read status file {path}: {e}") - return {} - - -def _write_status(state_dir: str, payload: dict): - """Atomically write status.json.""" - os.makedirs(state_dir, exist_ok=True) - tmp_path = _status_file(state_dir) + ".tmp" - with open(tmp_path, "w") as f: - json.dump(payload, f, indent=2, sort_keys=True) - os.replace(tmp_path, _status_file(state_dir)) - - -def _clear_markers(state_dir: str): - """Remove status marker files.""" - for name in (".RUNNING", ".SUCCESS", ".FAILED"): - path = os.path.join(state_dir, name) - if os.path.exists(path): - try: - os.remove(path) - except OSError as e: - logger.warning(f"Failed to remove marker {path}: {e}") - - -def _touch_marker(state_dir: str, name: str): - """Create a marker file.""" - os.makedirs(state_dir, exist_ok=True) - path = os.path.join(state_dir, name) - with open(path, "w") as f: - f.write(f"{datetime.now().isoformat()}\n") - - -def _mark_running( - state_dir: str, - shard_id: int, - datasets_config: List[BaseDataLoaderConfig], -) -> int: - """Mark shard as running and increment retry count.""" - prev = _read_status(state_dir) - retry_count = int(prev.get("retry_count", 0)) + 1 - - payload = { - "status": "running", - "retry_count": retry_count, - "shard_id": shard_id, - "started_at": datetime.now().isoformat(), - "finished_at": None, - "error": None, - "hostname": socket.gethostname(), - "pid": os.getpid(), - "slurm_job_id": os.environ.get("SLURM_JOB_ID"), - "slurm_array_task_id": os.environ.get("SLURM_ARRAY_TASK_ID"), - "datasets": [ - { - "path": ds_config.path, - "output_dir": ds_config.output_dir, - } - for ds_config in datasets_config - ], - } - - _write_status(state_dir, payload) - _clear_markers(state_dir) - _touch_marker(state_dir, ".RUNNING") - return retry_count - - -def _mark_success(state_dir: str): - """Mark shard as successful.""" - prev = _read_status(state_dir) - prev["status"] = "success" - prev["finished_at"] = datetime.now().isoformat() - prev["error"] = None - _write_status(state_dir, prev) - _clear_markers(state_dir) - _touch_marker(state_dir, ".SUCCESS") - - -def _mark_failure(state_dir: str, error_msg: str): - """Mark shard as failed.""" - prev = _read_status(state_dir) - prev["status"] = "failed" - prev["finished_at"] = datetime.now().isoformat() - prev["error"] = error_msg - _write_status(state_dir, prev) - _clear_markers(state_dir) - _touch_marker(state_dir, ".FAILED") - - -def _cleanup_old_shard_data(out_dir: str): - """Remove old dataset shard output before retry.""" - if os.path.exists(out_dir): - shutil.rmtree(out_dir) - logger.info(f"Removed old shard output: {out_dir}") - - -def _save_dataset_atomic(ds_processed: DatasetLike, out_dir: str): - """Save dataset atomically via temporary directory + rename.""" - parent_dir = os.path.dirname(out_dir) - os.makedirs(parent_dir, exist_ok=True) - - tmp_dir = f"{out_dir}.tmp.{os.getpid()}" - if os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir) - - ds_processed.save_to_disk(tmp_dir) - - if os.path.exists(out_dir): - shutil.rmtree(out_dir) - - os.replace(tmp_dir, out_dir) - - def rewrite_batch( batch: Dict[str, List[Any]], mapper: MMIRAGEMapper, renderer: TemplateRenderer, image_base_path: str = None, ) -> Dict[str, List[Any]]: - """Rewrite a batch of samples by applying transformations.""" + """Rewrite a batch of samples by applying transformations. + Args: + batch: Dictionary mapping column names to lists of values. + mapper: MMIRAGEMapper for processing transformations. + renderer: TemplateRenderer for generating output. + image_base_path: Optional base directory for resolving relative image paths. + Returns: + Dictionary mapping output keys to lists of rendered values. + Raises: + ValueError: If variables are not computable given the configuration. + """ if not mapper.validate_vars(): raise ValueError( "Uncomputable variables detected. Verify your configuration and make sure that there is no undefined variables" @@ -209,18 +57,12 @@ def rewrite_batch( return rendered_list -def _get_state_root(cfg) -> str: - """Get the shared pipeline state root from config.""" - state_dir = getattr(cfg.loading_params, "state_dir", None) - if not state_dir: - raise ValueError( - "loading_params.state_dir must be set when using multiple datasets with independent output_dir values" - ) - return state_dir - - def main(): - """Process a single logical shard across all configured datasets.""" + """ + Process a single shard of the dataset. + Loads configuration, datasets, processes the shard using MMIRAGE + transformations (including multimodal), and saves the result to disk. + """ ap = argparse.ArgumentParser("Process dataset shards using MMIRAGE with SGLang.") ap.add_argument( "--config", @@ -243,7 +85,7 @@ def main(): if not (0 <= shard_id < num_shards): raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") - state_root = _get_state_root(cfg) + state_root = loading_params.get_state_root() state_dir = _shard_state_dir(shard_id, state_root) try: diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py new file mode 100644 index 0000000..2ae1c9d --- /dev/null +++ b/src/mmirage/shard_utils.py @@ -0,0 +1,186 @@ +"""Utility functions for shard processing. + +This module contains helper functions for dataset sharding, state management, +and file operations used in the MMIRAGE shard processing pipeline. +""" + +from datetime import datetime +from functools import reduce +import json +import logging +import os +import shutil +import socket +from typing import Any, Dict, List + +from datasets import DatasetDict + +from mmirage.core.loader.base import BaseDataLoaderConfig, DatasetLike + +logger = logging.getLogger(__name__) + + +def _count_rows(ds: DatasetLike) -> int: + """Count total rows in a dataset or dataset dict.""" + if isinstance(ds, DatasetDict): + return sum(len(split) for split in ds.values()) + return len(ds) + + +def _shard_dataset(ds: DatasetLike, num_shards: int, shard_id: int) -> DatasetLike: + """Shard a dataset or dataset dict.""" + if isinstance(ds, DatasetDict): + return DatasetDict( + { + split: split_ds.shard(num_shards=num_shards, index=shard_id) + for split, split_ds in ds.items() + } + ) + return ds.shard(num_shards=num_shards, index=shard_id) + + +def _remove_columns(ds: DatasetLike, enable: bool) -> List[str]: + """Get columns to remove from dataset if enabled.""" + if not enable: + return [] + if isinstance(ds, DatasetDict): + columns_set = [set(split_ds.column_names) for split_ds in ds.values()] + return list(reduce(lambda x, y: x | y, columns_set)) + return ds.column_names + + +def _save_dataset_atomic(ds_processed: DatasetLike, out_dir: str): + """Save dataset atomically via temporary directory + rename.""" + parent_dir = os.path.dirname(out_dir) + os.makedirs(parent_dir, exist_ok=True) + + tmp_dir = f"{out_dir}.tmp.{os.getpid()}" + if os.path.exists(tmp_dir): + shutil.rmtree(tmp_dir) + + ds_processed.save_to_disk(tmp_dir) + + if os.path.exists(out_dir): + shutil.rmtree(out_dir) + + os.replace(tmp_dir, out_dir) + + +def _dataset_out_dir(shard_idx: int, ds_config: BaseDataLoaderConfig) -> str: + """Get dataset-specific output directory for a shard.""" + return os.path.join(ds_config.output_dir, f"shard_{shard_idx}") + + +def _shard_state_dir(shard_idx: int, state_root: str) -> str: + """Get central state directory for a logical shard.""" + return os.path.join(state_root, f"shard_{shard_idx}") + + +def _cleanup_old_shard_data(out_dir: str): + """Remove old dataset shard output before retry.""" + if os.path.exists(out_dir): + shutil.rmtree(out_dir) + logger.info(f"Removed old shard output: {out_dir}") + + +def _status_file(state_dir: str) -> str: + """Canonical status file path.""" + return os.path.join(state_dir, "status.json") + + +def _read_status(state_dir: str) -> dict: + """Read status.json if present.""" + path = _status_file(state_dir) + if not os.path.exists(path): + return {} + try: + with open(path, "r") as f: + return json.load(f) + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"Failed to read status file {path}: {e}") + return {} + + +def _write_status(state_dir: str, payload: dict): + """Atomically write status.json.""" + os.makedirs(state_dir, exist_ok=True) + tmp_path = _status_file(state_dir) + ".tmp" + with open(tmp_path, "w") as f: + json.dump(payload, f, indent=2, sort_keys=True) + os.replace(tmp_path, _status_file(state_dir)) + + +def _clear_markers(state_dir: str): + """Remove status marker files.""" + for name in (".RUNNING", ".SUCCESS", ".FAILED"): + path = os.path.join(state_dir, name) + if os.path.exists(path): + try: + os.remove(path) + except OSError as e: + logger.warning(f"Failed to remove marker {path}: {e}") + + +def _touch_marker(state_dir: str, name: str): + """Create a marker file.""" + os.makedirs(state_dir, exist_ok=True) + path = os.path.join(state_dir, name) + with open(path, "w") as f: + f.write(f"{datetime.now().isoformat()}\n") + + +def _mark_running( + state_dir: str, + shard_id: int, + datasets_config: List[BaseDataLoaderConfig], +) -> int: + """Mark shard as running and increment retry count.""" + prev = _read_status(state_dir) + retry_count = int(prev.get("retry_count", 0)) + 1 + + payload = { + "status": "running", + "retry_count": retry_count, + "shard_id": shard_id, + "started_at": datetime.now().isoformat(), + "finished_at": None, + "error": None, + "hostname": socket.gethostname(), + "pid": os.getpid(), + "slurm_job_id": os.environ.get("SLURM_JOB_ID"), + "slurm_array_task_id": os.environ.get("SLURM_ARRAY_TASK_ID"), + "datasets": [ + { + "path": ds_config.path, + "output_dir": ds_config.output_dir, + } + for ds_config in datasets_config + ], + } + + _write_status(state_dir, payload) + _clear_markers(state_dir) + _touch_marker(state_dir, ".RUNNING") + return retry_count + + +def _mark_success(state_dir: str): + """Mark shard as successful.""" + prev = _read_status(state_dir) + prev["status"] = "success" + prev["finished_at"] = datetime.now().isoformat() + prev["error"] = None + _write_status(state_dir, prev) + _clear_markers(state_dir) + _touch_marker(state_dir, ".SUCCESS") + + +def _mark_failure(state_dir: str, error_msg: str): + """Mark shard as failed.""" + prev = _read_status(state_dir) + prev["status"] = "failed" + prev["finished_at"] = datetime.now().isoformat() + prev["error"] = error_msg + _write_status(state_dir, prev) + _clear_markers(state_dir) + _touch_marker(state_dir, ".FAILED") From c8b51a975b46c9b2f2c32cfb23de26c253ad363d Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Fri, 13 Mar 2026 17:01:39 +0100 Subject: [PATCH 18/47] preparing for PR --- .gitignore | 3 + configs/config_medtrinity.yaml | 67 ---------- configs/config_mock.yaml | 1 + configs/config_mock_vision.yaml | 1 + retry_failed.sh | 4 +- run.sh | 10 +- run_medtrinity.sh | 33 ----- run_with_retry.sh | 212 ++++++++++++++++---------------- 8 files changed, 118 insertions(+), 213 deletions(-) delete mode 100644 configs/config_medtrinity.yaml delete mode 100644 run_medtrinity.sh diff --git a/.gitignore b/.gitignore index cf139bc..908494c 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,6 @@ else/ # Test outputs tests/mock_data/output/ tests/mock_data/shards/ + +# devcontainer +.devcontainer/ \ No newline at end of file diff --git a/configs/config_medtrinity.yaml b/configs/config_medtrinity.yaml deleted file mode 100644 index 775f6aa..0000000 --- a/configs/config_medtrinity.yaml +++ /dev/null @@ -1,67 +0,0 @@ -processors: - - type: llm - server_args: - model_path: Qwen/Qwen3-4B-Instruct-2507 - tp_size: 4 - disable_custom_all_reduce: true - default_sampling_params: - temperature: 0.1 - top_p: 0.9 - max_new_tokens: 1024 - chat_template_kwargs: - enable_thinking: false - -loading_params: - state_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/_pipeline_state - datasets: - - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/dirty/medtrinity_conversations_1/ - type: loadable - output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled2/part1 - - path: /capstor/store/cscs/swissai/a127/meditron/multimediset/dirty/medtrinity_conversations_2/ - type: loadable - output_dir: /capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/medtrinity_conversations_sampled2/part2 - num_shards: 32 - shard_id: "$SLURM_ARRAY_TASK_ID" - conversations_field: "conversations" - batch_size: 128 - -processing_params: - inputs: - - name: assistant_answer - key: conversations[1].content - - name: user_prompt - key: conversations[0].content - - name: modalities - key: modalities - - outputs: - - name: formatted_answer - type: llm - output_type: plain - prompt: | - You will receive a JSON object with a single field "assistant_text". - - Task: - - Rewrite "assistant_text" as clearer, well-structured **Markdown**. - - You may add headings, bullet points, and simple formatting to improve readability. - - Keep the original meaning; do **not** invent new facts or interpretations. - - Output: - - Return only the rewritten Markdown text as a plain string. - - Do NOT wrap it in JSON, quotes, or code fences. - - Input: - {{ assistant_answer }} - output_schema: - - question - - explanation - - answer - - remove_columns: True - output_schema: - conversations: - - role: user - content: "{{ user_prompt }}" - - role: assistant - content: "{{ formatted_answer }}" - modalities: "{{ modalities }}" \ No newline at end of file diff --git a/configs/config_mock.yaml b/configs/config_mock.yaml index 1f6c533..a2644c5 100644 --- a/configs/config_mock.yaml +++ b/configs/config_mock.yaml @@ -13,6 +13,7 @@ processors: enable_thinking: false loading_params: + state_dir: tests/output/data/_pipeline_state datasets: - path: tests/mock_data/data.jsonl type: JSONL diff --git a/configs/config_mock_vision.yaml b/configs/config_mock_vision.yaml index c86172c..af08754 100644 --- a/configs/config_mock_vision.yaml +++ b/configs/config_mock_vision.yaml @@ -11,6 +11,7 @@ processors: max_new_tokens: 512 loading_params: + state_dir: tests/output/data_vision/_pipeline_state datasets: - path: tests/mock_data_vision/data.jsonl type: JSONL diff --git a/retry_failed.sh b/retry_failed.sh index 7afbed9..22b6a85 100644 --- a/retry_failed.sh +++ b/retry_failed.sh @@ -5,10 +5,10 @@ set -euo pipefail -STATE_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/_pipeline_state" +STATE_ROOT="/users/$USER/meditron/MMIRAGE/tests/output/data/_pipeline_state" NUM_SHARDS=32 MAX_RETRIES=3 -SCRIPT_PATH="/users/qchapp/meditron/MIRAGE/run_with_retry.sh" +SCRIPT_PATH="/users/$USER/meditron/MMIRAGE/run.sh" echo "Checking shard states in: $STATE_ROOT" echo "" diff --git a/run.sh b/run.sh index 5cfc08a..e5a8606 100644 --- a/run.sh +++ b/run.sh @@ -12,17 +12,12 @@ #SBATCH --array=0-3 # --- outputs & config --- -export ROOT=$SCRATCH/mmirage_output -export SHARDS_ROOT="$ROOT/shards" -export MERGED_DIR="$ROOT/merged" export CFG=$MMIRAGE_PATH/configs/config_small.yaml +export TOTAL_SHARDS=32 # Total number of shards (used for retries) # HF cache/home export HF_HOME=$SCRATCH/hf -mkdir -p "$SHARDS_ROOT" -mkdir -p "$MERGED_DIR" - export CMD="python $MMIRAGE_PATH/src/mmirage/shard_process.py --config $CFG" SRUN_ARGS=" \ @@ -35,5 +30,4 @@ SRUN_ARGS=" \ " # bash -c is needed for the delayed interpolation of env vars to work srun $SRUN_ARGS bash -c "$CMD" -echo "END TIME: $(date)" - +echo "END TIME: $(date)" \ No newline at end of file diff --git a/run_medtrinity.sh b/run_medtrinity.sh deleted file mode 100644 index 590d79f..0000000 --- a/run_medtrinity.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=mmirage-medtrinity -#SBATCH --chdir=/users/qchapp/meditron/MIRAGE/src/mmirage -#SBATCH --output=/users/qchapp/reports/R-%x.%A_%a.out -#SBATCH --error=/users/qchapp/reports/R-%x.%A_%a.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=288 -#SBATCH --time=11:59:59 -#SBATCH -A a127 -#SBATCH --array=0-31 - -# --- outputs & config --- -export CFG=/users/qchapp/meditron/MIRAGE/configs/config_medtrinity.yaml - -# HF cache/home -export HF_HOME=/capstor/store/cscs/swissai/a127/homes/qchapp/hf - -export CMD="python /users/qchapp/meditron/MIRAGE/src/mmirage/shard_process.py --config $CFG" - -SRUN_ARGS=" \ - --cpus-per-task $SLURM_CPUS_PER_TASK \ - --jobid $SLURM_JOB_ID \ - --wait 60 \ - -A a127 \ - --reservation sai-a127 \ - --environment /users/qchapp/.edf/sglang.toml - " -# bash -c is needed for the delayed interpolation of env vars to work -srun $SRUN_ARGS bash -c "$CMD" -echo "END TIME: $(date)" - diff --git a/run_with_retry.sh b/run_with_retry.sh index edb8f1a..d3252a8 100644 --- a/run_with_retry.sh +++ b/run_with_retry.sh @@ -1,34 +1,21 @@ #!/bin/bash set -euo pipefail -# ========================================================= -# MMIRAGE pipeline controller -# - submit shard-processing SLURM array -# - wait for completion -# - inspect shard state_dir -# - retry failed shards -# ========================================================= - -# ----------------------------- -# User configuration -# ----------------------------- JOB_NAME="mmirage-sharded" ACCOUNT="a127" -RESERVATION="sai-a127" +RESERVATION="" # e.g. "sai-a127" if needed -MMIRAGE_CHDIR="/users/qchapp/meditron/MIRAGE/src/mmirage" -REPORT_DIR="/users/qchapp/reports" -EDF_ENV="/users/qchapp/.edf/sglang.toml" +MMIRAGE_CHDIR="/users/$USER/meditron/MMIRAGE/src/mmirage" +REPORT_DIR="/users/$USER/reports" +EDF_ENV="/users/$USER/.edf/mmirage.toml" -CFG="/users/qchapp/meditron/MIRAGE/configs/config_medtrinity.yaml" -HF_HOME="/capstor/store/cscs/swissai/a127/homes/qchapp/hf" - -STATE_ROOT="/capstor/store/cscs/swissai/a127/homes/qchapp/datasets/medtrinity/_pipeline_state" +CFG="/users/$USER/meditron/MMIRAGE/configs/config_mock.yaml" +HF_HOME="/capstor/store/cscs/swissai/a127/homes/$USER/hf" +STATE_ROOT="/users/$USER/meditron/MMIRAGE/tests/output/data/_pipeline_state" NUM_SHARDS=32 MAX_RETRIES=3 -# SLURM resources NODES=1 NTASKS_PER_NODE=1 GPUS=4 @@ -36,15 +23,16 @@ CPUS_PER_TASK=288 TIME_LIMIT="11:59:59" POLL_SECONDS=30 +SETTLE_SECONDS=60 +SETTLE_POLL=10 -# ----------------------------- -# Logging -# ----------------------------- -mkdir -p "$REPORT_DIR" +mkdir -p "$REPORT_DIR" "$HF_HOME" LOG_FILE="$REPORT_DIR/${JOB_NAME}_logs.out" - exec > >(tee -a "$LOG_FILE") 2>&1 +export CFG HF_HOME +export TOTAL_SHARDS="$NUM_SHARDS" + echo "==================================================" echo "Pipeline : $JOB_NAME" echo "User : $USER" @@ -53,26 +41,17 @@ echo "Start Time : $(date)" echo "Log File : $LOG_FILE" echo "==================================================" -# ----------------------------- -# Environment -# ----------------------------- -export HF_HOME -export CFG -export TOTAL_SHARDS="$NUM_SHARDS" - -mkdir -p "$HF_HOME" - -# ----------------------------- -# Submit SLURM array job -# ----------------------------- submit_array_job() { - local array_spec="$1" + local extra=() + + [[ -n "$RESERVATION" ]] && extra+=(--reservation="$RESERVATION") echo "[INFO] Submitting job array: $array_spec" SUBMITTED_JOB_ID=$( sbatch --parsable \ + "${extra[@]}" \ --job-name="$JOB_NAME" \ --chdir="$MMIRAGE_CHDIR" \ --output="$REPORT_DIR/R-%x.%A_%a.out" \ @@ -87,122 +66,150 @@ submit_array_job() { --export=ALL,CFG="$CFG",TOTAL_SHARDS="$NUM_SHARDS",HF_HOME="$HF_HOME" \ --wrap=" set -euo pipefail - - echo 'START:' \$(date) - echo 'HOST:' \$(hostname) - echo 'TASK:' \$SLURM_ARRAY_TASK_ID + echo START: \$(date) + echo HOST: \$(hostname) + echo TASK: \$SLURM_ARRAY_TASK_ID srun \ - --cpus-per-task $CPUS_PER_TASK \ - --jobid \$SLURM_JOB_ID \ - --wait 60 \ - -A $ACCOUNT \ - --reservation $RESERVATION \ - --environment $EDF_ENV \ - python shard_process.py --config \$CFG - - echo 'END:' \$(date) + --cpus-per-task=$CPUS_PER_TASK \ + --wait=60 \ + --environment=$EDF_ENV \ + python \"$MMIRAGE_CHDIR/shard_process.py\" --config \$CFG + + echo END: \$(date) " ) - echo "[INFO] Submitted job ID: $SUBMITTED_JOB_ID" + SUBMITTED_JOB_ID="${SUBMITTED_JOB_ID%%;*}" + echo "[INFO] Job ID: $SUBMITTED_JOB_ID" } -# ----------------------------- -# Wait for job completion -# ----------------------------- wait_for_job() { - local job_id="$1" - echo "[INFO] Waiting for job $job_id" + echo "[INFO] Waiting for job array $job_id" - while squeue -h -j "$job_id" | grep -q .; do + while true; do + if [[ -z "$(squeue -j "$job_id" -h 2>/dev/null || true)" ]]; then + break + fi + squeue -j "$job_id" -o "%.18i %.10T %.10M %.20R" sleep "$POLL_SECONDS" done - echo "[INFO] Job finished" + echo "[INFO] Job array $job_id finished" } -# ----------------------------- -# Inspect shard states -# ----------------------------- -check_failed_shards() { - - local -n result="$1" - result=() - - success=0 - exhausted=0 - - echo "" - echo "[INFO] Inspecting shard states" - echo "" - - for i in $(seq 0 $((NUM_SHARDS-1))); do - - status_file="$STATE_ROOT/shard_$i/status.json" +get_status() { + local shard="$1" + local status_file="$STATE_ROOT/shard_$shard/status.json" - if [[ ! -f "$status_file" ]]; then - echo "āŒ shard $i: missing state" - result+=("$i") - continue - fi + if [[ ! -f "$status_file" ]]; then + echo "missing" + return + fi - status=$(python - </dev/null || echo "unknown" import json with open("$status_file") as f: - print(json.load(f).get("status")) + print(json.load(f).get("status", "unknown")) PY -) +} + +get_retry_count() { + local shard="$1" + local status_file="$STATE_ROOT/shard_$shard/status.json" + + if [[ ! -f "$status_file" ]]; then + echo 0 + return + fi - retry=$(python - </dev/null || echo 0 import json with open("$status_file") as f: - print(json.load(f).get("retry_count",0)) + print(int(json.load(f).get("retry_count", 0))) PY -) +} + +wait_for_settle() { + local waited=0 + + echo "[INFO] Waiting up to ${SETTLE_SECONDS}s for shard states to settle" + + while (( waited < SETTLE_SECONDS )); do + local running=0 + + for i in $(seq 0 $((NUM_SHARDS - 1))); do + [[ "$(get_status "$i")" == "running" ]] && ((running+=1)) + done + + if (( running == 0 )); then + echo "[INFO] State files settled" + return + fi + + echo "[INFO] $running shard(s) still marked running" + sleep "$SETTLE_POLL" + ((waited+=SETTLE_POLL)) + done + + echo "[INFO] Continuing after settle timeout" +} + +check_failed_shards() { + local -n failed_ref=$1 + failed_ref=() + + local success=0 + local exhausted=0 + local running=0 + + echo + echo "[INFO] Inspecting shard states" + echo + + for i in $(seq 0 $((NUM_SHARDS - 1))); do + local status retry + status="$(get_status "$i")" + retry="$(get_retry_count "$i")" if [[ "$status" == "success" ]]; then echo "āœ… shard $i: success" ((success+=1)) - + elif [[ "$status" == "running" ]]; then + echo "ā³ shard $i: still running in state file (retry=$retry)" + ((running+=1)) elif [[ "$retry" -ge "$MAX_RETRIES" ]]; then echo "šŸ›‘ shard $i: retries exhausted ($retry)" ((exhausted+=1)) - else echo "āŒ shard $i: $status (retry=$retry)" - result+=("$i") + failed_ref+=("$i") fi done - echo "" + echo echo "Summary" echo " success: $success / $NUM_SHARDS" - echo " retry: ${#result[@]}" + echo " retry: ${#failed_ref[@]}" + echo " still running: $running" echo " exhausted: $exhausted" - echo "" + echo } -# ----------------------------- -# Main loop -# ----------------------------- main() { - local failed=() - local array_spec="0-$((NUM_SHARDS-1))" + local array_spec="0-$((NUM_SHARDS - 1))" while true; do - echo "--------------------------------------------------" echo "[INFO] Starting iteration at $(date)" echo "--------------------------------------------------" submit_array_job "$array_spec" - wait_for_job "$SUBMITTED_JOB_ID" - + wait_for_settle check_failed_shards failed if [[ ${#failed[@]} -eq 0 ]]; then @@ -211,7 +218,6 @@ main() { fi array_spec=$(IFS=,; echo "${failed[*]}") - echo "[INFO] Retrying shards: $array_spec" done } From 49cfc2c2d9bca0f72ffc69d24ce7fdde6a66203f Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Fri, 13 Mar 2026 17:21:33 +0100 Subject: [PATCH 19/47] maybe now --- run_with_retry.sh | 53 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/run_with_retry.sh b/run_with_retry.sh index d3252a8..27ff65d 100644 --- a/run_with_retry.sh +++ b/run_with_retry.sh @@ -1,5 +1,5 @@ #!/bin/bash -set -euo pipefail +set -uo pipefail JOB_NAME="mmirage-sharded" ACCOUNT="a127" @@ -45,11 +45,14 @@ submit_array_job() { local array_spec="$1" local extra=() - [[ -n "$RESERVATION" ]] && extra+=(--reservation="$RESERVATION") + if [[ -n "$RESERVATION" ]]; then + extra+=(--reservation="$RESERVATION") + fi echo "[INFO] Submitting job array: $array_spec" - SUBMITTED_JOB_ID=$( + local submitted + submitted=$( sbatch --parsable \ "${extra[@]}" \ --job-name="$JOB_NAME" \ @@ -65,7 +68,7 @@ submit_array_job() { --array="$array_spec" \ --export=ALL,CFG="$CFG",TOTAL_SHARDS="$NUM_SHARDS",HF_HOME="$HF_HOME" \ --wrap=" - set -euo pipefail + set -uo pipefail echo START: \$(date) echo HOST: \$(hostname) echo TASK: \$SLURM_ARRAY_TASK_ID @@ -76,11 +79,19 @@ submit_array_job() { --environment=$EDF_ENV \ python \"$MMIRAGE_CHDIR/shard_process.py\" --config \$CFG + rc=\$? + echo PYTHON_EXIT: \$rc echo END: \$(date) + exit \$rc " ) - SUBMITTED_JOB_ID="${SUBMITTED_JOB_ID%%;*}" + if [[ -z "$submitted" ]]; then + echo "[ERROR] sbatch submission failed" + exit 1 + fi + + SUBMITTED_JOB_ID="${submitted%%;*}" echo "[INFO] Job ID: $SUBMITTED_JOB_ID" } @@ -90,9 +101,13 @@ wait_for_job() { echo "[INFO] Waiting for job array $job_id" while true; do - if [[ -z "$(squeue -j "$job_id" -h 2>/dev/null || true)" ]]; then + local n + n=$(squeue -j "$job_id" -h 2>/dev/null | wc -l) + + if [[ "$n" -eq 0 ]]; then break fi + squeue -j "$job_id" -o "%.18i %.10T %.10M %.20R" sleep "$POLL_SECONDS" done @@ -109,11 +124,14 @@ get_status() { return fi - python - </dev/null || echo "unknown" + python - </dev/null import json with open("$status_file") as f: print(json.load(f).get("status", "unknown")) PY + if [[ $? -ne 0 ]]; then + echo "unknown" + fi } get_retry_count() { @@ -125,11 +143,14 @@ get_retry_count() { return fi - python - </dev/null || echo 0 + python - </dev/null import json with open("$status_file") as f: print(int(json.load(f).get("retry_count", 0))) PY + if [[ $? -ne 0 ]]; then + echo 0 + fi } wait_for_settle() { @@ -137,21 +158,23 @@ wait_for_settle() { echo "[INFO] Waiting up to ${SETTLE_SECONDS}s for shard states to settle" - while (( waited < SETTLE_SECONDS )); do + while [[ "$waited" -lt "$SETTLE_SECONDS" ]]; do local running=0 for i in $(seq 0 $((NUM_SHARDS - 1))); do - [[ "$(get_status "$i")" == "running" ]] && ((running+=1)) + if [[ "$(get_status "$i")" == "running" ]]; then + running=$((running + 1)) + fi done - if (( running == 0 )); then + if [[ "$running" -eq 0 ]]; then echo "[INFO] State files settled" return fi echo "[INFO] $running shard(s) still marked running" sleep "$SETTLE_POLL" - ((waited+=SETTLE_POLL)) + waited=$((waited + SETTLE_POLL)) done echo "[INFO] Continuing after settle timeout" @@ -176,13 +199,13 @@ check_failed_shards() { if [[ "$status" == "success" ]]; then echo "āœ… shard $i: success" - ((success+=1)) + success=$((success + 1)) elif [[ "$status" == "running" ]]; then echo "ā³ shard $i: still running in state file (retry=$retry)" - ((running+=1)) + running=$((running + 1)) elif [[ "$retry" -ge "$MAX_RETRIES" ]]; then echo "šŸ›‘ shard $i: retries exhausted ($retry)" - ((exhausted+=1)) + exhausted=$((exhausted + 1)) else echo "āŒ shard $i: $status (retry=$retry)" failed_ref+=("$i") From 7d0d2c31c71a7cf78bc9b3509f68a114ae48973d Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Sat, 14 Mar 2026 18:13:38 +0100 Subject: [PATCH 20/47] working on the test data --- configs/config_mock.yaml | 2 +- configs/config_mock_vision.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/config_mock.yaml b/configs/config_mock.yaml index a2644c5..1a59d23 100644 --- a/configs/config_mock.yaml +++ b/configs/config_mock.yaml @@ -20,7 +20,7 @@ loading_params: output_dir: tests/output/data num_shards: 4 - shard_id: 0 + shard_id: "$SLURM_ARRAY_TASK_ID" batch_size: 64 processing_params: diff --git a/configs/config_mock_vision.yaml b/configs/config_mock_vision.yaml index af08754..c5c2c68 100644 --- a/configs/config_mock_vision.yaml +++ b/configs/config_mock_vision.yaml @@ -19,7 +19,7 @@ loading_params: image_base_path: tests/mock_data_vision # Base directory where images are stored num_shards: 4 - shard_id: 0 + shard_id: "$SLURM_ARRAY_TASK_ID" batch_size: 1 processing_params: From c33cf828b0ad44ede3606b1fc50aeaa420640131 Mon Sep 17 00:00:00 2001 From: Quentin Chappuis Date: Sat, 14 Mar 2026 18:15:18 +0100 Subject: [PATCH 21/47] working on the test data --- run_with_retry.sh | 72 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 53 insertions(+), 19 deletions(-) diff --git a/run_with_retry.sh b/run_with_retry.sh index 27ff65d..4bf9ba5 100644 --- a/run_with_retry.sh +++ b/run_with_retry.sh @@ -3,17 +3,18 @@ set -uo pipefail JOB_NAME="mmirage-sharded" ACCOUNT="a127" -RESERVATION="" # e.g. "sai-a127" if needed +RESERVATION="" # leave empty unless a real reservation exists -MMIRAGE_CHDIR="/users/$USER/meditron/MMIRAGE/src/mmirage" +PROJECT_ROOT="/users/$USER/meditron/MIRAGE" +MMIRAGE_CHDIR="$PROJECT_ROOT/src/mmirage" REPORT_DIR="/users/$USER/reports" -EDF_ENV="/users/$USER/.edf/mmirage.toml" +EDF_ENV="/users/$USER/.edf/sglang.toml" -CFG="/users/$USER/meditron/MMIRAGE/configs/config_mock.yaml" +CFG="$PROJECT_ROOT/configs/config_mock.yaml" HF_HOME="/capstor/store/cscs/swissai/a127/homes/$USER/hf" -STATE_ROOT="/users/$USER/meditron/MMIRAGE/tests/output/data/_pipeline_state" +STATE_ROOT="$PROJECT_ROOT/tests/output/data/_pipeline_state" -NUM_SHARDS=32 +NUM_SHARDS=4 MAX_RETRIES=3 NODES=1 @@ -39,15 +40,14 @@ echo "User : $USER" echo "Host : $(hostname)" echo "Start Time : $(date)" echo "Log File : $LOG_FILE" +echo "State Root : $STATE_ROOT" echo "==================================================" submit_array_job() { local array_spec="$1" local extra=() - if [[ -n "$RESERVATION" ]]; then - extra+=(--reservation="$RESERVATION") - fi + [[ -n "$RESERVATION" ]] && extra+=(--reservation="$RESERVATION") echo "[INFO] Submitting job array: $array_spec" @@ -80,7 +80,7 @@ submit_array_job() { python \"$MMIRAGE_CHDIR/shard_process.py\" --config \$CFG rc=\$? - echo PYTHON_EXIT: \$rc + echo EXIT_CODE: \$rc echo END: \$(date) exit \$rc " @@ -101,14 +101,16 @@ wait_for_job() { echo "[INFO] Waiting for job array $job_id" while true; do - local n - n=$(squeue -j "$job_id" -h 2>/dev/null | wc -l) + local active + active=$(squeue -j "$job_id" -h 2>/dev/null | wc -l) + + echo "[INFO] Poll $(date): active entries = $active" - if [[ "$n" -eq 0 ]]; then + if [[ "$active" -eq 0 ]]; then break fi - squeue -j "$job_id" -o "%.18i %.10T %.10M %.20R" + squeue -j "$job_id" -o "%.18i %.10T %.10M %.20R" || true sleep "$POLL_SECONDS" done @@ -153,6 +155,15 @@ PY fi } +count_state_files() { + if [[ ! -d "$STATE_ROOT" ]]; then + echo 0 + return + fi + + find "$STATE_ROOT" -maxdepth 2 -name status.json 2>/dev/null | wc -l +} + wait_for_settle() { local waited=0 @@ -160,6 +171,9 @@ wait_for_settle() { while [[ "$waited" -lt "$SETTLE_SECONDS" ]]; do local running=0 + local present=0 + + present=$(count_state_files) for i in $(seq 0 $((NUM_SHARDS - 1))); do if [[ "$(get_status "$i")" == "running" ]]; then @@ -167,17 +181,17 @@ wait_for_settle() { fi done + echo "[INFO] Settle $(date): status files=$present running_states=$running" + if [[ "$running" -eq 0 ]]; then - echo "[INFO] State files settled" - return + break fi - echo "[INFO] $running shard(s) still marked running" sleep "$SETTLE_POLL" waited=$((waited + SETTLE_POLL)) done - echo "[INFO] Continuing after settle timeout" + echo "[INFO] Settle phase finished" } check_failed_shards() { @@ -187,6 +201,7 @@ check_failed_shards() { local success=0 local exhausted=0 local running=0 + local missing=0 echo echo "[INFO] Inspecting shard states" @@ -200,12 +215,19 @@ check_failed_shards() { if [[ "$status" == "success" ]]; then echo "āœ… shard $i: success" success=$((success + 1)) + elif [[ "$status" == "running" ]]; then echo "ā³ shard $i: still running in state file (retry=$retry)" running=$((running + 1)) + + elif [[ "$status" == "missing" ]]; then + echo "āš ļø shard $i: missing state file" + missing=$((missing + 1)) + elif [[ "$retry" -ge "$MAX_RETRIES" ]]; then echo "šŸ›‘ shard $i: retries exhausted ($retry)" exhausted=$((exhausted + 1)) + else echo "āŒ shard $i: $status (retry=$retry)" failed_ref+=("$i") @@ -217,8 +239,20 @@ check_failed_shards() { echo " success: $success / $NUM_SHARDS" echo " retry: ${#failed_ref[@]}" echo " still running: $running" + echo " missing: $missing" echo " exhausted: $exhausted" echo + + local present + present=$(count_state_files) + + if [[ "$present" -eq 0 ]]; then + echo "[ERROR] No shard state files were created in: $STATE_ROOT" + echo "[ERROR] The job finished, but no status.json files were found." + echo "[ERROR] This usually means STATE_ROOT is wrong, or shard_process.py wrote elsewhere." + echo "[ERROR] Refusing to auto-retry all shards." + exit 2 + fi } main() { @@ -245,4 +279,4 @@ main() { done } -main \ No newline at end of file +main From 8ec66130caab8f4248f8933cd265a64f3f8d7f84 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Sat, 14 Mar 2026 18:25:14 +0100 Subject: [PATCH 22/47] updated readme with modified changes --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index c162e8c..a3ccd08 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ processors: max_new_tokens: 384 loading_params: + state_dir: /path/to/state/dir datasets: - path: /path/to/dataset type: loadable @@ -141,6 +142,7 @@ processors: max_new_tokens: 768 loading_params: + state_dir: path/to/state/dir datasets: - path: /path/to/image/dataset type: loadable From d5a852768661ee6bd4cafc529c6e0ae0e928091e Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:48:07 +0100 Subject: [PATCH 23/47] trying with a cli to provide an easier way to use --- README.md | 21 +- configs/config_comprehensive.yaml | 203 +++++++ configs/config_mock.yaml | 10 + configs/config_mock_vision.yaml | 8 + pyproject.toml | 3 + retry_failed.sh | 105 ++-- run.sh | 46 +- run_with_retry.sh | 311 ++--------- src/mmirage/__init__.py | 12 +- src/mmirage/cli.py | 495 ++++++++++++++++++ src/mmirage/config/config.py | 78 ++- src/mmirage/config/loading.py | 21 +- src/mmirage/config/utils.py | 9 + src/mmirage/core/loader/__init__.py | 10 - src/mmirage/core/process/__init__.py | 14 +- src/mmirage/core/process/base.py | 26 + .../core/process/processors/llm/config.py | 4 + 17 files changed, 954 insertions(+), 422 deletions(-) create mode 100644 configs/config_comprehensive.yaml create mode 100644 src/mmirage/cli.py diff --git a/README.md b/README.md index a3ccd08..eb6306b 100644 --- a/README.md +++ b/README.md @@ -32,24 +32,27 @@ For testing and scripts that make use of the library, it is advised to create a ## Example usage -### Running with Automatic Retry +### Running (single command) -Your job scripts now automatically track success and failure using marker files. +Run the pipeline via the Python CLI. Retry behavior is driven by your YAML config: -Run your job with automatic retry: +- `execution_params.retry: true` → automatically retries failed shards until completion or `max_retries` +- `execution_params.retry: false` → submits/runs once; you can later trigger retries via `check` ```bash -bash run_with_retry.sh +python -m mmirage.cli run --config configs/config_mock.yaml ``` -Alternatively, you can run the steps separately: +To check status and (optionally) submit retries for failed shards: ```bash -# 1. Launch the job -sbatch run.sh +python -m mmirage.cli check --config configs/config_mock.yaml +``` + +If you only want the status summary (no retry submission): -# 2. Retry failed jobs -bash retry_failed.sh +```bash +python -m mmirage.cli check --config configs/config_mock.yaml --summary-only ``` ### Text-only: Reformatting dataset diff --git a/configs/config_comprehensive.yaml b/configs/config_comprehensive.yaml new file mode 100644 index 0000000..4285237 --- /dev/null +++ b/configs/config_comprehensive.yaml @@ -0,0 +1,203 @@ +# MMIRAGE Configuration with all parameters +# +# This is a comprehensive example showing all available configuration options. +# You can copy and modify this file for your specific use case. +# +# Parameters are organized into sections: +# 1. processors - LLM and other data transformation processors +# 2. loading_params - Dataset loading and sharding configuration +# 3. processing_params - How to transform/process the data +# 4. execution_params - SLURM, retry, and execution settings +# + +# ============================================================================ +# PROCESSORS CONFIGURATION +# ============================================================================ +# Define the processors used to transform your data. +# Common types: llm, vision_llm, etc. + +processors: + - type: llm + server_args: + model_path: Qwen/Qwen3-4B-Instruct-2507 + tp_size: 1 + disable_custom_all_reduce: true + default_sampling_params: + temperature: 0.1 + top_p: 0.9 + max_new_tokens: 1024 + custom_params: + chat_template_kwargs: + enable_thinking: false + + +# ============================================================================ +# LOADING PARAMETERS +# ============================================================================ +# Configure how datasets are loaded, sharded, and processed. + +loading_params: + # Directory to store pipeline state (checkpoints, status, retry tracking) + # Supports environment variables: $VAR or ${VAR} + state_dir: tests/output/data/_pipeline_state + + # Dataset configurations to load + # Each dataset can be separately sharded and output + datasets: + - path: tests/mock_data/data.jsonl + type: JSONL + output_dir: tests/output/data + # image_base_path: /path/to/images # Optional, for vision tasks + + # Total number of shards to split datasets into. + # For SLURM, this determines the array job size. + num_shards: 4 + + # Shard ID for this process (0-indexed). + # In SLURM array jobs, this is set automatically. + shard_id: "$SLURM_ARRAY_TASK_ID" + + # Batch size for processing samples + batch_size: 64 + + +# ============================================================================ +# PROCESSING PARAMETERS +# ============================================================================ +# Define what to extract, transform, and output from each sample. + +processing_params: + # Input variables to extract from source data + inputs: + - name: text + key: text + # For vision examples: + # - name: image + # key: image_path + # type: image + + # Output variables generated by processors + outputs: + - name: formatted_answer + type: llm + output_type: JSON + output_schema: + - question + - answer + prompt: | + Generate one question and its corresponding answer using the following text: + ``` + {{ text }} + ``` + + # Whether to remove original columns from the dataset + remove_columns: true + + # Output schema: how to structure the final dataset + output_schema: + conversations: + - role: "user" + content: "{{ formatted_answer.question }}" + - role: "assistant" + content: "{{ formatted_answer.answer }}" + + +# ============================================================================ +# EXECUTION PARAMETERS +# ============================================================================ +# Configure how to execute the pipeline: locally or on SLURM cluster. +# All parameters here are optional with sensible defaults. + +execution_params: + # Execution mode: "local" or "slurm" + # - local: Run directly on this machine + # - slurm: Submit jobs to SLURM cluster + mode: slurm + + # Whether the canonical `run` command should automatically retry failed shards. + # - false: submit one run only + # - true: submit, wait, and keep retrying failed shards until success or retry budget exhaustion + retry: true + + # Maximum number of times to retry a failed shard (default: 3) + max_retries: 3 + + # ========================================================================== + # SLURM CONFIGURATION (only used when mode: slurm) + # ========================================================================== + + # HPC account/partition to charge jobs to (REQUIRED for SLURM mode) + account: a127 + + # SLURM job name (default: "mmirage-sharded") + job_name: mmirage-sharded + + # Optional SLURM reservation name (leave blank or omit to not use) + # reservation: "sai-a127" + + # Number of nodes (default: 1) + nodes: 1 + + # Number of tasks per node (default: 1) + ntasks_per_node: 1 + + # Number of GPUs per node (default: 4) + gpus: 4 + + # Number of CPUs per task (default: 288) + cpus_per_task: 288 + + # Job time limit in HH:MM:SS format (default: "11:59:59") + time_limit: "11:59:59" + + # ========================================================================== + # PATH CONFIGURATION + # ========================================================================== + # These support environment variables ($VAR or ${VAR}) and home directory (~) + + # Project root directory (used as base for relative paths) + # If not set, uses current working directory + # project_root: "/path/to/project" + + # Directory for SLURM output and error files (default: ~/reports) + report_dir: "/users/${USER}/reports" + + # HuggingFace cache directory (default: ~/hf) + hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf" + + # Optional EDF environment file path for cluster-specific setup + # edf_env: "/path/to/.edf/mmirage.toml" + + # ========================================================================== + # JOB MONITORING (for "submit" and retry orchestration) + # ========================================================================== + + # Seconds to wait between checking job status (default: 30) + poll_interval_seconds: 30 + + # Seconds to wait after job completes before checking results (default: 60) + # This allows filesystem to settle on distributed systems + settle_time_seconds: 60 + + # Seconds to wait between checks during settle time (default: 10) + settle_poll_interval: 10 + + +# ============================================================================ +# USAGE EXAMPLES +# ============================================================================ +# +# 1. Canonical entrypoint (local or SLURM; retry controlled by config): +# python -m mmirage.cli run --config config.yaml +# +# 2. Submit job to SLURM with wait for completion: +# python -m mmirage.cli submit --config config.yaml --wait +# +# 3. Submit job and get job ID back (for scripting): +# JOB_ID=$(python -m mmirage.cli submit --config config.yaml) +# +# 4. Run a single shard locally: +# python -m mmirage.cli process --config config.yaml --shard-id 0 +# +# 5. Check status of all shards (and optionally submit retries): +# python -m mmirage.cli check --config config.yaml diff --git a/configs/config_mock.yaml b/configs/config_mock.yaml index 1a59d23..0f2bd54 100644 --- a/configs/config_mock.yaml +++ b/configs/config_mock.yaml @@ -48,3 +48,13 @@ processing_params: content: "{{ formatted_answer.question }}" - role: "assistant" content: "{{ formatted_answer.answer }}" + +# Execution configuration (local or SLURM cluster) +# For local testing, use mode: local +# For SLURM cluster, use mode: slurm and specify account +execution_params: + mode: local + retry: false + max_retries: 3 + report_dir: ~/reports + hf_home: ~/hf diff --git a/configs/config_mock_vision.yaml b/configs/config_mock_vision.yaml index c5c2c68..811d61e 100644 --- a/configs/config_mock_vision.yaml +++ b/configs/config_mock_vision.yaml @@ -39,3 +39,11 @@ processing_params: output_schema: image: "{{ image_input }}" caption: "{{ caption }}" + +# Execution configuration (local or SLURM cluster) +execution_params: + mode: local + retry: false + max_retries: 3 + report_dir: ~/reports + hf_home: ~/hf diff --git a/pyproject.toml b/pyproject.toml index 446406b..616ed78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ dev = [ "pytest", ] +[project.scripts] +mmirage = "mmirage.cli:main" + [tool.hatch.build.targets.wheel] packages = ["src/mmirage"] diff --git a/retry_failed.sh b/retry_failed.sh index 22b6a85..90a97a6 100644 --- a/retry_failed.sh +++ b/retry_failed.sh @@ -1,83 +1,52 @@ #!/bin/bash -# Check for failed logical shards and relaunch them +# MMIRAGE retry failed shards script +# +# Check for failed logical shards and relaunch them interactively. +# +# Usage: +# bash retry_failed.sh [--config path/to/config.yaml] +# +# Configuration: +# Set the CFG environment variable to point to your config file, or +# use the --config argument. Defaults to configs/config_mock.yaml. # -# Usage: bash retry_failed.sh set -euo pipefail - -STATE_ROOT="/users/$USER/meditron/MMIRAGE/tests/output/data/_pipeline_state" -NUM_SHARDS=32 -MAX_RETRIES=3 -SCRIPT_PATH="/users/$USER/meditron/MMIRAGE/run.sh" - -echo "Checking shard states in: $STATE_ROOT" -echo "" - -failed_shards=() -success_count=0 - -for i in $(seq 0 $((NUM_SHARDS - 1))); do - state_dir="$STATE_ROOT/shard_$i" - status_file="$state_dir/status.json" - - if [ ! -f "$status_file" ]; then - echo "āŒ Shard $i: MISSING STATUS" - failed_shards+=("$i") - continue - fi - - status=$(python - < 0 )); do + case "$1" in + --config) + CFG="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" >&2 + exit 1 + ;; + esac done -echo "" -echo "==========================================" -echo "Summary:" -echo " āœ… Successful: $success_count / $NUM_SHARDS" -echo " āŒ To retry: ${#failed_shards[@]}" -echo "==========================================" -echo "" - -if [ ${#failed_shards[@]} -eq 0 ]; then - echo "šŸŽ‰ All shards completed successfully!" - exit 0 +if [[ ! -f "$CFG" ]]; then + echo "āŒ Config file not found: $CFG" >&2 + exit 1 fi -ARRAY_SPEC=$(IFS=,; echo "${failed_shards[*]}") +echo "Checking shard states from config: $CFG" +echo "" + +# Use MMIRAGE CLI to check failed shards (summary only; no retry submission) +python -m mmirage.cli check --config "$CFG" --summary-only || true -echo "Failed shards: $ARRAY_SPEC" echo "" -read -p "Submit retry job for these shards? (y/N) " -n 1 -r +read -p "Submit retry job for failed shards? (y/N) " -n 1 -r echo if [[ $REPLY =~ ^[Yy]$ ]]; then - JOB_ID=$(sbatch --array="$ARRAY_SPEC" "$SCRIPT_PATH" | grep -oE '[0-9]+') - echo "āœ… Job submitted: $JOB_ID" - echo "" - echo "Monitor with: squeue -j $JOB_ID" + python -m mmirage.cli retry --config "$CFG" --no-interactive else echo "Cancelled." -fi \ No newline at end of file + exit 1 +fi diff --git a/run.sh b/run.sh index e5a8606..6ac856a 100644 --- a/run.sh +++ b/run.sh @@ -1,33 +1,23 @@ #!/bin/bash -#SBATCH --job-name=mmirage-sharded -#SBATCH --chdir=/users/$USER/meditron/MMIRAGE/src/mmirage -#SBATCH --output=/users/$USER/reports/R-%x.%A_%a.out -#SBATCH --error=/users/$USER/reports/R-%x.%A_%a.err -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:4 -#SBATCH --cpus-per-task=288 -#SBATCH --time=11:59:59 -#SBATCH -A a127 -#SBATCH --array=0-3 +# MMIRAGE launch script. +# +# Launch behavior is driven by the config file: +# - execution_params.retry=false: submit one SLURM array job, or run locally +# - execution_params.retry=true: submit and automatically retry failed shards +# +# Usage: +# bash run.sh +# CFG=configs/config_mock.yaml bash run.sh -# --- outputs & config --- -export CFG=$MMIRAGE_PATH/configs/config_small.yaml -export TOTAL_SHARDS=32 # Total number of shards (used for retries) +set -euo pipefail -# HF cache/home -export HF_HOME=$SCRATCH/hf +CFG="${CFG:-configs/config_mock.yaml}" -export CMD="python $MMIRAGE_PATH/src/mmirage/shard_process.py --config $CFG" +if [[ ! -f "$CFG" ]]; then + echo "Config file not found: $CFG" >&2 + exit 1 +fi -SRUN_ARGS=" \ - --cpus-per-task $SLURM_CPUS_PER_TASK \ - --jobid $SLURM_JOB_ID \ - --wait 60 \ - -A a127 \ - --reservation sai-a127 \ - --environment /users/$USER/.edf/mmirage.toml - " -# bash -c is needed for the delayed interpolation of env vars to work -srun $SRUN_ARGS bash -c "$CMD" -echo "END TIME: $(date)" \ No newline at end of file +python -m mmirage.cli run --config "$CFG" + +echo "END TIME: $(date)" diff --git a/run_with_retry.sh b/run_with_retry.sh index 4bf9ba5..e1b116a 100644 --- a/run_with_retry.sh +++ b/run_with_retry.sh @@ -1,282 +1,31 @@ #!/bin/bash -set -uo pipefail - -JOB_NAME="mmirage-sharded" -ACCOUNT="a127" -RESERVATION="" # leave empty unless a real reservation exists - -PROJECT_ROOT="/users/$USER/meditron/MIRAGE" -MMIRAGE_CHDIR="$PROJECT_ROOT/src/mmirage" -REPORT_DIR="/users/$USER/reports" -EDF_ENV="/users/$USER/.edf/sglang.toml" - -CFG="$PROJECT_ROOT/configs/config_mock.yaml" -HF_HOME="/capstor/store/cscs/swissai/a127/homes/$USER/hf" -STATE_ROOT="$PROJECT_ROOT/tests/output/data/_pipeline_state" - -NUM_SHARDS=4 -MAX_RETRIES=3 - -NODES=1 -NTASKS_PER_NODE=1 -GPUS=4 -CPUS_PER_TASK=288 -TIME_LIMIT="11:59:59" - -POLL_SECONDS=30 -SETTLE_SECONDS=60 -SETTLE_POLL=10 - -mkdir -p "$REPORT_DIR" "$HF_HOME" -LOG_FILE="$REPORT_DIR/${JOB_NAME}_logs.out" -exec > >(tee -a "$LOG_FILE") 2>&1 - -export CFG HF_HOME -export TOTAL_SHARDS="$NUM_SHARDS" - -echo "==================================================" -echo "Pipeline : $JOB_NAME" -echo "User : $USER" -echo "Host : $(hostname)" -echo "Start Time : $(date)" -echo "Log File : $LOG_FILE" -echo "State Root : $STATE_ROOT" -echo "==================================================" - -submit_array_job() { - local array_spec="$1" - local extra=() - - [[ -n "$RESERVATION" ]] && extra+=(--reservation="$RESERVATION") - - echo "[INFO] Submitting job array: $array_spec" - - local submitted - submitted=$( - sbatch --parsable \ - "${extra[@]}" \ - --job-name="$JOB_NAME" \ - --chdir="$MMIRAGE_CHDIR" \ - --output="$REPORT_DIR/R-%x.%A_%a.out" \ - --error="$REPORT_DIR/R-%x.%A_%a.err" \ - --nodes="$NODES" \ - --ntasks-per-node="$NTASKS_PER_NODE" \ - --gres="gpu:${GPUS}" \ - --cpus-per-task="$CPUS_PER_TASK" \ - --time="$TIME_LIMIT" \ - -A "$ACCOUNT" \ - --array="$array_spec" \ - --export=ALL,CFG="$CFG",TOTAL_SHARDS="$NUM_SHARDS",HF_HOME="$HF_HOME" \ - --wrap=" - set -uo pipefail - echo START: \$(date) - echo HOST: \$(hostname) - echo TASK: \$SLURM_ARRAY_TASK_ID - - srun \ - --cpus-per-task=$CPUS_PER_TASK \ - --wait=60 \ - --environment=$EDF_ENV \ - python \"$MMIRAGE_CHDIR/shard_process.py\" --config \$CFG - - rc=\$? - echo EXIT_CODE: \$rc - echo END: \$(date) - exit \$rc - " - ) - - if [[ -z "$submitted" ]]; then - echo "[ERROR] sbatch submission failed" - exit 1 - fi - - SUBMITTED_JOB_ID="${submitted%%;*}" - echo "[INFO] Job ID: $SUBMITTED_JOB_ID" -} - -wait_for_job() { - local job_id="$1" - - echo "[INFO] Waiting for job array $job_id" - - while true; do - local active - active=$(squeue -j "$job_id" -h 2>/dev/null | wc -l) - - echo "[INFO] Poll $(date): active entries = $active" - - if [[ "$active" -eq 0 ]]; then - break - fi - - squeue -j "$job_id" -o "%.18i %.10T %.10M %.20R" || true - sleep "$POLL_SECONDS" - done - - echo "[INFO] Job array $job_id finished" -} - -get_status() { - local shard="$1" - local status_file="$STATE_ROOT/shard_$shard/status.json" - - if [[ ! -f "$status_file" ]]; then - echo "missing" - return - fi - - python - </dev/null -import json -with open("$status_file") as f: - print(json.load(f).get("status", "unknown")) -PY - if [[ $? -ne 0 ]]; then - echo "unknown" - fi -} - -get_retry_count() { - local shard="$1" - local status_file="$STATE_ROOT/shard_$shard/status.json" - - if [[ ! -f "$status_file" ]]; then - echo 0 - return - fi - - python - </dev/null -import json -with open("$status_file") as f: - print(int(json.load(f).get("retry_count", 0))) -PY - if [[ $? -ne 0 ]]; then - echo 0 - fi -} - -count_state_files() { - if [[ ! -d "$STATE_ROOT" ]]; then - echo 0 - return - fi - - find "$STATE_ROOT" -maxdepth 2 -name status.json 2>/dev/null | wc -l -} - -wait_for_settle() { - local waited=0 - - echo "[INFO] Waiting up to ${SETTLE_SECONDS}s for shard states to settle" - - while [[ "$waited" -lt "$SETTLE_SECONDS" ]]; do - local running=0 - local present=0 - - present=$(count_state_files) - - for i in $(seq 0 $((NUM_SHARDS - 1))); do - if [[ "$(get_status "$i")" == "running" ]]; then - running=$((running + 1)) - fi - done - - echo "[INFO] Settle $(date): status files=$present running_states=$running" - - if [[ "$running" -eq 0 ]]; then - break - fi - - sleep "$SETTLE_POLL" - waited=$((waited + SETTLE_POLL)) - done - - echo "[INFO] Settle phase finished" -} - -check_failed_shards() { - local -n failed_ref=$1 - failed_ref=() - - local success=0 - local exhausted=0 - local running=0 - local missing=0 - - echo - echo "[INFO] Inspecting shard states" - echo - - for i in $(seq 0 $((NUM_SHARDS - 1))); do - local status retry - status="$(get_status "$i")" - retry="$(get_retry_count "$i")" - - if [[ "$status" == "success" ]]; then - echo "āœ… shard $i: success" - success=$((success + 1)) - - elif [[ "$status" == "running" ]]; then - echo "ā³ shard $i: still running in state file (retry=$retry)" - running=$((running + 1)) - - elif [[ "$status" == "missing" ]]; then - echo "āš ļø shard $i: missing state file" - missing=$((missing + 1)) - - elif [[ "$retry" -ge "$MAX_RETRIES" ]]; then - echo "šŸ›‘ shard $i: retries exhausted ($retry)" - exhausted=$((exhausted + 1)) - - else - echo "āŒ shard $i: $status (retry=$retry)" - failed_ref+=("$i") - fi - done - - echo - echo "Summary" - echo " success: $success / $NUM_SHARDS" - echo " retry: ${#failed_ref[@]}" - echo " still running: $running" - echo " missing: $missing" - echo " exhausted: $exhausted" - echo - - local present - present=$(count_state_files) - - if [[ "$present" -eq 0 ]]; then - echo "[ERROR] No shard state files were created in: $STATE_ROOT" - echo "[ERROR] The job finished, but no status.json files were found." - echo "[ERROR] This usually means STATE_ROOT is wrong, or shard_process.py wrote elsewhere." - echo "[ERROR] Refusing to auto-retry all shards." - exit 2 - fi -} - -main() { - local failed=() - local array_spec="0-$((NUM_SHARDS - 1))" - - while true; do - echo "--------------------------------------------------" - echo "[INFO] Starting iteration at $(date)" - echo "--------------------------------------------------" - - submit_array_job "$array_spec" - wait_for_job "$SUBMITTED_JOB_ID" - wait_for_settle - check_failed_shards failed - - if [[ ${#failed[@]} -eq 0 ]]; then - echo "šŸŽ‰ Pipeline completed successfully" - exit 0 - fi - - array_spec=$(IFS=,; echo "${failed[*]}") - echo "[INFO] Retrying shards: $array_spec" - done -} - -main +# MMIRAGE pipeline orchestration with forced automatic retry. +# +# Usage: +# bash run_with_retry.sh [--config path/to/config.yaml] + +set -euo pipefail +IFS=$'\n\t' + +CFG="${CFG:-configs/config_mock.yaml}" + +while (( $# > 0 )); do + case "$1" in + --config) + CFG="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" >&2 + exit 1 + ;; + esac +done + +if [[ ! -f "$CFG" ]]; then + echo "Config file not found: $CFG" >&2 + exit 1 +fi + +echo "Config: $CFG" +python -m mmirage.cli run --config "$CFG" --force-retry diff --git a/src/mmirage/__init__.py b/src/mmirage/__init__.py index 000aa27..fa7fe08 100644 --- a/src/mmirage/__init__.py +++ b/src/mmirage/__init__.py @@ -3,16 +3,12 @@ A platform for processing datasets using generative models including vision-language models (VLMs). """ +from __future__ import annotations __version__ = "0.2.0" -from mmirage.config import MMirageConfig, ProcessingParams, LoadingParams +from mmirage.config.config import MMirageConfig, ProcessingParams +from mmirage.config.loading import LoadingParams from mmirage.config.utils import load_mmirage_config -__all__ = [ - "MMirageConfig", - "ProcessingParams", - "LoadingParams", - "load_mmirage_config", - "__version__", -] +__all__ = ["MMirageConfig", "ProcessingParams", "LoadingParams", "load_mmirage_config", "__version__"] diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py new file mode 100644 index 0000000..0744bbc --- /dev/null +++ b/src/mmirage/cli.py @@ -0,0 +1,495 @@ +"""Command-line interface for MMIRAGE pipeline.""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import shlex +import subprocess +import sys +import time +from pathlib import Path +from typing import List, Optional, Sequence, Tuple + +from mmirage.config.config import MMirageConfig +from mmirage.config.utils import load_mmirage_config + + +logger = logging.getLogger(__name__) + + +def expand_path(path: str, project_root: Optional[str] = None) -> str: + """Expand environment variables, user home and relative paths.""" + expanded = os.path.expanduser(os.path.expandvars(path)) + if not os.path.isabs(expanded) and project_root: + expanded = os.path.join(project_root, expanded) + return os.path.abspath(expanded) + + +def get_project_root(cfg: MMirageConfig) -> str: + """Return the configured project root, or the current working directory.""" + project_root = cfg.execution_params.project_root + if project_root: + return expand_path(project_root) + return os.getcwd() + + +def create_directories(paths: Sequence[str]) -> None: + """Create directories if they do not already exist.""" + for path in paths: + Path(path).mkdir(parents=True, exist_ok=True) + + +def validate_paths(cfg: MMirageConfig) -> None: + """Validate pre-existing execution paths.""" + project_root = get_project_root(cfg) + if cfg.execution_params.edf_env: + edf_env = expand_path(cfg.execution_params.edf_env, project_root) + if not os.path.exists(edf_env): + raise FileNotFoundError(f"EDF environment file not found: {edf_env}") + + +def get_shard_state_dir(state_root: str, shard_id: int) -> str: + """Return the state directory for a shard.""" + return os.path.join(state_root, f"shard_{shard_id}") + + +def get_shard_status(state_dir: str) -> Tuple[str, int]: + """Read the current status and retry count for a shard.""" + status_file = os.path.join(state_dir, "status.json") + if not os.path.exists(status_file): + return ("missing", 0) + + try: + with open(status_file, "r", encoding="utf-8") as handle: + data = json.load(handle) + except (OSError, json.JSONDecodeError) as exc: + logger.warning("Failed to read shard status from %s: %s", status_file, exc) + return ("unknown", 0) + + return (str(data.get("status", "unknown")), int(data.get("retry_count", 0))) + + +def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], dict]: + """Return retryable failed shards and a compact summary.""" + state_root = cfg.loading_params.get_state_root() + if not state_root: + raise ValueError("loading_params.state_dir is required to check shard status") + + num_shards = cfg.loading_params.get_num_shards() + max_retries = cfg.execution_params.max_retries + failed_shards: List[int] = [] + success_count = 0 + exhausted_count = 0 + + for shard_id in range(num_shards): + status, retry_count = get_shard_status(get_shard_state_dir(state_root, shard_id)) + if status == "success": + success_count += 1 + continue + + if retry_count >= max_retries: + exhausted_count += 1 + logger.warning( + "Shard %s exceeded retry budget (%s/%s)", + shard_id, + retry_count, + max_retries, + ) + continue + + failed_shards.append(shard_id) + + summary = { + "total": num_shards, + "successful": success_count, + "failed": len(failed_shards), + "max_retries_exceeded": exhausted_count, + } + return failed_shards, summary + + +def run_local(config_path: str, shard_id: Optional[int] = None) -> int: + """Run one shard in the current Python environment.""" + command = [sys.executable, "-m", "mmirage.shard_process", "--config", config_path] + env = os.environ.copy() + if shard_id is not None: + env["SLURM_ARRAY_TASK_ID"] = str(shard_id) + + logger.info("Running local shard processing: %s", " ".join(command)) + result = subprocess.run(command, env=env, check=False) + return result.returncode + + +def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str: + """Build the sbatch payload executed for each array task.""" + project_root = get_project_root(cfg) + hf_home = expand_path(cfg.execution_params.hf_home, project_root) + state_root = expand_path(cfg.loading_params.get_state_root(), project_root) + + lines = [ + "#!/bin/bash", + "set -euo pipefail", + f"export HF_HOME={shlex.quote(hf_home)}", + f"export MMIRAGE_CONFIG={shlex.quote(config_path)}", + f"mkdir -p {shlex.quote(hf_home)}", + f"mkdir -p {shlex.quote(state_root)}", + "srun_args=(--cpus-per-task ${SLURM_CPUS_PER_TASK:-1} --wait 60)", + ] + + if cfg.execution_params.edf_env: + edf_env = expand_path(cfg.execution_params.edf_env, project_root) + lines.append(f"srun_args+=(--environment={shlex.quote(edf_env)})") + + lines.extend( + [ + f"srun \"${{srun_args[@]}}\" {shlex.quote(sys.executable)} -m mmirage.shard_process --config \"$MMIRAGE_CONFIG\"", + "echo \"Shard ${SLURM_ARRAY_TASK_ID:-0} completed\"", + ] + ) + return "\n".join(lines) + "\n" + + +def submit_slurm_job( + cfg: MMirageConfig, + config_path: str, + shard_ids: Optional[Sequence[int]] = None, +) -> Optional[int]: + """Submit a SLURM array job and return its job ID.""" + project_root = get_project_root(cfg) + report_dir = expand_path(cfg.execution_params.report_dir, project_root) + create_directories([report_dir]) + + command = [ + "sbatch", + "--parsable", + f"--job-name={cfg.execution_params.job_name}", + f"--chdir={project_root}", + f"--output={os.path.join(report_dir, 'R-%x.%A_%a.out')}", + f"--error={os.path.join(report_dir, 'R-%x.%A_%a.err')}", + f"--nodes={cfg.execution_params.nodes}", + f"--ntasks-per-node={cfg.execution_params.ntasks_per_node}", + f"--gres=gpu:{cfg.execution_params.gpus}", + f"--cpus-per-task={cfg.execution_params.cpus_per_task}", + f"--time={cfg.execution_params.time_limit}", + f"--account={cfg.execution_params.account}", + ] + + if cfg.execution_params.reservation: + command.append(f"--reservation={cfg.execution_params.reservation}") + + requested_shards = list(shard_ids or []) + if requested_shards: + command.append(f"--array={','.join(str(shard_id) for shard_id in requested_shards)}") + else: + num_shards = cfg.loading_params.get_num_shards() + command.append(f"--array=0-{num_shards - 1}") + + logger.info("Submitting SLURM job: %s", " ".join(command)) + result = subprocess.run( + command, + input=build_sbatch_script(cfg, config_path), + text=True, + capture_output=True, + check=False, + ) + + if result.returncode != 0: + logger.error("sbatch failed: %s", result.stderr.strip()) + return None + + raw_job_id = result.stdout.strip().split(";", 1)[0] + try: + return int(raw_job_id) + except ValueError: + logger.error("Unable to parse job id from sbatch output: %s", result.stdout.strip()) + return None + + +def wait_for_slurm_job(job_id: int, cfg: MMirageConfig) -> None: + """Wait for a SLURM job array to leave the queue.""" + logger.info("Waiting for SLURM job %s", job_id) + while True: + result = subprocess.run( + ["squeue", "-h", "-j", str(job_id)], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0 and not result.stdout.strip(): + break + time.sleep(cfg.execution_params.poll_interval_seconds) + + if cfg.execution_params.settle_time_seconds > 0: + logger.info("Waiting %ss for state files to settle", cfg.execution_params.settle_time_seconds) + time.sleep(cfg.execution_params.settle_time_seconds) + + +def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = False) -> int: + """Launch the pipeline according to execution mode and retry settings.""" + if not cfg.execution_params.is_slurm(): + return run_local(config_path, cfg.loading_params.get_shard_id()) + + auto_retry = force_retry or cfg.execution_params.retry + shard_ids: List[int] = [] + + while True: + job_id = submit_slurm_job(cfg, config_path, shard_ids) + if job_id is None: + return 1 + + print(job_id) + + if not auto_retry: + return 0 + + wait_for_slurm_job(job_id, cfg) + failed_shards, summary = check_failed_shards(cfg) + + if not failed_shards and summary["max_retries_exceeded"] == 0: + logger.info("All shards completed successfully") + return 0 + + if not failed_shards: + logger.error("Pipeline ended with shards that exceeded max retries") + return 1 + + logger.warning("Retrying failed shards: %s", ",".join(map(str, failed_shards))) + shard_ids = failed_shards + + +def configure_logging(level: str) -> None: + """Configure root logging.""" + logging.basicConfig( + level=getattr(logging, level, logging.INFO), + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + +def add_shared_arguments(parser: argparse.ArgumentParser) -> None: + """Attach common CLI arguments to a subcommand parser.""" + parser.add_argument("--config", required=True, help="Path to a MMIRAGE YAML config file") + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Log verbosity", + ) + + +def build_argparser() -> argparse.ArgumentParser: + """Build the CLI parser.""" + parser = argparse.ArgumentParser(description="MMIRAGE command-line interface") + subparsers = parser.add_subparsers(dest="command", required=True) + + process_parser = subparsers.add_parser("process", help="Run a shard locally") + add_shared_arguments(process_parser) + process_parser.add_argument("--shard-id", type=int, default=None, help="Shard id override") + + submit_parser = subparsers.add_parser("submit", help="Submit one SLURM array job") + add_shared_arguments(submit_parser) + submit_parser.add_argument( + "--shard-ids", + help="Comma-separated shard ids to submit instead of the full array", + ) + submit_parser.add_argument("--wait", action="store_true", help="Wait for the submitted job") + + check_parser = subparsers.add_parser("check", help="Inspect shard status") + add_shared_arguments(check_parser) + check_parser.add_argument( + "--summary-only", + action="store_true", + help="Only print status summary; do not submit retries.", + ) + check_retry_group = check_parser.add_mutually_exclusive_group() + check_retry_group.add_argument( + "--retry", + dest="retry", + action="store_true", + help="Submit a retry job for failed shards (default unless --summary-only).", + ) + check_retry_group.add_argument( + "--no-retry", + dest="retry", + action="store_false", + help="Do not submit retries (same as --summary-only).", + ) + check_parser.set_defaults(retry=True) + check_interactive_group = check_parser.add_mutually_exclusive_group() + check_interactive_group.add_argument( + "--interactive", + dest="interactive", + action="store_true", + help="Prompt before submitting retry jobs (default).", + ) + check_interactive_group.add_argument( + "--no-interactive", + dest="interactive", + action="store_false", + help="Submit retry jobs without prompting.", + ) + check_parser.set_defaults(interactive=True) + + retry_parser = subparsers.add_parser("retry", help="Submit only failed shards") + add_shared_arguments(retry_parser) + retry_group = retry_parser.add_mutually_exclusive_group() + retry_group.add_argument("--interactive", dest="interactive", action="store_true") + retry_group.add_argument("--no-interactive", dest="interactive", action="store_false") + retry_parser.set_defaults(interactive=True) + + launch_parser = subparsers.add_parser( + "launch", + help="(Deprecated) Use 'run'. Launch according to execution_params.mode and execution_params.retry", + ) + add_shared_arguments(launch_parser) + launch_parser.add_argument( + "--force-retry", + action="store_true", + help="Enable retry orchestration even if execution_params.retry is false", + ) + + run_parser = subparsers.add_parser( + "run", + help="Run according to execution_params.mode and execution_params.retry", + ) + add_shared_arguments(run_parser) + run_parser.add_argument( + "--force-retry", + action="store_true", + help="Enable retry orchestration even if execution_params.retry is false", + ) + + return parser + + +def _maybe_submit_retry_job( + *, + cfg: MMirageConfig, + config_path: str, + failed_shards: Sequence[int], + interactive: bool, +) -> int: + if not cfg.execution_params.is_slurm(): + logger.error("Retry submission requires execution_params.mode=slurm") + return 1 + + if not failed_shards: + return 0 + + if interactive: + if not sys.stdin.isatty(): + logger.error("Non-interactive input detected; re-run with --no-interactive to auto-submit retries") + return 1 + response = input(f"Retry {len(failed_shards)} shard(s)? (y/N) ") + if response.strip().lower() != "y": + print("Cancelled.") + return 1 + + job_id = submit_slurm_job(cfg, config_path, failed_shards) + if job_id is None: + return 1 + print(job_id) + return 0 + + +def parse_shard_ids(raw_value: Optional[str]) -> List[int]: + """Parse a comma-separated shard id list.""" + if not raw_value: + return [] + return [int(value.strip()) for value in raw_value.split(",") if value.strip()] + + +def main() -> None: + """CLI entry point.""" + parser = build_argparser() + args = parser.parse_args() + configure_logging(args.log_level) + + try: + config_path = os.path.abspath(args.config) + cfg = load_mmirage_config(config_path) + validate_paths(cfg) + + if args.command == "process": + sys.exit(run_local(config_path, args.shard_id)) + + if args.command == "submit": + if not cfg.execution_params.is_slurm(): + logger.error("submit requires execution_params.mode=slurm") + sys.exit(1) + + job_id = submit_slurm_job(cfg, config_path, parse_shard_ids(args.shard_ids)) + if job_id is None: + sys.exit(1) + + print(job_id) + if args.wait: + wait_for_slurm_job(job_id, cfg) + failed_shards, summary = check_failed_shards(cfg) + sys.exit(0 if not failed_shards and summary["max_retries_exceeded"] == 0 else 1) + sys.exit(0) + + if args.command == "check": + failed_shards, summary = check_failed_shards(cfg) + print(json.dumps(summary, indent=2)) + + if not cfg.execution_params.is_slurm(): + sys.exit(0 if not failed_shards and summary["max_retries_exceeded"] == 0 else 1) + + if args.summary_only or not args.retry: + sys.exit(0 if not failed_shards and summary["max_retries_exceeded"] == 0 else 1) + + if not failed_shards: + sys.exit(0 if summary["max_retries_exceeded"] == 0 else 1) + + sys.exit( + _maybe_submit_retry_job( + cfg=cfg, + config_path=config_path, + failed_shards=failed_shards, + interactive=bool(args.interactive), + ) + ) + + if args.command == "retry": + if not cfg.execution_params.is_slurm(): + logger.error("retry requires execution_params.mode=slurm") + sys.exit(1) + + failed_shards, summary = check_failed_shards(cfg) + print(json.dumps(summary, indent=2)) + + if not failed_shards: + if summary["max_retries_exceeded"] > 0: + logger.error("No retryable shards remain") + sys.exit(1) + print("All shards already succeeded.") + sys.exit(0) + + if args.interactive: + response = input(f"Retry {len(failed_shards)} shard(s)? (y/N) ") + if response.strip().lower() != "y": + print("Cancelled.") + sys.exit(1) + + job_id = submit_slurm_job(cfg, config_path, failed_shards) + if job_id is None: + sys.exit(1) + + print(job_id) + sys.exit(0) + + if args.command in {"launch", "run"}: + if args.command == "launch": + logger.warning("'launch' is deprecated; use 'run' instead") + sys.exit(launch_pipeline(cfg, config_path, force_retry=args.force_retry)) + + except Exception as exc: + logger.error("Error: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG)) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/mmirage/config/config.py b/src/mmirage/config/config.py index 8ef3cc7..ccfeb65 100644 --- a/src/mmirage/config/config.py +++ b/src/mmirage/config/config.py @@ -1,13 +1,81 @@ """Configuration dataclasses for MMIRAGE pipeline.""" -from dataclasses import dataclass -from typing import Any, Dict, List +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional from mmirage.config.loading import LoadingParams from mmirage.core.process.base import BaseProcessorConfig from mmirage.core.process.variables import InputVar, OutputVar +@dataclass +class ExecutionParams: + """Parameters for executing the MMIRAGE pipeline. + + Defines how the pipeline is executed, including local or SLURM-based + distributed execution, retry logic, and resource allocation. + + Attributes: + mode: Execution mode: "local" or "slurm". Defaults to "local". + max_retries: Maximum number of retries for failed shards. Defaults to 3. + poll_interval_seconds: Seconds to wait between polling job status. Defaults to 30. + settle_time_seconds: Seconds to wait after job completes before checking results. Defaults to 60. + settle_poll_interval: Seconds between polls during settle time. Defaults to 10. + + # SLURM-specific parameters + account: HPC account/partition to charge. Required for SLURM mode. + job_name: SLURM job name. Defaults to "mmirage-sharded". + reservation: Optional SLURM reservation name. + nodes: Number of nodes. Defaults to 1. + ntasks_per_node: Number of tasks per node. Defaults to 1. + gpus: Number of GPUs per node. Defaults to 4. + cpus_per_task: Number of CPUs per task. Defaults to 288. + time_limit: Job time limit (HH:MM:SS). Defaults to "11:59:59". + + # Paths + project_root: Base project directory. Can use environment variables with ${VAR}. + report_dir: Directory for SLURM output/error files. Defaults to ~/reports. + hf_home: HuggingFace cache directory. Defaults to ~/hf. + edf_env: Optional EDF environment file path. + """ + + mode: str = "local" + retry: bool = False + max_retries: int = 3 + poll_interval_seconds: int = 30 + settle_time_seconds: int = 60 + settle_poll_interval: int = 10 + + # SLURM parameters + account: Optional[str] = None + job_name: str = "mmirage-sharded" + reservation: Optional[str] = None + nodes: int = 1 + ntasks_per_node: int = 1 + gpus: int = 4 + cpus_per_task: int = 288 + time_limit: str = "11:59:59" + + # Paths (can contain environment variables like ${VAR} or $VAR) + project_root: Optional[str] = None + report_dir: str = "~/reports" + hf_home: str = "~/hf" + edf_env: Optional[str] = None + + def __post_init__(self): + """Validate execution parameters.""" + if self.mode not in ("local", "slurm"): + raise ValueError(f"Invalid execution mode: {self.mode!r}. Must be 'local' or 'slurm'.") + if self.mode == "slurm" and not self.account: + raise ValueError("account is required when mode='slurm'") + if self.max_retries < 0: + raise ValueError(f"max_retries must be >= 0, got {self.max_retries}") + + def is_slurm(self) -> bool: + """Check if execution mode is SLURM.""" + return self.mode == "slurm" + + @dataclass class ProcessingParams: """Parameters for processing dataset samples. @@ -33,15 +101,17 @@ class MMirageConfig: """Main configuration class for MMIRAGE pipeline. Contains all configuration needed to run a MMIRAGE processing pipeline, - including processor configurations, dataset loading parameters, and - processing parameters. + including processor configurations, dataset loading parameters, processing + parameters, and execution parameters. Attributes: processors: List of processor configurations for data transformation. loading_params: Parameters for loading input datasets. processing_params: Parameters for processing dataset samples. + execution_params: Parameters for executing the pipeline (local/SLURM). """ processors: List[BaseProcessorConfig] loading_params: LoadingParams processing_params: ProcessingParams + execution_params: ExecutionParams = field(default_factory=ExecutionParams) diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index 49ab6cf..f58ca41 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -1,5 +1,6 @@ """Data loading configuration for MMIRAGE pipeline.""" +import re from dataclasses import dataclass, field from typing import Union, List, cast, Optional @@ -39,13 +40,19 @@ def __post_init__(self): if self.num_shards < 1: raise ValueError() except (ValueError, TypeError): - raise ValueError(f"Invalid value for num_shards: {self.num_shards!r}") + if _is_unresolved_env_var(self.num_shards): + self.num_shards = 1 + else: + raise ValueError(f"Invalid value for num_shards: {self.num_shards!r}") if isinstance(self.shard_id, str): try: self.shard_id = int(self.shard_id) except (ValueError, TypeError): - raise ValueError(f"Invalid value for shard_id: {self.shard_id!r}") + if _is_unresolved_env_var(self.shard_id): + self.shard_id = 0 + else: + raise ValueError(f"Invalid value for shard_id: {self.shard_id!r}") if isinstance(self.batch_size, str): try: @@ -88,4 +95,12 @@ def get_batch_size(self) -> int: Returns: int: Batch size (minimum 1). """ - return cast(int, self.batch_size) \ No newline at end of file + return cast(int, self.batch_size) + + +_UNRESOLVED_ENV_VAR_PATTERN = re.compile(r"^\$(?:\{[A-Za-z_][A-Za-z0-9_]*\}|[A-Za-z_][A-Za-z0-9_]*)$") + + +def _is_unresolved_env_var(value: str) -> bool: + """Check whether a string still looks like an unresolved shell env var.""" + return bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(value.strip())) \ No newline at end of file diff --git a/src/mmirage/config/utils.py b/src/mmirage/config/utils.py index 77e8bcb..93ae3b6 100644 --- a/src/mmirage/config/utils.py +++ b/src/mmirage/config/utils.py @@ -9,6 +9,15 @@ from mmirage.core.process.base import BaseProcessorConfig, ProcessorRegistry, OutputVar from mmirage.core.loader.base import BaseDataLoaderConfig, DataLoaderRegistry +# Register built-in processors/loaders. +# +# We import configuration modules (lightweight) here so the registries know how +# to construct config/output-var objects from YAML without importing heavy +# processor implementations (e.g. torch/transformers). +import mmirage.core.process.processors.llm.config # noqa: F401 +import mmirage.core.loader.jsonl # noqa: F401 +import mmirage.core.loader.local_hf # noqa: F401 + EnvValue: TypeAlias = Union[str, List["EnvValue"], Dict[str, "EnvValue"]] diff --git a/src/mmirage/core/loader/__init__.py b/src/mmirage/core/loader/__init__.py index c27714d..217d53d 100644 --- a/src/mmirage/core/loader/__init__.py +++ b/src/mmirage/core/loader/__init__.py @@ -7,13 +7,3 @@ All loaders inherit from BaseDataLoader and are registered with DataLoaderRegistry for dynamic instantiation based on configuration. """ - -from mmirage.core.loader.jsonl import JSONLDataConfig, JSONLDataLoader -from mmirage.core.loader.local_hf import LocalHFConfig, LocalHFDataLoader - -__all__ = [ - "JSONLDataConfig", - "JSONLDataLoader", - "LocalHFDataLoader", - "LocalHFConfig", -] diff --git a/src/mmirage/core/process/__init__.py b/src/mmirage/core/process/__init__.py index b031213..3002d4c 100644 --- a/src/mmirage/core/process/__init__.py +++ b/src/mmirage/core/process/__init__.py @@ -1,15 +1,7 @@ """Processing module for MMIRAGE pipeline. -This module provides the core processing infrastructure: -- Base classes for processors and variables -- MMIRAGEMapper for orchestrating transformations -- LLM processor implementation for generative tasks (including multimodal) +Important: keep this package import lightweight. -Processors are responsible for generating new output variables from -existing variables, enabling flexible data transformations. +The LLM processor implementation depends on heavy libraries (e.g. torch). +We intentionally do not import processor implementations from here. """ - -from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig -from mmirage.core.process.processors.llm.llm_processor import LLMProcessor - -__all__ = ["LLMOutputVar", "SGLangLLMConfig", "LLMProcessor"] diff --git a/src/mmirage/core/process/base.py b/src/mmirage/core/process/base.py index f374e12..988bae7 100644 --- a/src/mmirage/core/process/base.py +++ b/src/mmirage/core/process/base.py @@ -1,6 +1,7 @@ """Base classes and registry for processors in MMIRAGE.""" import abc +from importlib import import_module from dataclasses import dataclass from typing import Callable, Generic, List, Type, TypeVar @@ -80,6 +81,28 @@ class ProcessorRegistry: _config_registry = dict() _output_var_registry = dict() + # Import processor implementations lazily because they may depend on heavy + # libraries (torch/transformers). Config/output-var types are registered via + # mmirage.config.utils importing the relevant config modules. + _lazy_processor_imports = {"llm": "mmirage.core.process.processors.llm.llm_processor"} + + @classmethod + def register_types( + cls, + name: str, + config_cls: Type[BaseProcessorConfig], + output_var_cls: Type[OutputVar], + ) -> None: + """Register config/output-var types without importing processor implementations.""" + cls._config_registry[name] = config_cls + cls._output_var_registry[name] = output_var_cls + + @classmethod + def _maybe_import_processor(cls, name: str) -> None: + module = cls._lazy_processor_imports.get(name) + if module: + import_module(module) + @classmethod def register( cls, @@ -118,6 +141,9 @@ def get_processor(cls, name: str) -> Type[BaseProcessor]: Raises: ValueError: If no processor is registered under the given name. """ + if name not in cls._registry: + cls._maybe_import_processor(name) + if name not in cls._registry: raise ValueError( f"Processor {name} not registered. Available processors are {list(cls._registry.keys())}" diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index 4c195af..e53c2c6 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -9,6 +9,7 @@ from mmirage.core.process.variables import BaseVar, OutputVar from mmirage.core.process.base import BaseProcessorConfig +from mmirage.core.process.base import ProcessorRegistry from jinja2 import Environment, meta logger = logging.getLogger(__name__) @@ -103,3 +104,6 @@ def is_computable(self, vars: Sequence[BaseVar]) -> bool: return False return True + + +ProcessorRegistry.register_types("llm", SGLangLLMConfig, LLMOutputVar) From e4cb7ac8c66cc416a9325b8aa53129b15862c8b9 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:40:11 +0100 Subject: [PATCH 24/47] big refactor for the PR --- configs/config_comprehensive.yaml | 6 +- configs/config_mock.yaml | 1 - configs/config_mock_vision.yaml | 1 - src/mmirage/cli.py | 435 ++++++++---------------------- src/mmirage/cli_utils/__init__.py | 1 + src/mmirage/cli_utils/runtime.py | 77 ++++++ src/mmirage/cli_utils/slurm.py | 168 ++++++++++++ src/mmirage/cli_utils/status.py | 123 +++++++++ 8 files changed, 491 insertions(+), 321 deletions(-) create mode 100644 src/mmirage/cli_utils/__init__.py create mode 100644 src/mmirage/cli_utils/runtime.py create mode 100644 src/mmirage/cli_utils/slurm.py create mode 100644 src/mmirage/cli_utils/status.py diff --git a/configs/config_comprehensive.yaml b/configs/config_comprehensive.yaml index 4285237..e017747 100644 --- a/configs/config_comprehensive.yaml +++ b/configs/config_comprehensive.yaml @@ -165,8 +165,8 @@ execution_params: # HuggingFace cache directory (default: ~/hf) hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf" - # Optional EDF environment file path for cluster-specific setup - # edf_env: "/path/to/.edf/mmirage.toml" + # EDF environment file path for cluster-specific setup + edf_env: "/users/${USER}/.edf/mmirage.toml" # ========================================================================== # JOB MONITORING (for "submit" and retry orchestration) @@ -197,7 +197,7 @@ execution_params: # JOB_ID=$(python -m mmirage.cli submit --config config.yaml) # # 4. Run a single shard locally: -# python -m mmirage.cli process --config config.yaml --shard-id 0 +# python -m mmirage.cli run --config config.yaml --shard-id 0 # # 5. Check status of all shards (and optionally submit retries): # python -m mmirage.cli check --config config.yaml diff --git a/configs/config_mock.yaml b/configs/config_mock.yaml index 0f2bd54..4f62ff7 100644 --- a/configs/config_mock.yaml +++ b/configs/config_mock.yaml @@ -55,6 +55,5 @@ processing_params: execution_params: mode: local retry: false - max_retries: 3 report_dir: ~/reports hf_home: ~/hf diff --git a/configs/config_mock_vision.yaml b/configs/config_mock_vision.yaml index 811d61e..46965dc 100644 --- a/configs/config_mock_vision.yaml +++ b/configs/config_mock_vision.yaml @@ -44,6 +44,5 @@ processing_params: execution_params: mode: local retry: false - max_retries: 3 report_dir: ~/reports hf_home: ~/hf diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index 0744bbc..51b9db7 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -6,13 +6,18 @@ import json import logging import os -import shlex import subprocess import sys -import time -from pathlib import Path -from typing import List, Optional, Sequence, Tuple - +from dataclasses import asdict +from typing import List, Optional + +from mmirage.cli_utils.runtime import setup_runtime, validate_paths +from mmirage.cli_utils.slurm import require_slurm, submit_slurm_job, wait_for_slurm_job +from mmirage.cli_utils.status import ( + check_failed_shards, + status_exit_code, + submit_failed_shards, +) from mmirage.config.config import MMirageConfig from mmirage.config.utils import load_mmirage_config @@ -20,97 +25,6 @@ logger = logging.getLogger(__name__) -def expand_path(path: str, project_root: Optional[str] = None) -> str: - """Expand environment variables, user home and relative paths.""" - expanded = os.path.expanduser(os.path.expandvars(path)) - if not os.path.isabs(expanded) and project_root: - expanded = os.path.join(project_root, expanded) - return os.path.abspath(expanded) - - -def get_project_root(cfg: MMirageConfig) -> str: - """Return the configured project root, or the current working directory.""" - project_root = cfg.execution_params.project_root - if project_root: - return expand_path(project_root) - return os.getcwd() - - -def create_directories(paths: Sequence[str]) -> None: - """Create directories if they do not already exist.""" - for path in paths: - Path(path).mkdir(parents=True, exist_ok=True) - - -def validate_paths(cfg: MMirageConfig) -> None: - """Validate pre-existing execution paths.""" - project_root = get_project_root(cfg) - if cfg.execution_params.edf_env: - edf_env = expand_path(cfg.execution_params.edf_env, project_root) - if not os.path.exists(edf_env): - raise FileNotFoundError(f"EDF environment file not found: {edf_env}") - - -def get_shard_state_dir(state_root: str, shard_id: int) -> str: - """Return the state directory for a shard.""" - return os.path.join(state_root, f"shard_{shard_id}") - - -def get_shard_status(state_dir: str) -> Tuple[str, int]: - """Read the current status and retry count for a shard.""" - status_file = os.path.join(state_dir, "status.json") - if not os.path.exists(status_file): - return ("missing", 0) - - try: - with open(status_file, "r", encoding="utf-8") as handle: - data = json.load(handle) - except (OSError, json.JSONDecodeError) as exc: - logger.warning("Failed to read shard status from %s: %s", status_file, exc) - return ("unknown", 0) - - return (str(data.get("status", "unknown")), int(data.get("retry_count", 0))) - - -def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], dict]: - """Return retryable failed shards and a compact summary.""" - state_root = cfg.loading_params.get_state_root() - if not state_root: - raise ValueError("loading_params.state_dir is required to check shard status") - - num_shards = cfg.loading_params.get_num_shards() - max_retries = cfg.execution_params.max_retries - failed_shards: List[int] = [] - success_count = 0 - exhausted_count = 0 - - for shard_id in range(num_shards): - status, retry_count = get_shard_status(get_shard_state_dir(state_root, shard_id)) - if status == "success": - success_count += 1 - continue - - if retry_count >= max_retries: - exhausted_count += 1 - logger.warning( - "Shard %s exceeded retry budget (%s/%s)", - shard_id, - retry_count, - max_retries, - ) - continue - - failed_shards.append(shard_id) - - summary = { - "total": num_shards, - "successful": success_count, - "failed": len(failed_shards), - "max_retries_exceeded": exhausted_count, - } - return failed_shards, summary - - def run_local(config_path: str, shard_id: Optional[int] = None) -> int: """Run one shard in the current Python environment.""" command = [sys.executable, "-m", "mmirage.shard_process", "--config", config_path] @@ -123,110 +37,6 @@ def run_local(config_path: str, shard_id: Optional[int] = None) -> int: return result.returncode -def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str: - """Build the sbatch payload executed for each array task.""" - project_root = get_project_root(cfg) - hf_home = expand_path(cfg.execution_params.hf_home, project_root) - state_root = expand_path(cfg.loading_params.get_state_root(), project_root) - - lines = [ - "#!/bin/bash", - "set -euo pipefail", - f"export HF_HOME={shlex.quote(hf_home)}", - f"export MMIRAGE_CONFIG={shlex.quote(config_path)}", - f"mkdir -p {shlex.quote(hf_home)}", - f"mkdir -p {shlex.quote(state_root)}", - "srun_args=(--cpus-per-task ${SLURM_CPUS_PER_TASK:-1} --wait 60)", - ] - - if cfg.execution_params.edf_env: - edf_env = expand_path(cfg.execution_params.edf_env, project_root) - lines.append(f"srun_args+=(--environment={shlex.quote(edf_env)})") - - lines.extend( - [ - f"srun \"${{srun_args[@]}}\" {shlex.quote(sys.executable)} -m mmirage.shard_process --config \"$MMIRAGE_CONFIG\"", - "echo \"Shard ${SLURM_ARRAY_TASK_ID:-0} completed\"", - ] - ) - return "\n".join(lines) + "\n" - - -def submit_slurm_job( - cfg: MMirageConfig, - config_path: str, - shard_ids: Optional[Sequence[int]] = None, -) -> Optional[int]: - """Submit a SLURM array job and return its job ID.""" - project_root = get_project_root(cfg) - report_dir = expand_path(cfg.execution_params.report_dir, project_root) - create_directories([report_dir]) - - command = [ - "sbatch", - "--parsable", - f"--job-name={cfg.execution_params.job_name}", - f"--chdir={project_root}", - f"--output={os.path.join(report_dir, 'R-%x.%A_%a.out')}", - f"--error={os.path.join(report_dir, 'R-%x.%A_%a.err')}", - f"--nodes={cfg.execution_params.nodes}", - f"--ntasks-per-node={cfg.execution_params.ntasks_per_node}", - f"--gres=gpu:{cfg.execution_params.gpus}", - f"--cpus-per-task={cfg.execution_params.cpus_per_task}", - f"--time={cfg.execution_params.time_limit}", - f"--account={cfg.execution_params.account}", - ] - - if cfg.execution_params.reservation: - command.append(f"--reservation={cfg.execution_params.reservation}") - - requested_shards = list(shard_ids or []) - if requested_shards: - command.append(f"--array={','.join(str(shard_id) for shard_id in requested_shards)}") - else: - num_shards = cfg.loading_params.get_num_shards() - command.append(f"--array=0-{num_shards - 1}") - - logger.info("Submitting SLURM job: %s", " ".join(command)) - result = subprocess.run( - command, - input=build_sbatch_script(cfg, config_path), - text=True, - capture_output=True, - check=False, - ) - - if result.returncode != 0: - logger.error("sbatch failed: %s", result.stderr.strip()) - return None - - raw_job_id = result.stdout.strip().split(";", 1)[0] - try: - return int(raw_job_id) - except ValueError: - logger.error("Unable to parse job id from sbatch output: %s", result.stdout.strip()) - return None - - -def wait_for_slurm_job(job_id: int, cfg: MMirageConfig) -> None: - """Wait for a SLURM job array to leave the queue.""" - logger.info("Waiting for SLURM job %s", job_id) - while True: - result = subprocess.run( - ["squeue", "-h", "-j", str(job_id)], - capture_output=True, - text=True, - check=False, - ) - if result.returncode == 0 and not result.stdout.strip(): - break - time.sleep(cfg.execution_params.poll_interval_seconds) - - if cfg.execution_params.settle_time_seconds > 0: - logger.info("Waiting %ss for state files to settle", cfg.execution_params.settle_time_seconds) - time.sleep(cfg.execution_params.settle_time_seconds) - - def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = False) -> int: """Launch the pipeline according to execution mode and retry settings.""" if not cfg.execution_params.is_slurm(): @@ -248,7 +58,7 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa wait_for_slurm_job(job_id, cfg) failed_shards, summary = check_failed_shards(cfg) - if not failed_shards and summary["max_retries_exceeded"] == 0: + if status_exit_code(failed_shards, summary) == 0: logger.info("All shards completed successfully") return 0 @@ -284,10 +94,6 @@ def build_argparser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="MMIRAGE command-line interface") subparsers = parser.add_subparsers(dest="command", required=True) - process_parser = subparsers.add_parser("process", help="Run a shard locally") - add_shared_arguments(process_parser) - process_parser.add_argument("--shard-id", type=int, default=None, help="Shard id override") - submit_parser = subparsers.add_parser("submit", help="Submit one SLURM array job") add_shared_arguments(submit_parser) submit_parser.add_argument( @@ -339,17 +145,6 @@ def build_argparser() -> argparse.ArgumentParser: retry_group.add_argument("--no-interactive", dest="interactive", action="store_false") retry_parser.set_defaults(interactive=True) - launch_parser = subparsers.add_parser( - "launch", - help="(Deprecated) Use 'run'. Launch according to execution_params.mode and execution_params.retry", - ) - add_shared_arguments(launch_parser) - launch_parser.add_argument( - "--force-retry", - action="store_true", - help="Enable retry orchestration even if execution_params.retry is false", - ) - run_parser = subparsers.add_parser( "run", help="Run according to execution_params.mode and execution_params.retry", @@ -360,45 +155,112 @@ def build_argparser() -> argparse.ArgumentParser: action="store_true", help="Enable retry orchestration even if execution_params.retry is false", ) + run_parser.add_argument( + "--shard-id", + type=int, + default=None, + help="Run a single shard locally (overrides execution mode)", + ) return parser -def _maybe_submit_retry_job( - *, - cfg: MMirageConfig, - config_path: str, - failed_shards: Sequence[int], - interactive: bool, -) -> int: - if not cfg.execution_params.is_slurm(): - logger.error("Retry submission requires execution_params.mode=slurm") - return 1 +def parse_shard_ids(raw_value: Optional[str], num_shards: Optional[int] = None) -> List[int]: + """Parse a comma-separated shard id list.""" + if not raw_value: + return [] - if not failed_shards: - return 0 + shard_ids: List[int] = [] + for raw_shard_id in raw_value.split(","): + candidate = raw_shard_id.strip() + if not candidate: + continue - if interactive: - if not sys.stdin.isatty(): - logger.error("Non-interactive input detected; re-run with --no-interactive to auto-submit retries") - return 1 - response = input(f"Retry {len(failed_shards)} shard(s)? (y/N) ") - if response.strip().lower() != "y": - print("Cancelled.") - return 1 + try: + shard_id = int(candidate) + except ValueError as exc: + raise ValueError(f"Invalid shard id {candidate!r}; expected integers") from exc + + if shard_id < 0: + raise ValueError(f"Invalid shard id {shard_id}; expected non-negative integer") + if num_shards is not None and shard_id >= num_shards: + raise ValueError(f"Invalid shard id {shard_id}; expected 0 <= shard_id < {num_shards}") + + shard_ids.append(shard_id) - job_id = submit_slurm_job(cfg, config_path, failed_shards) + return shard_ids + + +def handle_run(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: + """Handle the canonical run command.""" + if args.shard_id is not None: + return run_local(config_path, args.shard_id) + return launch_pipeline(cfg, config_path, force_retry=args.force_retry) + + +def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: + """Submit a SLURM array job and optionally wait.""" + if require_slurm(cfg, "submit") != 0: + return 1 + + shard_ids = parse_shard_ids(args.shard_ids, cfg.loading_params.get_num_shards()) + job_id = submit_slurm_job(cfg, config_path, shard_ids) if job_id is None: return 1 + print(job_id) - return 0 + if not args.wait: + return 0 + wait_for_slurm_job(job_id, cfg) + failed_shards, summary = check_failed_shards(cfg) + return status_exit_code(failed_shards, summary) + + +def handle_check(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: + """Inspect shard status and optionally submit retries.""" + failed_shards, summary = check_failed_shards(cfg) + print(json.dumps(asdict(summary), indent=2)) + + status_code = status_exit_code(failed_shards, summary) + if not cfg.execution_params.is_slurm(): + return status_code + + if args.summary_only or not args.retry: + return status_code + + if not failed_shards: + return status_code + + return submit_failed_shards( + cfg=cfg, + config_path=config_path, + failed_shards=failed_shards, + interactive=bool(args.interactive), + ) -def parse_shard_ids(raw_value: Optional[str]) -> List[int]: - """Parse a comma-separated shard id list.""" - if not raw_value: - return [] - return [int(value.strip()) for value in raw_value.split(",") if value.strip()] + +def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: + """Submit retries for failed shards only.""" + if require_slurm(cfg, "retry") != 0: + return 1 + + failed_shards, summary = check_failed_shards(cfg) + print(json.dumps(asdict(summary), indent=2)) + + if not failed_shards: + if summary.max_retries_exceeded > 0: + logger.error("No retryable shards remain") + return 1 + print("All shards already succeeded.") + return 0 + + return submit_failed_shards( + cfg=cfg, + config_path=config_path, + failed_shards=failed_shards, + interactive=bool(args.interactive), + ) def main() -> None: @@ -410,81 +272,22 @@ def main() -> None: try: config_path = os.path.abspath(args.config) cfg = load_mmirage_config(config_path) + + setup_runtime(cfg, args.log_level) validate_paths(cfg) - if args.command == "process": - sys.exit(run_local(config_path, args.shard_id)) - - if args.command == "submit": - if not cfg.execution_params.is_slurm(): - logger.error("submit requires execution_params.mode=slurm") - sys.exit(1) - - job_id = submit_slurm_job(cfg, config_path, parse_shard_ids(args.shard_ids)) - if job_id is None: - sys.exit(1) - - print(job_id) - if args.wait: - wait_for_slurm_job(job_id, cfg) - failed_shards, summary = check_failed_shards(cfg) - sys.exit(0 if not failed_shards and summary["max_retries_exceeded"] == 0 else 1) - sys.exit(0) - - if args.command == "check": - failed_shards, summary = check_failed_shards(cfg) - print(json.dumps(summary, indent=2)) - - if not cfg.execution_params.is_slurm(): - sys.exit(0 if not failed_shards and summary["max_retries_exceeded"] == 0 else 1) - - if args.summary_only or not args.retry: - sys.exit(0 if not failed_shards and summary["max_retries_exceeded"] == 0 else 1) - - if not failed_shards: - sys.exit(0 if summary["max_retries_exceeded"] == 0 else 1) - - sys.exit( - _maybe_submit_retry_job( - cfg=cfg, - config_path=config_path, - failed_shards=failed_shards, - interactive=bool(args.interactive), - ) - ) - - if args.command == "retry": - if not cfg.execution_params.is_slurm(): - logger.error("retry requires execution_params.mode=slurm") - sys.exit(1) - - failed_shards, summary = check_failed_shards(cfg) - print(json.dumps(summary, indent=2)) - - if not failed_shards: - if summary["max_retries_exceeded"] > 0: - logger.error("No retryable shards remain") - sys.exit(1) - print("All shards already succeeded.") - sys.exit(0) - - if args.interactive: - response = input(f"Retry {len(failed_shards)} shard(s)? (y/N) ") - if response.strip().lower() != "y": - print("Cancelled.") - sys.exit(1) - - job_id = submit_slurm_job(cfg, config_path, failed_shards) - if job_id is None: - sys.exit(1) - - print(job_id) - sys.exit(0) - - if args.command in {"launch", "run"}: - if args.command == "launch": - logger.warning("'launch' is deprecated; use 'run' instead") - sys.exit(launch_pipeline(cfg, config_path, force_retry=args.force_retry)) + handlers = { + "run": handle_run, + "submit": handle_submit, + "check": handle_check, + "retry": handle_retry, + } + handler = handlers.get(args.command) + if handler is None: + logger.error("Unknown command: %s", args.command) + sys.exit(2) + + sys.exit(handler(args, cfg, config_path)) except Exception as exc: logger.error("Error: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG)) diff --git a/src/mmirage/cli_utils/__init__.py b/src/mmirage/cli_utils/__init__.py new file mode 100644 index 0000000..ecab471 --- /dev/null +++ b/src/mmirage/cli_utils/__init__.py @@ -0,0 +1 @@ +"""Internal utility modules for the MMIRAGE CLI.""" diff --git a/src/mmirage/cli_utils/runtime.py b/src/mmirage/cli_utils/runtime.py new file mode 100644 index 0000000..4cc9fcc --- /dev/null +++ b/src/mmirage/cli_utils/runtime.py @@ -0,0 +1,77 @@ +"""Runtime/path helpers for the MMIRAGE CLI.""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Optional, Sequence + +from mmirage.config.config import MMirageConfig + + +logger = logging.getLogger(__name__) + + +def expand_path(path: str, project_root: Optional[str] = None) -> str: + """Expand environment variables, user home and relative paths.""" + expanded = os.path.expanduser(os.path.expandvars(path)) + if not os.path.isabs(expanded) and project_root: + expanded = os.path.join(project_root, expanded) + return os.path.abspath(expanded) + + +def get_project_root(cfg: MMirageConfig) -> str: + """Return the configured project root, or the current working directory.""" + project_root = cfg.execution_params.project_root + if project_root: + return expand_path(project_root) + return os.getcwd() + + +def create_directories(paths: Sequence[str]) -> None: + """Create directories if they do not already exist.""" + for path in paths: + Path(path).mkdir(parents=True, exist_ok=True) + + +def validate_paths(cfg: MMirageConfig) -> None: + """Validate pre-existing execution paths.""" + project_root = get_project_root(cfg) + if cfg.execution_params.edf_env: + edf_env = expand_path(cfg.execution_params.edf_env, project_root) + if not os.path.exists(edf_env): + raise FileNotFoundError(f"EDF environment file not found: {edf_env}") + + +def add_file_logging(log_file: str, level: str) -> None: + """Add a file handler so logs are also written to disk.""" + expanded_log_file = os.path.abspath(os.path.expanduser(os.path.expandvars(log_file))) + try: + create_directories([str(Path(expanded_log_file).parent)]) + except OSError as exc: + logger.warning("Unable to create log directory for %s: %s", expanded_log_file, exc) + return + + root_logger = logging.getLogger() + for handler in root_logger.handlers: + if isinstance(handler, logging.FileHandler) and os.path.abspath(handler.baseFilename) == expanded_log_file: + return + + try: + file_handler = logging.FileHandler(expanded_log_file, mode="a", encoding="utf-8") + except OSError as exc: + logger.warning("Unable to open log file %s: %s", expanded_log_file, exc) + return + file_handler.setLevel(getattr(logging, level, logging.INFO)) + file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) + root_logger.addHandler(file_handler) + + +def setup_runtime(cfg: MMirageConfig, log_level: str) -> None: + """Initialize runtime-level logging.""" + project_root = get_project_root(cfg) + report_dir = expand_path(cfg.execution_params.report_dir, project_root) + global_log_file = os.path.join(report_dir, f"{cfg.execution_params.job_name}.out") + add_file_logging(global_log_file, log_level) + logger.info("Writing logs to %s", global_log_file) diff --git a/src/mmirage/cli_utils/slurm.py b/src/mmirage/cli_utils/slurm.py new file mode 100644 index 0000000..6f050c7 --- /dev/null +++ b/src/mmirage/cli_utils/slurm.py @@ -0,0 +1,168 @@ +"""SLURM helpers for the MMIRAGE CLI.""" + +from __future__ import annotations + +import logging +import os +import shlex +import subprocess +import time +from typing import Optional, Sequence + +from mmirage.config.config import MMirageConfig +from mmirage.cli_utils.runtime import create_directories, expand_path, get_project_root + + +logger = logging.getLogger(__name__) + + +def _bash_double_quote(value: str) -> str: + """Return a double-quoted bash string literal. + + We intentionally do NOT escape '$' so that $VARS from config can expand on + compute nodes (e.g. $SCRATCH). This matches typical SLURM job scripts. + """ + escaped = value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + + +def _shell_path(value: str, project_root: str) -> str: + """Expand user home and make relative paths project-rooted. + + If the path starts with '$' we assume it will expand on the compute node and + therefore do not attempt to join it with project_root. + """ + raw = (value or "").strip() + if not raw: + return raw + + raw = os.path.expanduser(raw) + if raw.startswith("$"): + return raw + if not os.path.isabs(raw): + raw = os.path.join(project_root, raw) + return raw + + +def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str: + """Build the sbatch payload executed for each array task.""" + project_root = get_project_root(cfg) + hf_home = _shell_path(cfg.execution_params.hf_home, project_root) + state_root = _shell_path(cfg.loading_params.get_state_root(), project_root) + src_root = os.path.join(project_root, "src") + shard_process_path = os.path.join(src_root, "mmirage", "shard_process.py") + + lines = [ + "#!/bin/bash", + "set -euo pipefail", + f"export PYTHONPATH={_bash_double_quote(src_root)}:${{PYTHONPATH:-}}", + f"export SHARD_PROCESS={_bash_double_quote(shard_process_path)}", + f"export HF_HOME={_bash_double_quote(hf_home)}", + f"export MMIRAGE_CONFIG={_bash_double_quote(config_path)}", + f"mkdir -p {_bash_double_quote(hf_home)}", + f"mkdir -p {_bash_double_quote(state_root)}", + "srun_args=(--cpus-per-task ${SLURM_CPUS_PER_TASK:-1} --wait 60)", + ] + + if cfg.execution_params.edf_env: + edf_env = expand_path(cfg.execution_params.edf_env, project_root) + lines.append(f"srun_args+=(--environment={shlex.quote(edf_env)})") + + account = cfg.execution_params.account + if not account: + raise ValueError("execution_params.account must be set in slurm mode") + lines.append(f"srun_args+=(-A {shlex.quote(account)})") + + if cfg.execution_params.reservation: + lines.append(f"srun_args+=(--reservation={shlex.quote(cfg.execution_params.reservation)})") + + lines.extend( + [ + "srun \"${srun_args[@]}\" bash -c 'if command -v python3 >/dev/null 2>&1; then PYTHON_CMD=python3; elif command -v python >/dev/null 2>&1; then PYTHON_CMD=python; else echo \"python3/python not found in PATH\" >&2; exit 127; fi; echo \"Using Python: ${PYTHON_CMD} ($(${PYTHON_CMD} --version 2>&1))\"; ${PYTHON_CMD} -c \"import sys; raise SystemExit(0 if sys.version_info >= (3, 10) else 2)\" || { echo \"MMIRAGE requires Python >= 3.10 on compute nodes\" >&2; exit 2; }; exec ${PYTHON_CMD} \"$SHARD_PROCESS\" --config \"$MMIRAGE_CONFIG\"'", + "echo \"Shard ${SLURM_ARRAY_TASK_ID:-0} completed\"", + ] + ) + return "\n".join(lines) + "\n" + + +def submit_slurm_job( + cfg: MMirageConfig, + config_path: str, + shard_ids: Optional[Sequence[int]] = None, +) -> Optional[int]: + """Submit a SLURM array job and return its job ID.""" + project_root = get_project_root(cfg) + report_dir = expand_path(cfg.execution_params.report_dir, project_root) + create_directories([report_dir]) + + command = [ + "sbatch", + "--parsable", + f"--job-name={cfg.execution_params.job_name}", + f"--chdir={project_root}", + f"--output={os.path.join(report_dir, 'R-%x.%A_%a.out')}", + f"--error={os.path.join(report_dir, 'R-%x.%A_%a.err')}", + f"--nodes={cfg.execution_params.nodes}", + f"--ntasks-per-node={cfg.execution_params.ntasks_per_node}", + f"--gres=gpu:{cfg.execution_params.gpus}", + f"--cpus-per-task={cfg.execution_params.cpus_per_task}", + f"--time={cfg.execution_params.time_limit}", + f"--account={cfg.execution_params.account}", + ] + + if cfg.execution_params.reservation: + command.append(f"--reservation={cfg.execution_params.reservation}") + + requested_shards = list(shard_ids or []) + if requested_shards: + command.append(f"--array={','.join(str(shard_id) for shard_id in requested_shards)}") + else: + num_shards = cfg.loading_params.get_num_shards() + command.append(f"--array=0-{num_shards - 1}") + + logger.info("Submitting SLURM job: %s", " ".join(command)) + result = subprocess.run( + command, + input=build_sbatch_script(cfg, config_path), + text=True, + capture_output=True, + check=False, + ) + + if result.returncode != 0: + logger.error("sbatch failed: %s", result.stderr.strip()) + return None + + raw_job_id = result.stdout.strip().split(";", 1)[0] + try: + return int(raw_job_id) + except ValueError: + logger.error("Unable to parse job id from sbatch output: %s", result.stdout.strip()) + return None + + +def wait_for_slurm_job(job_id: int, cfg: MMirageConfig) -> None: + """Wait for a SLURM job array to leave the queue.""" + logger.info("Waiting for SLURM job %s", job_id) + while True: + result = subprocess.run( + ["squeue", "-h", "-j", str(job_id)], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0 and not result.stdout.strip(): + break + time.sleep(cfg.execution_params.poll_interval_seconds) + + if cfg.execution_params.settle_time_seconds > 0: + logger.info("Waiting %ss for state files to settle", cfg.execution_params.settle_time_seconds) + time.sleep(cfg.execution_params.settle_time_seconds) + + +def require_slurm(cfg: MMirageConfig, command_name: str) -> int: + """Ensure command can only run in SLURM mode.""" + if cfg.execution_params.is_slurm(): + return 0 + logger.error("%s requires execution_params.mode=slurm", command_name) + return 1 diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py new file mode 100644 index 0000000..d6e242b --- /dev/null +++ b/src/mmirage/cli_utils/status.py @@ -0,0 +1,123 @@ +"""Shard status and retry helpers for the MMIRAGE CLI.""" + +from __future__ import annotations + +import json +import logging +import os +import sys +from dataclasses import dataclass +from typing import List, Sequence, Tuple + +from mmirage.config.config import MMirageConfig +from mmirage.cli_utils.slurm import submit_slurm_job + + +logger = logging.getLogger(__name__) + + +@dataclass +class ShardSummary: + """Compact status summary for shard execution.""" + + total: int + successful: int + failed: int + max_retries_exceeded: int + + +def get_shard_state_dir(state_root: str, shard_id: int) -> str: + """Return the state directory for a shard.""" + return os.path.join(state_root, f"shard_{shard_id}") + + +def get_shard_status(state_dir: str) -> Tuple[str, int]: + """Read the current status and retry count for a shard.""" + status_file = os.path.join(state_dir, "status.json") + if not os.path.exists(status_file): + return ("missing", 0) + + try: + with open(status_file, "r", encoding="utf-8") as handle: + data = json.load(handle) + except (OSError, json.JSONDecodeError) as exc: + logger.warning("Failed to read shard status from %s: %s", status_file, exc) + return ("unknown", 0) + + return (str(data.get("status", "unknown")), int(data.get("retry_count", 0))) + + +def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: + """Return retryable failed shards and a compact summary.""" + state_root = cfg.loading_params.get_state_root() + if not state_root: + raise ValueError("loading_params.state_dir is required to check shard status") + + num_shards = cfg.loading_params.get_num_shards() + max_retries = cfg.execution_params.max_retries + failed_shards: List[int] = [] + success_count = 0 + exhausted_count = 0 + + for shard_id in range(num_shards): + status, retry_count = get_shard_status(get_shard_state_dir(state_root, shard_id)) + if status == "success": + success_count += 1 + continue + + if retry_count >= max_retries: + exhausted_count += 1 + logger.warning( + "Shard %s exceeded retry budget (%s/%s)", + shard_id, + retry_count, + max_retries, + ) + continue + + failed_shards.append(shard_id) + + summary = ShardSummary( + total=num_shards, + successful=success_count, + failed=len(failed_shards), + max_retries_exceeded=exhausted_count, + ) + return failed_shards, summary + + +def confirm_retry(count: int, interactive: bool) -> bool: + """Return whether retry submission is confirmed.""" + if not interactive: + return True + if not sys.stdin.isatty(): + logger.error("Non-interactive input detected; use --no-interactive") + return False + response = input(f"Retry {count} shard(s)? (y/N) ") + return response.strip().lower() == "y" + + +def status_exit_code(failed_shards: Sequence[int], summary: ShardSummary) -> int: + """Map shard status to an exit code.""" + return 0 if not failed_shards and summary.max_retries_exceeded == 0 else 1 + + +def submit_failed_shards( + cfg: MMirageConfig, + config_path: str, + failed_shards: Sequence[int], + interactive: bool, +) -> int: + """Submit retry jobs for failed shards when requested.""" + if not failed_shards: + return 0 + + if not confirm_retry(len(failed_shards), interactive): + return 1 + + job_id = submit_slurm_job(cfg, config_path, failed_shards) + if job_id is None: + return 1 + + print(job_id) + return 0 From 70799acc4fa4262f77090d2a47c9fda32697de92 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:57:57 +0100 Subject: [PATCH 25/47] fixing number of shards for local configs --- README.md | 4 ++-- configs/config_mock.yaml | 4 ++-- configs/config_mock_vision.yaml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index eb6306b..8238f1d 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ loading_params: - path: /path/to/dataset type: loadable output_dir: /path/to/output/shards - num_shards: "$SLURM_ARRAY_TASK_COUNT" + num_shards: 4 shard_id: "$SLURM_ARRAY_TASK_ID" batch_size: 64 @@ -150,7 +150,7 @@ loading_params: - path: /path/to/image/dataset type: loadable output_dir: /path/to/output/shards - num_shards: "$SLURM_ARRAY_TASK_COUNT" + num_shards: 4 shard_id: "$SLURM_ARRAY_TASK_ID" batch_size: 32 diff --git a/configs/config_mock.yaml b/configs/config_mock.yaml index 4f62ff7..9c30dbc 100644 --- a/configs/config_mock.yaml +++ b/configs/config_mock.yaml @@ -19,8 +19,8 @@ loading_params: type: JSONL output_dir: tests/output/data - num_shards: 4 - shard_id: "$SLURM_ARRAY_TASK_ID" + num_shards: 1 + shard_id: 0 batch_size: 64 processing_params: diff --git a/configs/config_mock_vision.yaml b/configs/config_mock_vision.yaml index 46965dc..f90c6a4 100644 --- a/configs/config_mock_vision.yaml +++ b/configs/config_mock_vision.yaml @@ -18,8 +18,8 @@ loading_params: output_dir: tests/output/data_vision image_base_path: tests/mock_data_vision # Base directory where images are stored - num_shards: 4 - shard_id: "$SLURM_ARRAY_TASK_ID" + num_shards: 1 + shard_id: 0 batch_size: 1 processing_params: From 34f6595a52e110b8b61650db6c079f52f57f6b3b Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:09:13 +0100 Subject: [PATCH 26/47] added retry to local mode as well --- src/mmirage/cli.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index 51b9db7..b8c71a5 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -39,10 +39,36 @@ def run_local(config_path: str, shard_id: Optional[int] = None) -> int: def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = False) -> int: """Launch the pipeline according to execution mode and retry settings.""" + auto_retry = force_retry or cfg.execution_params.retry + if not cfg.execution_params.is_slurm(): - return run_local(config_path, cfg.loading_params.get_shard_id()) + initial_shard_id = cfg.loading_params.get_shard_id() + if not auto_retry: + return run_local(config_path, initial_shard_id) + + if not cfg.loading_params.get_state_root(): + logger.warning( + "Local retry requires loading_params.state_dir; running once without orchestration" + ) + return run_local(config_path, initial_shard_id) + + shard_ids: List[int] = [initial_shard_id] + while True: + for shard_id in shard_ids: + run_local(config_path, shard_id) + + failed_shards, summary = check_failed_shards(cfg) + if status_exit_code(failed_shards, summary) == 0: + logger.info("All shards completed successfully") + return 0 + + if not failed_shards: + logger.error("Pipeline ended with shards that exceeded max retries") + return 1 + + logger.warning("Retrying failed shards locally: %s", ",".join(map(str, failed_shards))) + shard_ids = failed_shards - auto_retry = force_retry or cfg.execution_params.retry shard_ids: List[int] = [] while True: From 995125987228b244fe03bf01b34ac17778e172aa Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:11:54 +0100 Subject: [PATCH 27/47] added config changes to readme as well --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 8238f1d..3dec545 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,10 @@ processing_params: - role: assistant content: "{{ formatted_answer }}" modalities: "{{ modalities }}" + +execution_params: + mode: local + retry: false ``` Configuration explanation: @@ -177,6 +181,10 @@ processing_params: image: "{{ medical_image }}" caption: "{{ enhanced_caption }}" original_caption: "{{ original_caption }}" + +execution_params: + mode: local + retry: false ``` Key multimodal features: From 19bf689ce2f7f0ffe9c1ad88bcbb44ec655f16aa Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Sun, 22 Mar 2026 19:55:02 +0100 Subject: [PATCH 28/47] ready for PR --- retry_failed.sh | 52 ---------------------------------------------- run.sh | 2 +- run_with_retry.sh | 31 --------------------------- src/mmirage/cli.py | 19 +++++++++++++---- 4 files changed, 16 insertions(+), 88 deletions(-) delete mode 100644 retry_failed.sh delete mode 100644 run_with_retry.sh diff --git a/retry_failed.sh b/retry_failed.sh deleted file mode 100644 index 90a97a6..0000000 --- a/retry_failed.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash -# MMIRAGE retry failed shards script -# -# Check for failed logical shards and relaunch them interactively. -# -# Usage: -# bash retry_failed.sh [--config path/to/config.yaml] -# -# Configuration: -# Set the CFG environment variable to point to your config file, or -# use the --config argument. Defaults to configs/config_mock.yaml. -# - -set -euo pipefail -IFS=$'\n\t' - -# Parse command line arguments -CFG="${CFG:-configs/config_mock.yaml}" -while (( $# > 0 )); do - case "$1" in - --config) - CFG="$2" - shift 2 - ;; - *) - echo "Unknown option: $1" >&2 - exit 1 - ;; - esac -done - -if [[ ! -f "$CFG" ]]; then - echo "āŒ Config file not found: $CFG" >&2 - exit 1 -fi - -echo "Checking shard states from config: $CFG" -echo "" - -# Use MMIRAGE CLI to check failed shards (summary only; no retry submission) -python -m mmirage.cli check --config "$CFG" --summary-only || true - -echo "" -read -p "Submit retry job for failed shards? (y/N) " -n 1 -r -echo - -if [[ $REPLY =~ ^[Yy]$ ]]; then - python -m mmirage.cli retry --config "$CFG" --no-interactive -else - echo "Cancelled." - exit 1 -fi diff --git a/run.sh b/run.sh index 6ac856a..a35fb08 100644 --- a/run.sh +++ b/run.sh @@ -1,5 +1,5 @@ #!/bin/bash -# MMIRAGE launch script. +# MMIRAGE launch script. MMIRAGE should now be ran with a single command, and the behavior is driven by the config file. # # Launch behavior is driven by the config file: # - execution_params.retry=false: submit one SLURM array job, or run locally diff --git a/run_with_retry.sh b/run_with_retry.sh deleted file mode 100644 index e1b116a..0000000 --- a/run_with_retry.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash -# MMIRAGE pipeline orchestration with forced automatic retry. -# -# Usage: -# bash run_with_retry.sh [--config path/to/config.yaml] - -set -euo pipefail -IFS=$'\n\t' - -CFG="${CFG:-configs/config_mock.yaml}" - -while (( $# > 0 )); do - case "$1" in - --config) - CFG="$2" - shift 2 - ;; - *) - echo "Unknown option: $1" >&2 - exit 1 - ;; - esac -done - -if [[ ! -f "$CFG" ]]; then - echo "Config file not found: $CFG" >&2 - exit 1 -fi - -echo "Config: $CFG" -python -m mmirage.cli run --config "$CFG" --force-retry diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index b8c71a5..ee343d8 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -53,21 +53,32 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa return run_local(config_path, initial_shard_id) shard_ids: List[int] = [initial_shard_id] + attempts_by_shard = {initial_shard_id: 0} while True: + run_exit_codes = {} for shard_id in shard_ids: - run_local(config_path, shard_id) + attempts_by_shard[shard_id] = attempts_by_shard.get(shard_id, 0) + 1 + run_exit_codes[shard_id] = run_local(config_path, shard_id) failed_shards, summary = check_failed_shards(cfg) if status_exit_code(failed_shards, summary) == 0: logger.info("All shards completed successfully") return 0 - if not failed_shards: + runtime_failed = [shard_id for shard_id, rc in run_exit_codes.items() if rc != 0] + candidates = sorted(set(failed_shards) | set(runtime_failed)) + retryable_shards = [ + shard_id + for shard_id in candidates + if attempts_by_shard.get(shard_id, 0) < cfg.execution_params.max_retries + ] + + if not retryable_shards: logger.error("Pipeline ended with shards that exceeded max retries") return 1 - logger.warning("Retrying failed shards locally: %s", ",".join(map(str, failed_shards))) - shard_ids = failed_shards + logger.warning("Retrying failed shards locally: %s", ",".join(map(str, retryable_shards))) + shard_ids = retryable_shards shard_ids: List[int] = [] From f613fa1d6759c7e12138db281158759b318c1a85 Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Sun, 22 Mar 2026 20:09:53 +0100 Subject: [PATCH 29/47] Update src/mmirage/config/loading.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/mmirage/config/loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index f58ca41..e71bbaa 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -65,11 +65,11 @@ def __post_init__(self): if self.state_dir is not None: self.state_dir = str(self.state_dir).strip() or None - def get_state_root(self) -> str: + def get_state_root(self) -> Optional[str]: """Get the state root path. Returns: - str: State root path + Optional[str]: State root path, or None if no state directory is configured. """ return self.state_dir From 555405b98e5429d8b5cbd29232c299466434a6c9 Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Sun, 22 Mar 2026 20:10:14 +0100 Subject: [PATCH 30/47] Update src/mmirage/shard_process.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/mmirage/shard_process.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 2a7390a..48f178d 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -86,6 +86,11 @@ def main(): raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") state_root = loading_params.get_state_root() + if state_root is None: + raise ValueError( + "loading_params.state_dir is not set. Please configure " + "`config.loading_params.state_dir` to enable shard state tracking." + ) state_dir = _shard_state_dir(shard_id, state_root) try: From ce41693d71547644d6cc6df3c7fdd9153246189f Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Sun, 22 Mar 2026 20:11:42 +0100 Subject: [PATCH 31/47] Update src/mmirage/cli_utils/status.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/mmirage/cli_utils/status.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index d6e242b..99a7ad8 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -65,12 +65,16 @@ def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: success_count += 1 continue - if retry_count >= max_retries: + # `retry_count` includes the initial attempt (starts at 1 on first run), + # while `max_retries` is configured as "number of retries after the first run". + # Compute the number of retries already used so the budget is applied correctly. + retries_used = max(retry_count - 1, 0) + if retries_used >= max_retries: exhausted_count += 1 logger.warning( "Shard %s exceeded retry budget (%s/%s)", shard_id, - retry_count, + retries_used, max_retries, ) continue From c4c23c93047118bac5c66981bb17350dda09ff2a Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Sun, 22 Mar 2026 20:19:04 +0100 Subject: [PATCH 32/47] Update src/mmirage/cli_utils/slurm.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/mmirage/cli_utils/slurm.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/mmirage/cli_utils/slurm.py b/src/mmirage/cli_utils/slurm.py index 6f050c7..6358ae1 100644 --- a/src/mmirage/cli_utils/slurm.py +++ b/src/mmirage/cli_utils/slurm.py @@ -21,7 +21,17 @@ def _bash_double_quote(value: str) -> str: We intentionally do NOT escape '$' so that $VARS from config can expand on compute nodes (e.g. $SCRATCH). This matches typical SLURM job scripts. + + To avoid command injection, we reject values containing shell command + substitution syntax such as ``$(...)`` or backticks. Variable expansion + using ``$VAR`` or ``${VAR}`` is still allowed. """ + # Disallow command substitution while still allowing $VAR expansion. + if "`" in value or "$(" in value: + raise ValueError( + "Config value contains unsupported shell command substitution " + "(` or '$('). Command substitution is not allowed in SLURM-generated scripts." + ) escaped = value.replace("\\", "\\\\").replace('"', '\\"') return f'"{escaped}"' From b5c094f1bf8b6567ace5c48168b341fa74e79f65 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Sun, 22 Mar 2026 20:23:33 +0100 Subject: [PATCH 33/47] implemented changes proposed by copilot --- configs/config_comprehensive.yaml | 3 --- src/mmirage/cli.py | 19 ++++++++++++++----- src/mmirage/cli_utils/status.py | 5 +---- src/mmirage/config/config.py | 2 -- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/configs/config_comprehensive.yaml b/configs/config_comprehensive.yaml index e017747..57291bc 100644 --- a/configs/config_comprehensive.yaml +++ b/configs/config_comprehensive.yaml @@ -179,9 +179,6 @@ execution_params: # This allows filesystem to settle on distributed systems settle_time_seconds: 60 - # Seconds to wait between checks during settle time (default: 10) - settle_poll_interval: 10 - # ============================================================================ # USAGE EXAMPLES diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index ee343d8..808c62b 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -15,6 +15,8 @@ from mmirage.cli_utils.slurm import require_slurm, submit_slurm_job, wait_for_slurm_job from mmirage.cli_utils.status import ( check_failed_shards, + get_shard_state_dir, + get_shard_status, status_exit_code, submit_failed_shards, ) @@ -54,6 +56,10 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa shard_ids: List[int] = [initial_shard_id] attempts_by_shard = {initial_shard_id: 0} + state_root = cfg.loading_params.get_state_root() + if state_root is None: + logger.error("loading_params.state_dir is required for local retry orchestration") + return 1 while True: run_exit_codes = {} for shard_id in shard_ids: @@ -67,11 +73,14 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa runtime_failed = [shard_id for shard_id, rc in run_exit_codes.items() if rc != 0] candidates = sorted(set(failed_shards) | set(runtime_failed)) - retryable_shards = [ - shard_id - for shard_id in candidates - if attempts_by_shard.get(shard_id, 0) < cfg.execution_params.max_retries - ] + retryable_shards: List[int] = [] + for shard_id in candidates: + _, retry_count = get_shard_status(get_shard_state_dir(state_root, shard_id)) + retries_from_state = max(retry_count - 1, 0) + retries_from_memory = max(attempts_by_shard.get(shard_id, 0) - 1, 0) + retries_used = max(retries_from_state, retries_from_memory) + if retries_used < cfg.execution_params.max_retries: + retryable_shards.append(shard_id) if not retryable_shards: logger.error("Pipeline ended with shards that exceeded max retries") diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index 99a7ad8..844a724 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -64,10 +64,7 @@ def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: if status == "success": success_count += 1 continue - - # `retry_count` includes the initial attempt (starts at 1 on first run), - # while `max_retries` is configured as "number of retries after the first run". - # Compute the number of retries already used so the budget is applied correctly. + retries_used = max(retry_count - 1, 0) if retries_used >= max_retries: exhausted_count += 1 diff --git a/src/mmirage/config/config.py b/src/mmirage/config/config.py index ccfeb65..3a2993d 100644 --- a/src/mmirage/config/config.py +++ b/src/mmirage/config/config.py @@ -20,7 +20,6 @@ class ExecutionParams: max_retries: Maximum number of retries for failed shards. Defaults to 3. poll_interval_seconds: Seconds to wait between polling job status. Defaults to 30. settle_time_seconds: Seconds to wait after job completes before checking results. Defaults to 60. - settle_poll_interval: Seconds between polls during settle time. Defaults to 10. # SLURM-specific parameters account: HPC account/partition to charge. Required for SLURM mode. @@ -44,7 +43,6 @@ class ExecutionParams: max_retries: int = 3 poll_interval_seconds: int = 30 settle_time_seconds: int = 60 - settle_poll_interval: int = 10 # SLURM parameters account: Optional[str] = None From d90abf9d4f437a5c58532d31b22771e20b7135b0 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Sun, 22 Mar 2026 20:43:49 +0100 Subject: [PATCH 34/47] implemented more changes proposed by copilot --- src/mmirage/__init__.py | 2 +- src/mmirage/cli_utils/slurm.py | 5 ++++- src/mmirage/cli_utils/status.py | 16 +++++++++++++++- src/mmirage/config/config.py | 1 + src/mmirage/config/loading.py | 15 +++++++++------ src/mmirage/core/loader/utils.py | 30 ++++++++++++++---------------- 6 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/mmirage/__init__.py b/src/mmirage/__init__.py index fa7fe08..9d64cca 100644 --- a/src/mmirage/__init__.py +++ b/src/mmirage/__init__.py @@ -5,7 +5,7 @@ """ from __future__ import annotations -__version__ = "0.2.0" +__version__ = "0.1.3" from mmirage.config.config import MMirageConfig, ProcessingParams from mmirage.config.loading import LoadingParams diff --git a/src/mmirage/cli_utils/slurm.py b/src/mmirage/cli_utils/slurm.py index 6358ae1..beaa4e6 100644 --- a/src/mmirage/cli_utils/slurm.py +++ b/src/mmirage/cli_utils/slurm.py @@ -58,7 +58,10 @@ def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str: """Build the sbatch payload executed for each array task.""" project_root = get_project_root(cfg) hf_home = _shell_path(cfg.execution_params.hf_home, project_root) - state_root = _shell_path(cfg.loading_params.get_state_root(), project_root) + state_root_cfg = cfg.loading_params.get_state_root() + if not state_root_cfg: + raise ValueError("loading_params.state_dir must be set in slurm mode") + state_root = _shell_path(state_root_cfg, project_root) src_root = os.path.join(project_root, "src") shard_process_path = os.path.join(src_root, "mmirage", "shard_process.py") diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index 844a724..f602e4a 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -22,6 +22,7 @@ class ShardSummary: total: int successful: int + running: int failed: int max_retries_exceeded: int @@ -57,6 +58,7 @@ def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: max_retries = cfg.execution_params.max_retries failed_shards: List[int] = [] success_count = 0 + running_count = 0 exhausted_count = 0 for shard_id in range(num_shards): @@ -64,6 +66,10 @@ def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: if status == "success": success_count += 1 continue + + if status == "running": + running_count += 1 + continue retries_used = max(retry_count - 1, 0) if retries_used >= max_retries: @@ -81,6 +87,7 @@ def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: summary = ShardSummary( total=num_shards, successful=success_count, + running=running_count, failed=len(failed_shards), max_retries_exceeded=exhausted_count, ) @@ -100,7 +107,14 @@ def confirm_retry(count: int, interactive: bool) -> bool: def status_exit_code(failed_shards: Sequence[int], summary: ShardSummary) -> int: """Map shard status to an exit code.""" - return 0 if not failed_shards and summary.max_retries_exceeded == 0 else 1 + return ( + 0 + if not failed_shards + and summary.max_retries_exceeded == 0 + and summary.running == 0 + and summary.successful == summary.total + else 1 + ) def submit_failed_shards( diff --git a/src/mmirage/config/config.py b/src/mmirage/config/config.py index 3a2993d..f6cc34e 100644 --- a/src/mmirage/config/config.py +++ b/src/mmirage/config/config.py @@ -17,6 +17,7 @@ class ExecutionParams: Attributes: mode: Execution mode: "local" or "slurm". Defaults to "local". + retry: Whether automatic retry orchestration is enabled. Defaults to False. max_retries: Maximum number of retries for failed shards. Defaults to 3. poll_interval_seconds: Seconds to wait between polling job status. Defaults to 30. settle_time_seconds: Seconds to wait after job completes before checking results. Defaults to 60. diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index e71bbaa..8c3fa6a 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -2,7 +2,7 @@ import re from dataclasses import dataclass, field -from typing import Union, List, cast, Optional +from typing import Union, List, cast from mmirage.core.loader.base import BaseDataLoaderConfig @@ -27,7 +27,7 @@ class LoadingParams: """ datasets: List[BaseDataLoaderConfig] = field(default_factory=list) - state_dir: Optional[str] = None + state_dir: str = "" output_dir: str = "" num_shards: Union[int, str] = 1 shard_id: Union[int, str] = 0 @@ -62,14 +62,17 @@ def __post_init__(self): self.batch_size = max(self.batch_size, 1) - if self.state_dir is not None: - self.state_dir = str(self.state_dir).strip() or None + self.state_dir = str(self.state_dir).strip() + if not self.state_dir: + raise ValueError( + "loading_params.state_dir is required to enable shard state tracking" + ) - def get_state_root(self) -> Optional[str]: + def get_state_root(self) -> str: """Get the state root path. Returns: - Optional[str]: State root path, or None if no state directory is configured. + str: State root path. """ return self.state_dir diff --git a/src/mmirage/core/loader/utils.py b/src/mmirage/core/loader/utils.py index 0b43727..c039727 100644 --- a/src/mmirage/core/loader/utils.py +++ b/src/mmirage/core/loader/utils.py @@ -31,23 +31,21 @@ def load_datasets_from_configs(configs: List[BaseDataLoaderConfig]) -> List[Data RuntimeError: If no datasets could be loaded successfully. """ - config_per_type = {} - for ds_config in configs: - config_per_type[ds_config.type] = config_per_type.get(ds_config.type, []) + [ - ds_config - ] - valid_ds: List[DatasetLike] = [] - for config_type, config_list in config_per_type.items(): - loader = AutoDataLoader.from_name(config_type)() - for ds_config in config_list: - try: - ds = loader.from_config(ds_config) - if ds is None: - continue - valid_ds.append(ds) - except Exception as e: - logger.warning(f"āš ļø Dataset loading failed with error: {e}. Skipping") + loader_by_type = {} + for ds_config in configs: + loader = loader_by_type.get(ds_config.type) + if loader is None: + loader = AutoDataLoader.from_name(ds_config.type)() + loader_by_type[ds_config.type] = loader + + try: + ds = loader.from_config(ds_config) + if ds is None: + continue + valid_ds.append(ds) + except Exception as e: + logger.warning(f"Dataset loading failed with error: {e}. Skipping") if not valid_ds: raise RuntimeError("No valid datasets loaded from the provided configs.") From 6eb3d2925f0607259e686e7ec15e334b10248cc6 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 23 Mar 2026 17:54:47 +0100 Subject: [PATCH 35/47] deleted run.sh and implemented a default state_dir value --- README.md | 1 + run.sh | 23 ----------------------- src/mmirage/cli.py | 9 --------- src/mmirage/cli_utils/slurm.py | 5 +---- src/mmirage/cli_utils/status.py | 2 -- src/mmirage/config/loading.py | 13 ++++++++----- src/mmirage/shard_process.py | 8 +------- 7 files changed, 11 insertions(+), 50 deletions(-) delete mode 100644 run.sh diff --git a/README.md b/README.md index 3dec545..931ecb4 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,7 @@ Configuration explanation: - `processors`: List of processor configurations. Currently supports `llm` type for LLM-based generation. - `loading_params`: Parameters for loading and sharding datasets. + - `state_dir`: Optional shared directory for shard status/retry state. Defaults to `~/state_dir`. - `datasets`: List of dataset configurations with path, type, and output directory. - `processing_params`: - `inputs`: Variables extracted from the input dataset using JMESPath queries. diff --git a/run.sh b/run.sh deleted file mode 100644 index a35fb08..0000000 --- a/run.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash -# MMIRAGE launch script. MMIRAGE should now be ran with a single command, and the behavior is driven by the config file. -# -# Launch behavior is driven by the config file: -# - execution_params.retry=false: submit one SLURM array job, or run locally -# - execution_params.retry=true: submit and automatically retry failed shards -# -# Usage: -# bash run.sh -# CFG=configs/config_mock.yaml bash run.sh - -set -euo pipefail - -CFG="${CFG:-configs/config_mock.yaml}" - -if [[ ! -f "$CFG" ]]; then - echo "Config file not found: $CFG" >&2 - exit 1 -fi - -python -m mmirage.cli run --config "$CFG" - -echo "END TIME: $(date)" diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index 808c62b..cc33102 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -48,18 +48,9 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa if not auto_retry: return run_local(config_path, initial_shard_id) - if not cfg.loading_params.get_state_root(): - logger.warning( - "Local retry requires loading_params.state_dir; running once without orchestration" - ) - return run_local(config_path, initial_shard_id) - shard_ids: List[int] = [initial_shard_id] attempts_by_shard = {initial_shard_id: 0} state_root = cfg.loading_params.get_state_root() - if state_root is None: - logger.error("loading_params.state_dir is required for local retry orchestration") - return 1 while True: run_exit_codes = {} for shard_id in shard_ids: diff --git a/src/mmirage/cli_utils/slurm.py b/src/mmirage/cli_utils/slurm.py index beaa4e6..6358ae1 100644 --- a/src/mmirage/cli_utils/slurm.py +++ b/src/mmirage/cli_utils/slurm.py @@ -58,10 +58,7 @@ def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str: """Build the sbatch payload executed for each array task.""" project_root = get_project_root(cfg) hf_home = _shell_path(cfg.execution_params.hf_home, project_root) - state_root_cfg = cfg.loading_params.get_state_root() - if not state_root_cfg: - raise ValueError("loading_params.state_dir must be set in slurm mode") - state_root = _shell_path(state_root_cfg, project_root) + state_root = _shell_path(cfg.loading_params.get_state_root(), project_root) src_root = os.path.join(project_root, "src") shard_process_path = os.path.join(src_root, "mmirage", "shard_process.py") diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index f602e4a..31da10e 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -51,8 +51,6 @@ def get_shard_status(state_dir: str) -> Tuple[str, int]: def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: """Return retryable failed shards and a compact summary.""" state_root = cfg.loading_params.get_state_root() - if not state_root: - raise ValueError("loading_params.state_dir is required to check shard status") num_shards = cfg.loading_params.get_num_shards() max_retries = cfg.execution_params.max_retries diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index 8c3fa6a..fe8140e 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -1,11 +1,13 @@ """Data loading configuration for MMIRAGE pipeline.""" +import os import re from dataclasses import dataclass, field from typing import Union, List, cast from mmirage.core.loader.base import BaseDataLoaderConfig +DEFAULT_STATE_DIR = "~/.cache/MMIRAGE/state_dir" @dataclass class LoadingParams: @@ -27,7 +29,7 @@ class LoadingParams: """ datasets: List[BaseDataLoaderConfig] = field(default_factory=list) - state_dir: str = "" + state_dir: str = DEFAULT_STATE_DIR output_dir: str = "" num_shards: Union[int, str] = 1 shard_id: Union[int, str] = 0 @@ -62,11 +64,12 @@ def __post_init__(self): self.batch_size = max(self.batch_size, 1) - self.state_dir = str(self.state_dir).strip() + raw_state_dir = "" if self.state_dir is None else str(self.state_dir) + self.state_dir = raw_state_dir.strip() if not self.state_dir: - raise ValueError( - "loading_params.state_dir is required to enable shard state tracking" - ) + self.state_dir = DEFAULT_STATE_DIR + + self.state_dir = os.path.expanduser(self.state_dir) def get_state_root(self) -> str: """Get the state root path. diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 48f178d..fec0e47 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -85,13 +85,7 @@ def main(): if not (0 <= shard_id < num_shards): raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") - state_root = loading_params.get_state_root() - if state_root is None: - raise ValueError( - "loading_params.state_dir is not set. Please configure " - "`config.loading_params.state_dir` to enable shard state tracking." - ) - state_dir = _shard_state_dir(shard_id, state_root) + state_dir = _shard_state_dir(shard_id, loading_params.get_state_root()) try: retry_count = _mark_running(state_dir, shard_id, datasets_config) From b64711e289e92515e63774d135179cbc5c15aa01 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 23 Mar 2026 17:56:01 +0100 Subject: [PATCH 36/47] forgot to update the readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 931ecb4..87d7f05 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,7 @@ Configuration explanation: - `processors`: List of processor configurations. Currently supports `llm` type for LLM-based generation. - `loading_params`: Parameters for loading and sharding datasets. - - `state_dir`: Optional shared directory for shard status/retry state. Defaults to `~/state_dir`. + - `state_dir`: Optional shared directory for shard status/retry state. Defaults to `~/.cache/MMIRAGE/state_dir`. - `datasets`: List of dataset configurations with path, type, and output directory. - `processing_params`: - `inputs`: Variables extracted from the input dataset using JMESPath queries. From c966a1b72ca8c7cb7663d1b71f692db6012006b7 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:09:50 +0100 Subject: [PATCH 37/47] changes suggested in the PR --- README.md | 9 +- pyproject.toml | 2 +- src/mmirage/__init__.py | 2 +- src/mmirage/cli.py | 179 +++++++++++++++++++++---------- src/mmirage/cli_utils/runtime.py | 48 +++++---- src/mmirage/cli_utils/slurm.py | 5 +- src/mmirage/cli_utils/status.py | 69 +++++++----- src/mmirage/config/config.py | 12 +-- src/mmirage/config/loading.py | 10 +- src/mmirage/config/utils.py | 7 ++ src/mmirage/merge_shards.py | 8 +- src/mmirage/shard_process.py | 9 +- src/mmirage/shard_utils.py | 142 ++++++++++++++++++------ 13 files changed, 331 insertions(+), 171 deletions(-) diff --git a/README.md b/README.md index 87d7f05..1446df4 100644 --- a/README.md +++ b/README.md @@ -43,16 +43,16 @@ Run the pipeline via the Python CLI. Retry behavior is driven by your YAML confi python -m mmirage.cli run --config configs/config_mock.yaml ``` -To check status and (optionally) submit retries for failed shards: +To check status only: ```bash python -m mmirage.cli check --config configs/config_mock.yaml ``` -If you only want the status summary (no retry submission): +To check status and submit retries for failed shards: ```bash -python -m mmirage.cli check --config configs/config_mock.yaml --summary-only +python -m mmirage.cli check --config configs/config_mock.yaml --retry ``` ### Text-only: Reformatting dataset @@ -131,6 +131,9 @@ Configuration explanation: - `inputs`: Variables extracted from the input dataset using JMESPath queries. - `outputs`: Variables created by processors. Prompts use Jinja2 templating (`{{ variable }}`). - `output_schema`: Defines the structure of output samples. +- `execution_params`: + - `mode`: "local" to run shard processing in the current Python environment or "slurm" to run through SLURM by submitting an sbatch array job. + - `retry`: If true, MMIRAGE automatically retries failed shards until they succeed or `max_retries` is reached. If false, the pipeline runs/submits once, and retries can be triggered later via the check/retry CLI commands. ### Multimodal: Processing images with VLMs diff --git a/pyproject.toml b/pyproject.toml index 616ed78..5804d4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mmirage" -version = "0.1.3" +version = "0.1.4" description = "Modular Multimodal Intelligent Reformatting and Augmentation Generation Engine - Advanced platform for processing datasets using generative models including vision-language models." readme = "README.md" requires-python = ">=3.10" diff --git a/src/mmirage/__init__.py b/src/mmirage/__init__.py index 9d64cca..f255e81 100644 --- a/src/mmirage/__init__.py +++ b/src/mmirage/__init__.py @@ -5,7 +5,7 @@ """ from __future__ import annotations -__version__ = "0.1.3" +__version__ = "0.1.4" from mmirage.config.config import MMirageConfig, ProcessingParams from mmirage.config.loading import LoadingParams diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index cc33102..00d8b54 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -11,11 +11,12 @@ from dataclasses import asdict from typing import List, Optional -from mmirage.cli_utils.runtime import setup_runtime, validate_paths +from mmirage.cli_utils.runtime import setup_runtime, validate_edf_env_path from mmirage.cli_utils.slurm import require_slurm, submit_slurm_job, wait_for_slurm_job from mmirage.cli_utils.status import ( check_failed_shards, - get_shard_state_dir, + is_retry_budget_exceeded, + shard_state_dir, get_shard_status, status_exit_code, submit_failed_shards, @@ -28,7 +29,15 @@ def run_local(config_path: str, shard_id: Optional[int] = None) -> int: - """Run one shard in the current Python environment.""" + """Run one shard in the current Python environment. + + Args: + config_path: Absolute path to the MMIRAGE YAML config file. + shard_id: Optional shard id to inject via SLURM_ARRAY_TASK_ID. + + Returns: + Process return code from shard execution. + """ command = [sys.executable, "-m", "mmirage.shard_process", "--config", config_path] env = os.environ.copy() if shard_id is not None: @@ -40,7 +49,16 @@ def run_local(config_path: str, shard_id: Optional[int] = None) -> int: def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = False) -> int: - """Launch the pipeline according to execution mode and retry settings.""" + """Launch the pipeline according to execution mode and retry settings. + + Args: + cfg: Parsed MMIRAGE configuration object. + config_path: Absolute path to the MMIRAGE YAML config file. + force_retry: If True, enable retry orchestration regardless of config flag. + + Returns: + Exit code: 0 on success, 1 on failure. + """ auto_retry = force_retry or cfg.execution_params.retry if not cfg.execution_params.is_slurm(): @@ -66,11 +84,14 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa candidates = sorted(set(failed_shards) | set(runtime_failed)) retryable_shards: List[int] = [] for shard_id in candidates: - _, retry_count = get_shard_status(get_shard_state_dir(state_root, shard_id)) - retries_from_state = max(retry_count - 1, 0) - retries_from_memory = max(attempts_by_shard.get(shard_id, 0) - 1, 0) - retries_used = max(retries_from_state, retries_from_memory) - if retries_used < cfg.execution_params.max_retries: + _, state_attempt_count = get_shard_status(shard_state_dir(state_root, shard_id)) + memory_attempt_count = attempts_by_shard.get(shard_id, 0) + effective_attempt_count = max(state_attempt_count, memory_attempt_count) + + if not is_retry_budget_exceeded( + effective_attempt_count, + cfg.execution_params.max_retries, + ): retryable_shards.append(shard_id) if not retryable_shards: @@ -108,7 +129,11 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa def configure_logging(level: str) -> None: - """Configure root logging.""" + """Configure root logging. + + Args: + level: Root log level name. + """ logging.basicConfig( level=getattr(logging, level, logging.INFO), format="%(asctime)s %(levelname)s %(name)s: %(message)s", @@ -116,7 +141,11 @@ def configure_logging(level: str) -> None: def add_shared_arguments(parser: argparse.ArgumentParser) -> None: - """Attach common CLI arguments to a subcommand parser.""" + """Attach common CLI arguments to a subcommand parser. + + Args: + parser: Subcommand parser receiving shared arguments. + """ parser.add_argument("--config", required=True, help="Path to a MMIRAGE YAML config file") parser.add_argument( "--log-level", @@ -127,7 +156,11 @@ def add_shared_arguments(parser: argparse.ArgumentParser) -> None: def build_argparser() -> argparse.ArgumentParser: - """Build the CLI parser.""" + """Build the CLI parser. + + Returns: + Configured top-level argparse parser. + """ parser = argparse.ArgumentParser(description="MMIRAGE command-line interface") subparsers = parser.add_subparsers(dest="command", required=True) @@ -142,45 +175,33 @@ def build_argparser() -> argparse.ArgumentParser: check_parser = subparsers.add_parser("check", help="Inspect shard status") add_shared_arguments(check_parser) check_parser.add_argument( - "--summary-only", - action="store_true", - help="Only print status summary; do not submit retries.", - ) - check_retry_group = check_parser.add_mutually_exclusive_group() - check_retry_group.add_argument( "--retry", dest="retry", action="store_true", - help="Submit a retry job for failed shards (default unless --summary-only).", - ) - check_retry_group.add_argument( - "--no-retry", - dest="retry", - action="store_false", - help="Do not submit retries (same as --summary-only).", - ) - check_parser.set_defaults(retry=True) - check_interactive_group = check_parser.add_mutually_exclusive_group() - check_interactive_group.add_argument( - "--interactive", - dest="interactive", - action="store_true", - help="Prompt before submitting retry jobs (default).", + help="Submit a retry job for failed shards.", ) - check_interactive_group.add_argument( - "--no-interactive", - dest="interactive", - action="store_false", - help="Submit retry jobs without prompting.", + check_parser.set_defaults(retry=False) + check_parser.add_argument( + "-y", + "--yes", + dest="confirm_mode", + action="store_const", + const="yes", + help="Submit retries without prompting.", ) - check_parser.set_defaults(interactive=True) + check_parser.set_defaults(confirm_mode="prompt") retry_parser = subparsers.add_parser("retry", help="Submit only failed shards") add_shared_arguments(retry_parser) - retry_group = retry_parser.add_mutually_exclusive_group() - retry_group.add_argument("--interactive", dest="interactive", action="store_true") - retry_group.add_argument("--no-interactive", dest="interactive", action="store_false") - retry_parser.set_defaults(interactive=True) + retry_parser.add_argument( + "-y", + "--yes", + dest="confirm_mode", + action="store_const", + const="yes", + help="Submit retries without prompting.", + ) + retry_parser.set_defaults(confirm_mode="prompt") run_parser = subparsers.add_parser( "run", @@ -203,7 +224,15 @@ def build_argparser() -> argparse.ArgumentParser: def parse_shard_ids(raw_value: Optional[str], num_shards: Optional[int] = None) -> List[int]: - """Parse a comma-separated shard id list.""" + """Parse a comma-separated shard id list. + + Args: + raw_value: Comma-separated shard ids, or None/empty for full array. + num_shards: Optional upper bound used for range validation. + + Returns: + Parsed shard ids. + """ if not raw_value: return [] @@ -213,13 +242,11 @@ def parse_shard_ids(raw_value: Optional[str], num_shards: Optional[int] = None) if not candidate: continue - try: + if shard_id.isdigit(): shard_id = int(candidate) - except ValueError as exc: - raise ValueError(f"Invalid shard id {candidate!r}; expected integers") from exc + else: + raise ValueError(f"Invalid shard id {candidate!r}; expected integers") - if shard_id < 0: - raise ValueError(f"Invalid shard id {shard_id}; expected non-negative integer") if num_shards is not None and shard_id >= num_shards: raise ValueError(f"Invalid shard id {shard_id}; expected 0 <= shard_id < {num_shards}") @@ -229,14 +256,32 @@ def parse_shard_ids(raw_value: Optional[str], num_shards: Optional[int] = None) def handle_run(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: - """Handle the canonical run command.""" + """Handle the canonical run command. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + config_path: Absolute path to the MMIRAGE YAML config file. + + Returns: + Exit code for the run operation. + """ if args.shard_id is not None: return run_local(config_path, args.shard_id) return launch_pipeline(cfg, config_path, force_retry=args.force_retry) def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: - """Submit a SLURM array job and optionally wait.""" + """Submit a SLURM array job and optionally wait. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + config_path: Absolute path to the MMIRAGE YAML config file. + + Returns: + Exit code for submission/wait outcome. + """ if require_slurm(cfg, "submit") != 0: return 1 @@ -255,7 +300,16 @@ def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str def handle_check(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: - """Inspect shard status and optionally submit retries.""" + """Inspect shard status and optionally submit retries. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + config_path: Absolute path to the MMIRAGE YAML config file. + + Returns: + Exit code based on shard status and optional retry submission. + """ failed_shards, summary = check_failed_shards(cfg) print(json.dumps(asdict(summary), indent=2)) @@ -263,7 +317,7 @@ def handle_check(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) if not cfg.execution_params.is_slurm(): return status_code - if args.summary_only or not args.retry: + if not args.retry: return status_code if not failed_shards: @@ -273,12 +327,21 @@ def handle_check(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) cfg=cfg, config_path=config_path, failed_shards=failed_shards, - interactive=bool(args.interactive), + confirm_mode=args.confirm_mode, ) def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: - """Submit retries for failed shards only.""" + """Submit retries for failed shards only. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + config_path: Absolute path to the MMIRAGE YAML config file. + + Returns: + Exit code for retry submission outcome. + """ if require_slurm(cfg, "retry") != 0: return 1 @@ -289,14 +352,14 @@ def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) if summary.max_retries_exceeded > 0: logger.error("No retryable shards remain") return 1 - print("All shards already succeeded.") + logger.info("All shards already succeeded.") return 0 return submit_failed_shards( cfg=cfg, config_path=config_path, failed_shards=failed_shards, - interactive=bool(args.interactive), + confirm_mode=args.confirm_mode, ) @@ -311,7 +374,7 @@ def main() -> None: cfg = load_mmirage_config(config_path) setup_runtime(cfg, args.log_level) - validate_paths(cfg) + validate_edf_env_path(cfg) handlers = { "run": handle_run, diff --git a/src/mmirage/cli_utils/runtime.py b/src/mmirage/cli_utils/runtime.py index 4cc9fcc..64f6ae4 100644 --- a/src/mmirage/cli_utils/runtime.py +++ b/src/mmirage/cli_utils/runtime.py @@ -15,10 +15,10 @@ def expand_path(path: str, project_root: Optional[str] = None) -> str: """Expand environment variables, user home and relative paths.""" - expanded = os.path.expanduser(os.path.expandvars(path)) - if not os.path.isabs(expanded) and project_root: - expanded = os.path.join(project_root, expanded) - return os.path.abspath(expanded) + expanded = Path(os.path.expandvars(os.path.expanduser(path))) + if not expanded.is_absolute() and project_root: + expanded = Path(project_root) / expanded + return str(expanded.resolve()) def get_project_root(cfg: MMirageConfig) -> str: @@ -35,43 +35,45 @@ def create_directories(paths: Sequence[str]) -> None: Path(path).mkdir(parents=True, exist_ok=True) -def validate_paths(cfg: MMirageConfig) -> None: - """Validate pre-existing execution paths.""" - project_root = get_project_root(cfg) - if cfg.execution_params.edf_env: - edf_env = expand_path(cfg.execution_params.edf_env, project_root) - if not os.path.exists(edf_env): - raise FileNotFoundError(f"EDF environment file not found: {edf_env}") +def validate_edf_env_path(cfg: MMirageConfig) -> None: + """Validate the optional EDF environment file path.""" + edf_env = cfg.execution_params.edf_env + if not edf_env: + return + + resolved = expand_path(edf_env, get_project_root(cfg)) + if not Path(resolved).is_file(): + raise FileNotFoundError(f"EDF environment file not found: {resolved}") def add_file_logging(log_file: str, level: str) -> None: """Add a file handler so logs are also written to disk.""" - expanded_log_file = os.path.abspath(os.path.expanduser(os.path.expandvars(log_file))) + resolved_log_file = Path(expand_path(log_file)) try: - create_directories([str(Path(expanded_log_file).parent)]) + resolved_log_file.parent.mkdir(parents=True, exist_ok=True) except OSError as exc: - logger.warning("Unable to create log directory for %s: %s", expanded_log_file, exc) + logger.warning("Unable to create log directory for %s: %s", resolved_log_file, exc) return root_logger = logging.getLogger() for handler in root_logger.handlers: - if isinstance(handler, logging.FileHandler) and os.path.abspath(handler.baseFilename) == expanded_log_file: + if isinstance(handler, logging.FileHandler) and Path(handler.baseFilename).resolve() == resolved_log_file: return try: - file_handler = logging.FileHandler(expanded_log_file, mode="a", encoding="utf-8") + file_handler = logging.FileHandler(resolved_log_file, mode="a", encoding="utf-8") except OSError as exc: - logger.warning("Unable to open log file %s: %s", expanded_log_file, exc) + logger.warning("Unable to open log file %s: %s", resolved_log_file, exc) return - file_handler.setLevel(getattr(logging, level, logging.INFO)) + + file_handler.setLevel(getattr(logging, level.upper(), logging.INFO)) file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) root_logger.addHandler(file_handler) def setup_runtime(cfg: MMirageConfig, log_level: str) -> None: """Initialize runtime-level logging.""" - project_root = get_project_root(cfg) - report_dir = expand_path(cfg.execution_params.report_dir, project_root) - global_log_file = os.path.join(report_dir, f"{cfg.execution_params.job_name}.out") - add_file_logging(global_log_file, log_level) - logger.info("Writing logs to %s", global_log_file) + report_dir = Path(expand_path(cfg.execution_params.report_dir, get_project_root(cfg))) + log_file = report_dir / f"{cfg.execution_params.job_name}.out" + add_file_logging(str(log_file), log_level) + logger.info("Writing logs to %s", log_file) diff --git a/src/mmirage/cli_utils/slurm.py b/src/mmirage/cli_utils/slurm.py index 6358ae1..ad69d45 100644 --- a/src/mmirage/cli_utils/slurm.py +++ b/src/mmirage/cli_utils/slurm.py @@ -42,7 +42,7 @@ def _shell_path(value: str, project_root: str) -> str: If the path starts with '$' we assume it will expand on the compute node and therefore do not attempt to join it with project_root. """ - raw = (value or "").strip() + raw = value.strip() if not raw: return raw @@ -128,7 +128,8 @@ def submit_slurm_job( command.append(f"--array={','.join(str(shard_id) for shard_id in requested_shards)}") else: num_shards = cfg.loading_params.get_num_shards() - command.append(f"--array=0-{num_shards - 1}") + last_shard_id = num_shards - 1 + command.append(f"--array=0-{last_shard_id}") logger.info("Submitting SLURM job: %s", " ".join(command)) result = subprocess.run( diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index 31da10e..838099a 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -7,10 +7,11 @@ import os import sys from dataclasses import dataclass -from typing import List, Sequence, Tuple +from typing import List, Literal, Sequence, Tuple from mmirage.config.config import MMirageConfig from mmirage.cli_utils.slurm import submit_slurm_job +from mmirage.shard_utils import ShardStatus logger = logging.getLogger(__name__) @@ -27,13 +28,26 @@ class ShardSummary: max_retries_exceeded: int -def get_shard_state_dir(state_root: str, shard_id: int) -> str: +def max_allowed_attempts(max_retries: int) -> int: + """Return max allowed total attempts for a shard. + + Total attempts = initial attempt + max_retries. + """ + return max_retries + 1 + + +def is_retry_budget_exceeded(attempt_count: int, max_retries: int) -> bool: + """Return whether a shard has exceeded the retry budget.""" + return attempt_count > max_allowed_attempts(max_retries) + + +def shard_state_dir(state_root: str, shard_id: int) -> str: """Return the state directory for a shard.""" return os.path.join(state_root, f"shard_{shard_id}") def get_shard_status(state_dir: str) -> Tuple[str, int]: - """Read the current status and retry count for a shard.""" + """Read the current status and attempt counter for a shard.""" status_file = os.path.join(state_dir, "status.json") if not os.path.exists(status_file): return ("missing", 0) @@ -41,11 +55,15 @@ def get_shard_status(state_dir: str) -> Tuple[str, int]: try: with open(status_file, "r", encoding="utf-8") as handle: data = json.load(handle) + if not isinstance(data, dict): + logger.warning("Invalid shard status format in %s; expected object", status_file) + return ("unknown", 0) except (OSError, json.JSONDecodeError) as exc: logger.warning("Failed to read shard status from %s: %s", status_file, exc) return ("unknown", 0) - return (str(data.get("status", "unknown")), int(data.get("retry_count", 0))) + parsed = ShardStatus.from_dict(data) + return (parsed.status, parsed.retry_count) def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: @@ -58,29 +76,24 @@ def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: success_count = 0 running_count = 0 exhausted_count = 0 + allowed_attempts = max_allowed_attempts(max_retries) for shard_id in range(num_shards): - status, retry_count = get_shard_status(get_shard_state_dir(state_root, shard_id)) + status, attempt_count = get_shard_status(shard_state_dir(state_root, shard_id)) if status == "success": success_count += 1 - continue - - if status == "running": + elif status == "running": running_count += 1 - continue - - retries_used = max(retry_count - 1, 0) - if retries_used >= max_retries: + elif is_retry_budget_exceeded(attempt_count, max_retries): exhausted_count += 1 logger.warning( - "Shard %s exceeded retry budget (%s/%s)", + "Shard %s exceeded retry budget (attempts=%s, max_allowed_attempts=%s)", shard_id, - retries_used, - max_retries, + attempt_count, + allowed_attempts, ) - continue - - failed_shards.append(shard_id) + else: + failed_shards.append(shard_id) summary = ShardSummary( total=num_shards, @@ -92,13 +105,20 @@ def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: return failed_shards, summary -def confirm_retry(count: int, interactive: bool) -> bool: - """Return whether retry submission is confirmed.""" - if not interactive: +def confirm_retry(count: int, confirm_mode: Literal["prompt", "yes"]) -> bool: + """Return whether retry submission is confirmed. + + Modes: + - prompt: ask the user interactively + - yes: submit without prompting + """ + if confirm_mode == "yes": return True + if not sys.stdin.isatty(): - logger.error("Non-interactive input detected; use --no-interactive") + logger.error("Interactive confirmation requested but stdin is not a TTY; use --yes") return False + response = input(f"Retry {count} shard(s)? (y/N) ") return response.strip().lower() == "y" @@ -119,18 +139,17 @@ def submit_failed_shards( cfg: MMirageConfig, config_path: str, failed_shards: Sequence[int], - interactive: bool, + confirm_mode: Literal["prompt", "yes"], ) -> int: """Submit retry jobs for failed shards when requested.""" if not failed_shards: return 0 - if not confirm_retry(len(failed_shards), interactive): + if not confirm_retry(len(failed_shards), confirm_mode): return 1 job_id = submit_slurm_job(cfg, config_path, failed_shards) if job_id is None: return 1 - print(job_id) return 0 diff --git a/src/mmirage/config/config.py b/src/mmirage/config/config.py index f6cc34e..6063f7a 100644 --- a/src/mmirage/config/config.py +++ b/src/mmirage/config/config.py @@ -45,6 +45,12 @@ class ExecutionParams: poll_interval_seconds: int = 30 settle_time_seconds: int = 60 + # Paths (can contain environment variables like ${VAR} or $VAR) + project_root: Optional[str] = None + report_dir: str = "~/reports" + hf_home: str = "~/hf" + edf_env: Optional[str] = None + # SLURM parameters account: Optional[str] = None job_name: str = "mmirage-sharded" @@ -55,12 +61,6 @@ class ExecutionParams: cpus_per_task: int = 288 time_limit: str = "11:59:59" - # Paths (can contain environment variables like ${VAR} or $VAR) - project_root: Optional[str] = None - report_dir: str = "~/reports" - hf_home: str = "~/hf" - edf_env: Optional[str] = None - def __post_init__(self): """Validate execution parameters.""" if self.mode not in ("local", "slurm"): diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index fe8140e..22ff207 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -1,11 +1,11 @@ """Data loading configuration for MMIRAGE pipeline.""" import os -import re from dataclasses import dataclass, field from typing import Union, List, cast from mmirage.core.loader.base import BaseDataLoaderConfig +from mmirage.config.utils import _is_unresolved_env_var DEFAULT_STATE_DIR = "~/.cache/MMIRAGE/state_dir" @@ -102,11 +102,3 @@ def get_batch_size(self) -> int: int: Batch size (minimum 1). """ return cast(int, self.batch_size) - - -_UNRESOLVED_ENV_VAR_PATTERN = re.compile(r"^\$(?:\{[A-Za-z_][A-Za-z0-9_]*\}|[A-Za-z_][A-Za-z0-9_]*)$") - - -def _is_unresolved_env_var(value: str) -> bool: - """Check whether a string still looks like an unresolved shell env var.""" - return bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(value.strip())) \ No newline at end of file diff --git a/src/mmirage/config/utils.py b/src/mmirage/config/utils.py index 93ae3b6..4a73910 100644 --- a/src/mmirage/config/utils.py +++ b/src/mmirage/config/utils.py @@ -4,6 +4,7 @@ from dacite import Config, from_dict import yaml import os +import re from mmirage.config.config import MMirageConfig from mmirage.core.process.base import BaseProcessorConfig, ProcessorRegistry, OutputVar @@ -122,3 +123,9 @@ def output_var_hook(data: Dict[str, Any]) -> OutputVar: cfg_obj = from_dict(MMirageConfig, cast(dict, cfg), config=config) return cfg_obj + +def _is_unresolved_env_var(value: str) -> bool: + """Check whether a string still looks like an unresolved shell env var.""" + _UNRESOLVED_ENV_VAR_PATTERN = re.compile(r"^\$(?:\{[A-Za-z_][A-Za-z0-9_]*\}|[A-Za-z_][A-Za-z0-9_]*)$") + + return bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(value.strip())) \ No newline at end of file diff --git a/src/mmirage/merge_shards.py b/src/mmirage/merge_shards.py index e290cb7..433feb5 100644 --- a/src/mmirage/merge_shards.py +++ b/src/mmirage/merge_shards.py @@ -7,13 +7,7 @@ from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk from mmirage.core.loader.base import DatasetLike - - -def _count_rows(ds: DatasetLike) -> int: - """Count total rows in a dataset or dataset dict.""" - if isinstance(ds, DatasetDict): - return sum(len(split) for split in ds.values()) - return len(ds) +from mmirage.shard_utils import _count_rows def _merge_datasetdict(shard_dsets: List[DatasetDict]) -> DatasetDict: diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index fec0e47..ac6dc53 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -81,6 +81,7 @@ def main(): shard_id = loading_params.get_shard_id() num_shards = loading_params.get_num_shards() + last_shard_id = num_shards - 1 if not (0 <= shard_id < num_shards): raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") @@ -89,7 +90,7 @@ def main(): try: retry_count = _mark_running(state_dir, shard_id, datasets_config) - logger.info(f"Starting shard {shard_id}/{num_shards - 1} (attempt #{retry_count})") + logger.info(f"Starting shard {shard_id}/{last_shard_id} (attempt #{retry_count})") if retry_count > 1: for ds_config in datasets_config: @@ -117,7 +118,9 @@ def main(): ds_processed_all: List[DatasetLike] = [] for ds_idx, ds_shard in enumerate(ds_all_shard): ds_config = datasets_config[ds_idx] - remove_columns = _remove_columns(ds_shard, processing_params.remove_columns) + if processing_params.remove_columns: + remove_columns = _remove_columns(ds_shard) + else: remove_columns = [] logger.info( f"Processing dataset {ds_idx} for shard {shard_id}: " @@ -129,7 +132,7 @@ def main(): batched=True, batch_size=loading_params.get_batch_size(), load_from_cache_file=False, - desc=f"Shard {shard_id}/{num_shards - 1} dataset {ds_idx}", + desc=f"Shard {shard_id}/{last_shard_id} dataset {ds_idx}", fn_kwargs={ "mapper": mapper, "renderer": renderer, diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index 2ae1c9d..defdffe 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -5,13 +5,14 @@ """ from datetime import datetime +from dataclasses import dataclass from functools import reduce import json import logging import os import shutil import socket -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from datasets import DatasetDict @@ -20,6 +21,80 @@ logger = logging.getLogger(__name__) +@dataclass +class ShardStatus: + """Typed representation of the shard status.json payload.""" + + status: str = "unknown" + retry_count: int = 0 + shard_id: Optional[int] = None + started_at: Optional[str] = None + finished_at: Optional[str] = None + error: Optional[str] = None + hostname: Optional[str] = None + pid: Optional[int] = None + slurm_job_id: Optional[str] = None + slurm_array_task_id: Optional[str] = None + datasets: Optional[List[Dict[str, Any]]] = None + + @classmethod + def from_dict(cls, payload: Dict[str, Any]) -> "ShardStatus": + """Build a status object from a JSON payload.""" + data = payload or {} + try: + retry_count = int(data.get("retry_count", 0)) + except (TypeError, ValueError): + retry_count = 0 + + shard_id = data.get("shard_id") + if shard_id is not None: + try: + shard_id = int(shard_id) + except (TypeError, ValueError): + shard_id = None + + pid = data.get("pid") + if pid is not None: + try: + pid = int(pid) + except (TypeError, ValueError): + pid = None + + datasets = data.get("datasets") + if not isinstance(datasets, list): + datasets = None + + return cls( + status=str(data.get("status", "unknown")), + retry_count=retry_count, + shard_id=shard_id, + started_at=data.get("started_at"), + finished_at=data.get("finished_at"), + error=data.get("error"), + hostname=data.get("hostname"), + pid=pid, + slurm_job_id=data.get("slurm_job_id"), + slurm_array_task_id=data.get("slurm_array_task_id"), + datasets=datasets, + ) + + def to_dict(self) -> Dict[str, Any]: + """Serialize status to the JSON payload written on disk.""" + return { + "status": self.status, + "retry_count": self.retry_count, + "shard_id": self.shard_id, + "started_at": self.started_at, + "finished_at": self.finished_at, + "error": self.error, + "hostname": self.hostname, + "pid": self.pid, + "slurm_job_id": self.slurm_job_id, + "slurm_array_task_id": self.slurm_array_task_id, + "datasets": self.datasets, + } + + def _count_rows(ds: DatasetLike) -> int: """Count total rows in a dataset or dataset dict.""" if isinstance(ds, DatasetDict): @@ -39,13 +114,10 @@ def _shard_dataset(ds: DatasetLike, num_shards: int, shard_id: int) -> DatasetLi return ds.shard(num_shards=num_shards, index=shard_id) -def _remove_columns(ds: DatasetLike, enable: bool) -> List[str]: +def _remove_columns(ds: DatasetLike) -> List[str]: """Get columns to remove from dataset if enabled.""" - if not enable: - return [] if isinstance(ds, DatasetDict): - columns_set = [set(split_ds.column_names) for split_ds in ds.values()] - return list(reduce(lambda x, y: x | y, columns_set)) + return list(set(x for split_ds in ds.values() for x in split_ds.column_names)) return ds.column_names @@ -88,25 +160,29 @@ def _status_file(state_dir: str) -> str: return os.path.join(state_dir, "status.json") -def _read_status(state_dir: str) -> dict: +def _read_status(state_dir: str) -> ShardStatus: """Read status.json if present.""" path = _status_file(state_dir) if not os.path.exists(path): - return {} + return ShardStatus(status="missing") try: with open(path, "r") as f: - return json.load(f) + data = json.load(f) + if not isinstance(data, dict): + logger.warning(f"Invalid status format in {path}; expected object") + return ShardStatus(status="unknown") + return ShardStatus.from_dict(data) except (json.JSONDecodeError, OSError) as e: logger.warning(f"Failed to read status file {path}: {e}") - return {} + return ShardStatus(status="unknown") -def _write_status(state_dir: str, payload: dict): +def _write_status(state_dir: str, payload: ShardStatus): """Atomically write status.json.""" os.makedirs(state_dir, exist_ok=True) tmp_path = _status_file(state_dir) + ".tmp" with open(tmp_path, "w") as f: - json.dump(payload, f, indent=2, sort_keys=True) + json.dump(payload.to_dict(), f, indent=2, sort_keys=True) os.replace(tmp_path, _status_file(state_dir)) @@ -136,27 +212,27 @@ def _mark_running( ) -> int: """Mark shard as running and increment retry count.""" prev = _read_status(state_dir) - retry_count = int(prev.get("retry_count", 0)) + 1 - - payload = { - "status": "running", - "retry_count": retry_count, - "shard_id": shard_id, - "started_at": datetime.now().isoformat(), - "finished_at": None, - "error": None, - "hostname": socket.gethostname(), - "pid": os.getpid(), - "slurm_job_id": os.environ.get("SLURM_JOB_ID"), - "slurm_array_task_id": os.environ.get("SLURM_ARRAY_TASK_ID"), - "datasets": [ + retry_count = prev.retry_count + 1 + + payload = ShardStatus( + status="running", + retry_count=retry_count, + shard_id=shard_id, + started_at=datetime.now().isoformat(), + finished_at=None, + error=None, + hostname=socket.gethostname(), + pid=os.getpid(), + slurm_job_id=os.environ.get("SLURM_JOB_ID"), + slurm_array_task_id=os.environ.get("SLURM_ARRAY_TASK_ID"), + datasets=[ { "path": ds_config.path, "output_dir": ds_config.output_dir, } for ds_config in datasets_config ], - } + ) _write_status(state_dir, payload) _clear_markers(state_dir) @@ -167,9 +243,9 @@ def _mark_running( def _mark_success(state_dir: str): """Mark shard as successful.""" prev = _read_status(state_dir) - prev["status"] = "success" - prev["finished_at"] = datetime.now().isoformat() - prev["error"] = None + prev.status = "success" + prev.finished_at = datetime.now().isoformat() + prev.error = None _write_status(state_dir, prev) _clear_markers(state_dir) _touch_marker(state_dir, ".SUCCESS") @@ -178,9 +254,9 @@ def _mark_success(state_dir: str): def _mark_failure(state_dir: str, error_msg: str): """Mark shard as failed.""" prev = _read_status(state_dir) - prev["status"] = "failed" - prev["finished_at"] = datetime.now().isoformat() - prev["error"] = error_msg + prev.status = "failed" + prev.finished_at = datetime.now().isoformat() + prev.error = error_msg _write_status(state_dir, prev) _clear_markers(state_dir) _touch_marker(state_dir, ".FAILED") From 09708fa341ee2eb9b0944bafe5c7a9053ad0f875 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:23:29 +0100 Subject: [PATCH 38/47] removed a changes to avoid circular imports --- src/mmirage/config/loading.py | 8 +++++++- src/mmirage/config/utils.py | 9 +-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index 22ff207..e324e20 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -1,13 +1,19 @@ """Data loading configuration for MMIRAGE pipeline.""" import os +import re from dataclasses import dataclass, field from typing import Union, List, cast from mmirage.core.loader.base import BaseDataLoaderConfig -from mmirage.config.utils import _is_unresolved_env_var DEFAULT_STATE_DIR = "~/.cache/MMIRAGE/state_dir" +_UNRESOLVED_ENV_VAR_PATTERN = re.compile(r"^\$(?:\{[A-Za-z_][A-Za-z0-9_]*\}|[A-Za-z_][A-Za-z0-9_]*)$") + + +def _is_unresolved_env_var(value: str) -> bool: + """Check whether a string still looks like an unresolved shell env var.""" + return bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(value.strip())) @dataclass class LoadingParams: diff --git a/src/mmirage/config/utils.py b/src/mmirage/config/utils.py index 4a73910..e69c2d3 100644 --- a/src/mmirage/config/utils.py +++ b/src/mmirage/config/utils.py @@ -4,7 +4,6 @@ from dacite import Config, from_dict import yaml import os -import re from mmirage.config.config import MMirageConfig from mmirage.core.process.base import BaseProcessorConfig, ProcessorRegistry, OutputVar @@ -122,10 +121,4 @@ def output_var_hook(data: Dict[str, Any]) -> OutputVar: ) cfg_obj = from_dict(MMirageConfig, cast(dict, cfg), config=config) - return cfg_obj - -def _is_unresolved_env_var(value: str) -> bool: - """Check whether a string still looks like an unresolved shell env var.""" - _UNRESOLVED_ENV_VAR_PATTERN = re.compile(r"^\$(?:\{[A-Za-z_][A-Za-z0-9_]*\}|[A-Za-z_][A-Za-z0-9_]*)$") - - return bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(value.strip())) \ No newline at end of file + return cfg_obj \ No newline at end of file From a643c7d77bde149652d3873acce61fa86182d9e1 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:33:57 +0100 Subject: [PATCH 39/47] added some logging --- src/mmirage/cli.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index 00d8b54..e1cde7e 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -64,7 +64,10 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa if not cfg.execution_params.is_slurm(): initial_shard_id = cfg.loading_params.get_shard_id() if not auto_retry: - return run_local(config_path, initial_shard_id) + exit_code = run_local(config_path, initial_shard_id) + if exit_code == 0: + logger.info("All shards completed successfully") + return exit_code shard_ids: List[int] = [initial_shard_id] attempts_by_shard = {initial_shard_id: 0} @@ -296,7 +299,10 @@ def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str wait_for_slurm_job(job_id, cfg) failed_shards, summary = check_failed_shards(cfg) - return status_exit_code(failed_shards, summary) + status_code = status_exit_code(failed_shards, summary) + if status_code == 0: + logger.info("All shards completed successfully") + return status_code def handle_check(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int: From 5dca2b9535c40c1f8a0e52a609db0507af7a1d79 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:52:59 +0100 Subject: [PATCH 40/47] changed print to log for job_id information --- src/mmirage/cli.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index e1cde7e..3b98f12 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -111,7 +111,7 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa if job_id is None: return 1 - print(job_id) + logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids}") if not auto_retry: return 0 @@ -293,7 +293,8 @@ def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str if job_id is None: return 1 - print(job_id) + logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids}") + if not args.wait: return 0 From 73ffb3225932061e2be0bda8c6a63045cfb82166 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 24 Mar 2026 19:10:58 +0100 Subject: [PATCH 41/47] changed logging again --- src/mmirage/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index 3b98f12..b985453 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -111,7 +111,7 @@ def launch_pipeline(cfg: MMirageConfig, config_path: str, force_retry: bool = Fa if job_id is None: return 1 - logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids}") + logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids or 'ALL'}") if not auto_retry: return 0 @@ -293,7 +293,7 @@ def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str if job_id is None: return 1 - logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids}") + logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids or 'ALL'}") if not args.wait: return 0 From c2cca433f1c29bb4b6bdd7d06f01fdce6885bd81 Mon Sep 17 00:00:00 2001 From: Fabrice Nemo Date: Tue, 24 Mar 2026 19:12:03 +0100 Subject: [PATCH 42/47] lambda in __post_init__ to avoid exposing a function used in only one place --- src/mmirage/config/loading.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index e324e20..6bceb9b 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -8,13 +8,8 @@ from mmirage.core.loader.base import BaseDataLoaderConfig DEFAULT_STATE_DIR = "~/.cache/MMIRAGE/state_dir" -_UNRESOLVED_ENV_VAR_PATTERN = re.compile(r"^\$(?:\{[A-Za-z_][A-Za-z0-9_]*\}|[A-Za-z_][A-Za-z0-9_]*)$") -def _is_unresolved_env_var(value: str) -> bool: - """Check whether a string still looks like an unresolved shell env var.""" - return bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(value.strip())) - @dataclass class LoadingParams: """Parameters for loading and distributing datasets across shards. @@ -42,6 +37,9 @@ class LoadingParams: batch_size: Union[int, str] = 1 def __post_init__(self): + _UNRESOLVED_ENV_VAR_PATTERN = re.compile(r"^\$(?:\{[A-Za-z_][A-Za-z0-9_]*\}|[A-Za-z_][A-Za-z0-9_]*)$") + _is_unresolved_env_var = lambda s: bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(s.strip())) + if isinstance(self.num_shards, str): try: self.num_shards = int(self.num_shards) From 4e00d8ba58c652a7f1e9586089023393855ccb9a Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Tue, 24 Mar 2026 19:22:59 +0100 Subject: [PATCH 43/47] Update src/mmirage/shard_process.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/mmirage/shard_process.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index ac6dc53..66e8529 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -120,7 +120,8 @@ def main(): ds_config = datasets_config[ds_idx] if processing_params.remove_columns: remove_columns = _remove_columns(ds_shard) - else: remove_columns = [] + else: + remove_columns = [] logger.info( f"Processing dataset {ds_idx} for shard {shard_id}: " From 46aae895dbf9265c31e746ebe59d6b28464918f5 Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Tue, 24 Mar 2026 19:23:17 +0100 Subject: [PATCH 44/47] Update src/mmirage/shard_utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/mmirage/shard_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index defdffe..1f6be4d 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -6,7 +6,6 @@ from datetime import datetime from dataclasses import dataclass -from functools import reduce import json import logging import os From 4be5403ab4ff3bfdf97a7894d2384ee3e82376cc Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Tue, 24 Mar 2026 19:26:56 +0100 Subject: [PATCH 45/47] fixed small issue --- src/mmirage/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index b985453..d1c2add 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -245,7 +245,7 @@ def parse_shard_ids(raw_value: Optional[str], num_shards: Optional[int] = None) if not candidate: continue - if shard_id.isdigit(): + if candidate.isdigit(): shard_id = int(candidate) else: raise ValueError(f"Invalid shard id {candidate!r}; expected integers") From 4215283006ce17d69c730b771c45306da5f99820 Mon Sep 17 00:00:00 2001 From: Fabrice Nemo Date: Wed, 25 Mar 2026 11:11:18 +0100 Subject: [PATCH 46/47] better style --- src/mmirage/config/loading.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index 6bceb9b..929d76d 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -38,7 +38,8 @@ class LoadingParams: def __post_init__(self): _UNRESOLVED_ENV_VAR_PATTERN = re.compile(r"^\$(?:\{[A-Za-z_][A-Za-z0-9_]*\}|[A-Za-z_][A-Za-z0-9_]*)$") - _is_unresolved_env_var = lambda s: bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(s.strip())) + def is_unresolved_env_var(s: str) -> bool: + return bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(s.strip())) if isinstance(self.num_shards, str): try: @@ -46,7 +47,7 @@ def __post_init__(self): if self.num_shards < 1: raise ValueError() except (ValueError, TypeError): - if _is_unresolved_env_var(self.num_shards): + if is_unresolved_env_var(self.num_shards): self.num_shards = 1 else: raise ValueError(f"Invalid value for num_shards: {self.num_shards!r}") @@ -55,7 +56,7 @@ def __post_init__(self): try: self.shard_id = int(self.shard_id) except (ValueError, TypeError): - if _is_unresolved_env_var(self.shard_id): + if is_unresolved_env_var(self.shard_id): self.shard_id = 0 else: raise ValueError(f"Invalid value for shard_id: {self.shard_id!r}") From 6b3994a4b32ad9b823f0070eb35a446b65fe8078 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Wed, 25 Mar 2026 11:45:54 +0100 Subject: [PATCH 47/47] change proposed by copilot --- src/mmirage/shard_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index 1f6be4d..c4dce9a 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -11,6 +11,7 @@ import os import shutil import socket +import uuid from typing import Any, Dict, List, Optional from datasets import DatasetDict @@ -125,7 +126,9 @@ def _save_dataset_atomic(ds_processed: DatasetLike, out_dir: str): parent_dir = os.path.dirname(out_dir) os.makedirs(parent_dir, exist_ok=True) - tmp_dir = f"{out_dir}.tmp.{os.getpid()}" + tmp_dir = ( + f"{out_dir}.tmp.{socket.gethostname()}.{os.getpid()}.{uuid.uuid4().hex}" + ) if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)