From 968e729de251bef6ec2799787f6fc89a264ea82e Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:04:56 +0200 Subject: [PATCH 01/24] trying a benchmark --- README.md | 78 ++++++ configs/config_benchmark_datatrove.yaml | 78 ++++++ src/mmirage/cli.py | 52 +++- src/mmirage/cli_utils/slurm.py | 11 +- src/mmirage/cli_utils/status.py | 128 +++++++++- src/mmirage/core/process/mapper.py | 18 ++ .../process/processors/llm/llm_processor.py | 26 ++ src/mmirage/shard_process.py | 39 ++- src/mmirage/shard_utils.py | 228 +++++++++++++++++- 9 files changed, 642 insertions(+), 16 deletions(-) create mode 100644 configs/config_benchmark_datatrove.yaml diff --git a/README.md b/README.md index 2660832..4730dc4 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,84 @@ mmirage merge --config configs/config_mock.yaml --output-root /path/to/merged MMIRAGE still keeps datasets separate by creating one subdirectory per dataset under the root. +### Benchmarking shard performance + +Pass `--stats` to `run` or `submit` to enable per-shard benchmarking. This activates GPU +utilization polling and throughput tracking on compute nodes — disabled by default to +avoid unnecessary overhead. + +```bash +# Local run with stats collection +mmirage run --config configs/config_mock.yaml --stats + +# SLURM submission with stats collection +mmirage submit --config configs/config_mock.yaml --stats +``` + +After the run completes, inspect the results with: + +```bash +mmirage stats --config configs/config_mock.yaml +``` + +This prints a JSON report with per-shard details and an aggregate summary: + +```json +{ + "per_shard": [ + { + "shard_id": 0, + "status": "success", + "started_at": "2026-04-30T10:00:00", + "finished_at": "2026-04-30T10:01:05", + "stats": { + "runtime_seconds": 65.2, + "runtime_human": "1m 5s", + "rows_processed": 1024, + "throughput_rows_per_sec": 15.7, + "gpu_util_mean": 88.4, + "gpu_util_min": 72.0, + "gpu_util_max": 98.0, + "gpu_util_samples": 13, + "input_tokens": 512000, + "output_tokens": 196608, + "num_gpus": 4, + "tokens_per_sec_per_gpu": 753.1, + "gpu_days_per_billion_tokens": 0.0015 + } + } + ], + "aggregate": { + "total_shards": 4, + "completed_shards": 4, + "total_rows_processed": 4096, + "wall_clock_runtime_seconds": 68.1, + "wall_clock_runtime_human": "1m 8s", + "sum_shard_runtime_seconds": 261.4, + "sum_shard_runtime_human": "4m 21s", + "min_shard_runtime_seconds": 62.3, + "min_shard_runtime_human": "1m 2s", + "max_shard_runtime_seconds": 69.7, + "max_shard_runtime_human": "1m 9s", + "overall_throughput_rows_per_sec": 60.1, + "mean_gpu_util_pct": 87.9, + "num_gpus": 4, + "total_input_tokens": 2048000, + "total_output_tokens": 786432, + "tokens_per_sec_per_gpu": 750.8, + "gpu_days_per_billion_tokens": 0.0015 + } +} +``` + +Key metrics: +- **`runtime_seconds`** / **`runtime_human`**: time from when the shard started on the cluster (after dispatch), excluding queue wait time. +- **`overall_throughput_rows_per_sec`**: total rows / wall-clock time across all shards running in parallel. +- **`mean_gpu_util_pct`**: weighted average GPU utilization across shards (weighted by rows processed). +- **`tokens_per_sec_per_gpu`**: output tokens generated per second per GPU — the primary throughput metric used by frameworks such as [DataTrove](https://github.com/huggingface/datatrove). +- **`gpu_days_per_billion_tokens`**: total GPU-days consumed to generate 1 billion output tokens — useful for cost and scaling comparisons across different hardware configurations. +- Token metrics are `null` when no LLM processor was active, and GPU stats are `null` when `nvidia-smi` is unavailable or `--stats` was not passed. + ### Text-only: Reformatting dataset Suppose you have a dataset with samples of the following format diff --git a/configs/config_benchmark_datatrove.yaml b/configs/config_benchmark_datatrove.yaml new file mode 100644 index 0000000..ca4affb --- /dev/null +++ b/configs/config_benchmark_datatrove.yaml @@ -0,0 +1,78 @@ +# MMIRAGE — DataTrove-compatible throughput benchmark +# +# Mirrors the conditions used in the DataTrove inference benchmark +# (https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark): +# +# dataset : simplescaling/s1K-1.1 (train split, 1 000 samples) +# prompt : raw `question` field, no system prompt +# output : up to 1 024 tokens per sample +# context : 2 048-token model max context +# model : Qwen/Qwen3-4B (DataTrove baseline: tp=1 on a single GPU) +# +# Download the dataset before running: +# +# python -c " +# from datasets import load_dataset +# ds = load_dataset('simplescaling/s1K-1.1', split='train') +# ds.save_to_disk('data/s1K-1.1') +# " +# +# Then run with stats collection enabled: +# +# mmirage run --config configs/config_benchmark_datatrove.yaml --stats +# +# Inspect results: +# +# mmirage stats --config configs/config_benchmark_datatrove.yaml +# +# Key metrics to compare against DataTrove: +# tokens_per_sec_per_gpu (DataTrove: output_tps_per_gpu) +# gpu_days_per_billion_tokens (DataTrove: gpu_days_to_process_1b_tokens) + +processors: + - type: llm + server_args: + model_path: Qwen/Qwen3-4B # same model family as DataTrove baseline + tp_size: 1 # DataTrove baseline: tp=1 + trust_remote_code: true + disable_custom_all_reduce: true + default_sampling_params: + temperature: 0.0 # greedy — maximises reproducible throughput + max_new_tokens: 1024 # DataTrove: max-tokens=1024 + # Keep context to 2 048 tokens to match DataTrove model-max-context=2048. + # SGLang enforces this via max_total_tokens on the server; set it in + # server_args if your SGLang version supports it: + # max_total_tokens: 2048 + +loading_params: + state_dir: /users/qchapp/data/benchmark_s1k/_pipeline_state + datasets: + - path: /users/qchapp/data/s1K-1.1 # save_to_disk() target above + type: loadable + output_dir: /users/qchapp/data/benchmark_s1k/output + num_shards: 4 # adjust to your GPU/node count + shard_id: "$SLURM_ARRAY_TASK_ID" # use 0 for a local single-shard run + batch_size: 64 + +processing_params: + inputs: + - name: question + key: question # DataTrove: prompt-column=question + + outputs: + - name: answer + type: llm + output_type: plain + prompt: "{{ question }}" # bare question, no system prompt (DataTrove baseline) + + remove_columns: false + output_schema: + question: "{{ question }}" + answer: "{{ answer }}" + +execution_params: + mode: local # switch to slurm for multi-node + retry: false + merge: true + report_dir: /users/qchapp/reports + hf_home: /users/qchapp/hf diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index 4b40cc0..a160542 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -15,6 +15,7 @@ from mmirage.cli_utils.slurm import require_slurm, submit_slurm_job, wait_for_slurm_job from mmirage.cli_utils.status import ( check_failed_shards, + collect_bench_stats, is_retry_budget_exceeded, shard_state_dir, get_shard_status, @@ -29,12 +30,13 @@ logger = logging.getLogger(__name__) -def run_local(config_path: str, shard_id: Optional[int] = None) -> int: +def run_local(config_path: str, shard_id: Optional[int] = None, collect_stats: bool = False) -> int: """Run one shard in the current Python environment. Args: config_path: Absolute path to the MMIRAGE YAML config file. shard_id: Optional shard id to inject via SLURM_ARRAY_TASK_ID. + collect_stats: If True, enable GPU utilization polling in the shard process. Returns: Process return code from shard execution. @@ -43,6 +45,8 @@ def run_local(config_path: str, shard_id: Optional[int] = None) -> int: env = os.environ.copy() if shard_id is not None: env["SLURM_ARRAY_TASK_ID"] = str(shard_id) + if collect_stats: + env["MMIRAGE_COLLECT_STATS"] = "1" logger.info("Running local shard processing: %s", " ".join(command)) result = subprocess.run(command, env=env, check=False) @@ -54,6 +58,7 @@ def launch_pipeline( config_path: str, force_retry: bool = False, require_completion: bool = False, + collect_stats: bool = False, ) -> int: """Launch the pipeline according to execution mode and retry settings. @@ -63,6 +68,7 @@ def launch_pipeline( force_retry: If True, enable retry orchestration regardless of config flag. require_completion: If True, wait for completion and verify shard status before returning success in SLURM mode when auto-retry is off. + collect_stats: If True, enable GPU utilization polling on compute nodes. Returns: Exit code: 0 on success, 1 on failure. @@ -72,7 +78,7 @@ def launch_pipeline( if not cfg.execution_params.is_slurm(): initial_shard_id = cfg.loading_params.get_shard_id() if not auto_retry: - exit_code = run_local(config_path, initial_shard_id) + exit_code = run_local(config_path, initial_shard_id, collect_stats=collect_stats) if exit_code == 0: logger.info("All shards completed successfully") return exit_code @@ -84,7 +90,7 @@ def launch_pipeline( run_exit_codes = {} for shard_id in shard_ids: attempts_by_shard[shard_id] = attempts_by_shard.get(shard_id, 0) + 1 - run_exit_codes[shard_id] = run_local(config_path, shard_id) + run_exit_codes[shard_id] = run_local(config_path, shard_id, collect_stats=collect_stats) failed_shards, summary = check_failed_shards(cfg) if status_exit_code(failed_shards, summary) == 0: @@ -115,7 +121,7 @@ def launch_pipeline( shard_ids: List[int] = [] while True: - job_id = submit_slurm_job(cfg, config_path, shard_ids) + job_id = submit_slurm_job(cfg, config_path, shard_ids, collect_stats=collect_stats) if job_id is None: return 1 @@ -192,6 +198,11 @@ def build_argparser() -> argparse.ArgumentParser: help="Comma-separated shard ids to submit instead of the full array", ) submit_parser.add_argument("--wait", action="store_true", help="Wait for the submitted job") + submit_parser.add_argument( + "--stats", + action="store_true", + help="Enable GPU utilization and throughput collection on compute nodes", + ) check_parser = subparsers.add_parser("check", help="Inspect shard status") add_shared_arguments(check_parser) @@ -240,6 +251,11 @@ def build_argparser() -> argparse.ArgumentParser: default=None, help="Run a single shard locally (overrides execution mode)", ) + run_parser.add_argument( + "--stats", + action="store_true", + help="Enable GPU utilization and throughput collection during shard execution", + ) merge_parser = subparsers.add_parser( "merge", @@ -282,6 +298,12 @@ def build_argparser() -> argparse.ArgumentParser: help="Log verbosity", ) + stats_parser = subparsers.add_parser( + "stats", + help="Show per-shard benchmark statistics (runtime, throughput, GPU utilization)", + ) + add_shared_arguments(stats_parser) + return parser @@ -346,13 +368,14 @@ def handle_run(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) - Exit code for the run operation. """ if args.shard_id is not None: - return run_local(config_path, args.shard_id) + return run_local(config_path, args.shard_id, collect_stats=args.stats) exit_code = launch_pipeline( cfg, config_path, force_retry=args.force_retry, require_completion=cfg.execution_params.merge, + collect_stats=args.stats, ) if exit_code != 0: return exit_code @@ -380,7 +403,7 @@ def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str return 1 shard_ids = parse_shard_ids(args.shard_ids, cfg.loading_params.get_num_shards()) - job_id = submit_slurm_job(cfg, config_path, shard_ids) + job_id = submit_slurm_job(cfg, config_path, shard_ids, collect_stats=args.stats) if job_id is None: return 1 @@ -461,6 +484,22 @@ def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) ) +def handle_stats(args: argparse.Namespace, cfg: MMirageConfig, _config_path: str) -> int: + """Print per-shard benchmark statistics and aggregate totals. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + _config_path: Absolute path to the MMIRAGE YAML config file (not needed here). + + Returns: + Exit code: 0 always (stats are informational). + """ + report = collect_bench_stats(cfg) + print(json.dumps(report, indent=2)) + return 0 + + def handle_merge(args: argparse.Namespace, cfg: MMirageConfig, _config_path: str) -> int: """Merge shard outputs defined in config.loading_params.datasets. @@ -513,6 +552,7 @@ def main() -> None: "check": handle_check, "retry": handle_retry, "merge": handle_merge, + "stats": handle_stats, } handler = handlers.get(args.command) if handler is None: diff --git a/src/mmirage/cli_utils/slurm.py b/src/mmirage/cli_utils/slurm.py index ad69d45..bb2c7d4 100644 --- a/src/mmirage/cli_utils/slurm.py +++ b/src/mmirage/cli_utils/slurm.py @@ -54,7 +54,7 @@ def _shell_path(value: str, project_root: str) -> str: return raw -def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str: +def build_sbatch_script(cfg: MMirageConfig, config_path: str, collect_stats: bool = False) -> str: """Build the sbatch payload executed for each array task.""" project_root = get_project_root(cfg) hf_home = _shell_path(cfg.execution_params.hf_home, project_root) @@ -69,10 +69,14 @@ def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str: f"export SHARD_PROCESS={_bash_double_quote(shard_process_path)}", f"export HF_HOME={_bash_double_quote(hf_home)}", f"export MMIRAGE_CONFIG={_bash_double_quote(config_path)}", + ] + if collect_stats: + lines.append("export MMIRAGE_COLLECT_STATS=1") + lines.extend([ f"mkdir -p {_bash_double_quote(hf_home)}", f"mkdir -p {_bash_double_quote(state_root)}", "srun_args=(--cpus-per-task ${SLURM_CPUS_PER_TASK:-1} --wait 60)", - ] + ]) if cfg.execution_params.edf_env: edf_env = expand_path(cfg.execution_params.edf_env, project_root) @@ -99,6 +103,7 @@ def submit_slurm_job( cfg: MMirageConfig, config_path: str, shard_ids: Optional[Sequence[int]] = None, + collect_stats: bool = False, ) -> Optional[int]: """Submit a SLURM array job and return its job ID.""" project_root = get_project_root(cfg) @@ -134,7 +139,7 @@ def submit_slurm_job( logger.info("Submitting SLURM job: %s", " ".join(command)) result = subprocess.run( command, - input=build_sbatch_script(cfg, config_path), + input=build_sbatch_script(cfg, config_path, collect_stats=collect_stats), text=True, capture_output=True, check=False, diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index 838099a..e78b2c7 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -7,11 +7,11 @@ import os import sys from dataclasses import dataclass -from typing import List, Literal, Sequence, Tuple +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple from mmirage.config.config import MMirageConfig from mmirage.cli_utils.slurm import submit_slurm_job -from mmirage.shard_utils import ShardStatus +from mmirage.shard_utils import ShardStatus, _format_duration, _read_status, _shard_state_dir logger = logging.getLogger(__name__) @@ -153,3 +153,127 @@ def submit_failed_shards( return 1 return 0 + + +def collect_bench_stats(cfg: MMirageConfig) -> Dict[str, Any]: + """Collect per-shard benchmark statistics and compute aggregate totals. + + Returns a dict with two keys: + + - ``per_shard``: list of dicts, one per shard, each containing the full + :class:`~mmirage.shard_utils.ShardStatus` payload plus a flattened + ``stats`` sub-dict. + - ``aggregate``: rolled-up totals across all completed shards. + + Shards without ``stats`` (e.g. still running or from older runs) are + included in ``per_shard`` but excluded from aggregate calculations. + """ + state_root = cfg.loading_params.get_state_root() + num_shards = cfg.loading_params.get_num_shards() + + per_shard: List[Dict[str, Any]] = [] + + total_rows: int = 0 + sum_runtime: float = 0.0 + runtimes: List[float] = [] + gpu_util_weighted: List[float] = [] # util * rows for weighted mean + gpu_total_rows_for_weight: int = 0 + earliest_start: Optional[str] = None + latest_finish: Optional[str] = None + # Token-level aggregates (DataTrove-compatible benchmark format). + total_input_tokens: int = 0 + total_output_tokens: int = 0 + has_token_data: bool = False + num_gpus: Optional[int] = None # taken from first shard that has it + + for shard_id in range(num_shards): + state_dir = _shard_state_dir(shard_id, state_root) + status = _read_status(state_dir) + entry: Dict[str, Any] = status.to_dict() + per_shard.append(entry) + + if status.status != "success" or status.stats is None: + continue + + s = status.stats + if s.runtime_seconds is not None: + sum_runtime += s.runtime_seconds + runtimes.append(s.runtime_seconds) + if s.rows_processed is not None: + total_rows += s.rows_processed + if s.gpu_util_mean is not None and s.rows_processed: + gpu_util_weighted.append(s.gpu_util_mean * s.rows_processed) + gpu_total_rows_for_weight += s.rows_processed + + # Accumulate token counts. + if s.input_tokens is not None: + total_input_tokens += s.input_tokens + has_token_data = True + if s.output_tokens is not None: + total_output_tokens += s.output_tokens + has_token_data = True + if num_gpus is None and s.num_gpus is not None: + num_gpus = s.num_gpus + + # Track earliest start / latest finish for wall-clock runtime. + if status.started_at: + if earliest_start is None or status.started_at < earliest_start: + earliest_start = status.started_at + if status.finished_at: + if latest_finish is None or status.finished_at > latest_finish: + latest_finish = status.finished_at + + # Wall-clock runtime: time from first shard start to last shard finish. + wall_clock: Optional[float] = None + if earliest_start and latest_finish: + try: + from datetime import datetime as _dt + wall_clock = round( + (_dt.fromisoformat(latest_finish) - _dt.fromisoformat(earliest_start)).total_seconds(), + 3, + ) + except (ValueError, TypeError): + pass + + overall_throughput: Optional[float] = None + if total_rows > 0 and wall_clock and wall_clock > 0: + overall_throughput = round(total_rows / wall_clock, 2) + + mean_gpu_util: Optional[float] = None + if gpu_util_weighted and gpu_total_rows_for_weight > 0: + mean_gpu_util = round(sum(gpu_util_weighted) / gpu_total_rows_for_weight, 1) + + # Aggregate token-throughput metrics (DataTrove-compatible benchmark format). + # Uses sum of shard runtimes (total GPU-time) to compute a per-GPU token rate. + agg_tokens_per_sec_per_gpu: Optional[float] = None + agg_gpu_days_per_billion_tokens: Optional[float] = None + if has_token_data and total_output_tokens > 0 and runtimes and num_gpus and num_gpus > 0: + total_gpu_seconds = sum_runtime * num_gpus + agg_tokens_per_sec_per_gpu = round(total_output_tokens / total_gpu_seconds, 2) + total_gpu_days = total_gpu_seconds / 86_400 + agg_gpu_days_per_billion_tokens = round(total_gpu_days / (total_output_tokens / 1e9), 4) + + aggregate: Dict[str, Any] = { + "total_shards": num_shards, + "completed_shards": sum(1 for e in per_shard if e.get("status") == "success"), + "total_rows_processed": total_rows if total_rows > 0 else None, + "wall_clock_runtime_seconds": wall_clock, + "wall_clock_runtime_human": _format_duration(wall_clock), + "sum_shard_runtime_seconds": round(sum_runtime, 3) if runtimes else None, + "sum_shard_runtime_human": _format_duration(round(sum_runtime, 3) if runtimes else None), + "min_shard_runtime_seconds": round(min(runtimes), 3) if runtimes else None, + "min_shard_runtime_human": _format_duration(round(min(runtimes), 3) if runtimes else None), + "max_shard_runtime_seconds": round(max(runtimes), 3) if runtimes else None, + "max_shard_runtime_human": _format_duration(round(max(runtimes), 3) if runtimes else None), + "overall_throughput_rows_per_sec": overall_throughput, + "mean_gpu_util_pct": mean_gpu_util, + # Token-level benchmark metrics (DataTrove-compatible). + "num_gpus": num_gpus, + "total_input_tokens": total_input_tokens if has_token_data else None, + "total_output_tokens": total_output_tokens if has_token_data else None, + "tokens_per_sec_per_gpu": agg_tokens_per_sec_per_gpu, + "gpu_days_per_billion_tokens": agg_gpu_days_per_billion_tokens, + } + + return {"per_shard": per_shard, "aggregate": aggregate} + diff --git a/src/mmirage/core/process/mapper.py b/src/mmirage/core/process/mapper.py index 5310150..f9d8844 100644 --- a/src/mmirage/core/process/mapper.py +++ b/src/mmirage/core/process/mapper.py @@ -103,3 +103,21 @@ def rewrite_batch( ) return batch_environment + + def get_token_counts(self) -> Dict[str, int]: + """Return cumulative token counts aggregated across all LLM processors. + + Sums ``input_tokens`` and ``output_tokens`` from every processor that + exposes a ``get_token_counts()`` method (i.e., ``LLMProcessor``). + + Returns: + Dict with ``input_tokens`` and ``output_tokens`` keys. + """ + total_input = 0 + total_output = 0 + for proc in self.processors.values(): + if hasattr(proc, "get_token_counts"): + counts = proc.get_token_counts() + total_input += counts.get("input_tokens", 0) + total_output += counts.get("output_tokens", 0) + return {"input_tokens": total_input, "output_tokens": total_output} diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 56afd17..928e033 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -65,6 +65,28 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: ) self.sampling_params = engine_args.default_sampling_params self.chat_template = engine_args.chat_template + # Cumulative token counts across all generate() calls in this processor's lifetime. + self._total_input_tokens: int = 0 + self._total_output_tokens: int = 0 + + def get_token_counts(self) -> dict: + """Return cumulative token counts for this processor. + + Returns: + Dict with ``input_tokens`` (prompt tokens) and ``output_tokens`` + (completion tokens) accumulated since this processor was created. + """ + return { + "input_tokens": self._total_input_tokens, + "output_tokens": self._total_output_tokens, + } + + def _accumulate_tokens(self, outputs: list) -> None: + """Add token counts from a list of SGLang generate() outputs.""" + for out in outputs: + meta = out.get("meta_info") or {} + self._total_input_tokens += int(meta.get("prompt_tokens") or 0) + self._total_output_tokens += int(meta.get("completion_tokens") or 0) def build_prompt( self, prompt_template: str, vars_samples: List[VariableEnvironment] @@ -190,6 +212,8 @@ def batch_process_sample( f"{len(text_only_outputs) if isinstance(text_only_outputs, list) else 'non-list'}" ) + self._accumulate_tokens(text_only_outputs) + for local_idx, global_i in enumerate(text_only_indices): value = text_only_outputs[local_idx].get("text", "").strip() if output_var.output_type == "JSON": @@ -252,6 +276,8 @@ def batch_process_sample( f"{len(multimodal_outputs) if isinstance(multimodal_outputs, list) else 'non-list'}" ) + self._accumulate_tokens(multimodal_outputs) + for local_idx, global_i in enumerate(multimodal_indices): value = multimodal_outputs[local_idx].get("text", "").strip() if output_var.output_type == "JSON": diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 66e8529..e9ce55a 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -5,9 +5,10 @@ import argparse import logging +import os import sys import traceback -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from mmirage.config.utils import load_mmirage_config from mmirage.core.loader.base import DatasetLike @@ -15,6 +16,8 @@ from mmirage.core.process.mapper import MMIRAGEMapper from mmirage.core.writer.renderer import TemplateRenderer from mmirage.shard_utils import ( + GpuUtilizationPoller, + ShardStats, _cleanup_old_shard_data, _count_rows, _dataset_out_dir, @@ -88,10 +91,15 @@ def main(): state_dir = _shard_state_dir(shard_id, loading_params.get_state_root()) + collect_stats = os.environ.get("MMIRAGE_COLLECT_STATS", "") == "1" + gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller(interval_seconds=5.0) try: retry_count = _mark_running(state_dir, shard_id, datasets_config) logger.info(f"Starting shard {shard_id}/{last_shard_id} (attempt #{retry_count})") + if collect_stats: + gpu_poller.start() + if retry_count > 1: for ds_config in datasets_config: out_dir = _dataset_out_dir(shard_id, ds_config) @@ -148,13 +156,40 @@ def main(): _save_dataset_atomic(ds_processed, out_dir) logger.info(f"✅ Saved dataset {ds_idx} shard in: {out_dir}") - _mark_success(state_dir) + gpu_info = gpu_poller.stop() if collect_stats else {"mean": None, "min": None, "max": None, "samples": 0} + + # Collect token counts accumulated by LLM processor(s). + token_counts = mapper.get_token_counts() + input_tokens = token_counts["input_tokens"] or None + output_tokens = token_counts["output_tokens"] or None + + # Resolve num_gpus from the first processor config that exposes tp_size. + num_gpus: Optional[int] = None + for proc_cfg in cfg.processors: + tp = getattr(getattr(proc_cfg, "server_args", None), "tp_size", None) + if tp and tp > 0: + num_gpus = int(tp) + break + + stats = ShardStats( + rows_processed=shard_rows, + gpu_util_mean=gpu_info["mean"], + gpu_util_min=gpu_info["min"], + gpu_util_max=gpu_info["max"], + gpu_util_samples=gpu_info["samples"], + input_tokens=input_tokens, + output_tokens=output_tokens, + num_gpus=num_gpus, + ) + _mark_success(state_dir, stats=stats) logger.info(f"✅ Logical shard {shard_id} completed successfully") except Exception as e: error_msg = f"{type(e).__name__}: {str(e)}" logger.error(f"❌ Shard {shard_id} failed: {error_msg}") logger.error(traceback.format_exc()) + if collect_stats: + gpu_poller.stop() _mark_failure(state_dir, error_msg) sys.exit(1) diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index 76af6c9..ca09939 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -11,6 +11,8 @@ import os import shutil import socket +import subprocess +import threading import uuid from typing import Any, Dict, List, Optional @@ -21,6 +23,191 @@ logger = logging.getLogger(__name__) +def _format_duration(seconds: Optional[float]) -> Optional[str]: + """Format a duration given in seconds as a human-readable string. + + Examples:: + + _format_duration(45.3) -> "45s" + _format_duration(125.0) -> "2m 5s" + _format_duration(3725.0) -> "1h 2m 5s" + """ + if seconds is None: + return None + total = int(seconds) + hours, remainder = divmod(total, 3600) + minutes, secs = divmod(remainder, 60) + if hours: + return f"{hours}h {minutes}m {secs}s" + if minutes: + return f"{minutes}m {secs}s" + return f"{secs}s" + + +@dataclass +class ShardStats: + """Per-shard benchmark statistics recorded at completion.""" + + runtime_seconds: Optional[float] = None + rows_processed: Optional[int] = None + throughput_rows_per_sec: Optional[float] = None + gpu_util_mean: Optional[float] = None + gpu_util_min: Optional[float] = None + gpu_util_max: Optional[float] = None + gpu_util_samples: Optional[int] = None + # Token-level throughput metrics (DataTrove-compatible benchmark format). + # input_tokens: total prompt tokens consumed across all LLM calls in this shard. + # output_tokens: total completion tokens generated across all LLM calls in this shard. + # num_gpus: number of GPUs used (tensor-parallel size from the LLM processor config). + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + num_gpus: Optional[int] = None + + @classmethod + def from_dict(cls, data: Optional[Dict[str, Any]]) -> Optional["ShardStats"]: + """Build a ShardStats from a JSON payload, or return None if data is missing.""" + if not isinstance(data, dict): + return None + + def _opt_float(v: Any) -> Optional[float]: + try: + return float(v) if v is not None else None + except (TypeError, ValueError): + return None + + def _opt_int(v: Any) -> Optional[int]: + try: + return int(v) if v is not None else None + except (TypeError, ValueError): + return None + + return cls( + runtime_seconds=_opt_float(data.get("runtime_seconds")), + rows_processed=_opt_int(data.get("rows_processed")), + throughput_rows_per_sec=_opt_float(data.get("throughput_rows_per_sec")), + gpu_util_mean=_opt_float(data.get("gpu_util_mean")), + gpu_util_min=_opt_float(data.get("gpu_util_min")), + gpu_util_max=_opt_float(data.get("gpu_util_max")), + gpu_util_samples=_opt_int(data.get("gpu_util_samples")), + input_tokens=_opt_int(data.get("input_tokens")), + output_tokens=_opt_int(data.get("output_tokens")), + num_gpus=_opt_int(data.get("num_gpus")), + ) + + def to_dict(self) -> Dict[str, Any]: + # Derived token-throughput metrics (DataTrove-compatible benchmark format). + # tokens_per_sec_per_gpu = output_tokens / (runtime_seconds * num_gpus) + # gpu_days_per_billion_tokens = (num_gpus * runtime_seconds / 86_400) / (output_tokens / 1e9) + tokens_per_sec_per_gpu: Optional[float] = None + gpu_days_per_billion_tokens: Optional[float] = None + if ( + self.output_tokens is not None + and self.output_tokens > 0 + and self.runtime_seconds is not None + and self.runtime_seconds > 0 + and self.num_gpus is not None + and self.num_gpus > 0 + ): + tokens_per_sec_per_gpu = round( + self.output_tokens / (self.runtime_seconds * self.num_gpus), 2 + ) + gpu_days_per_billion_tokens = round( + (self.num_gpus * self.runtime_seconds / 86_400) / (self.output_tokens / 1e9), 4 + ) + + return { + "runtime_seconds": self.runtime_seconds, + "runtime_human": _format_duration(self.runtime_seconds), + "rows_processed": self.rows_processed, + "throughput_rows_per_sec": self.throughput_rows_per_sec, + "gpu_util_mean": self.gpu_util_mean, + "gpu_util_min": self.gpu_util_min, + "gpu_util_max": self.gpu_util_max, + "gpu_util_samples": self.gpu_util_samples, + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "num_gpus": self.num_gpus, + "tokens_per_sec_per_gpu": tokens_per_sec_per_gpu, + "gpu_days_per_billion_tokens": gpu_days_per_billion_tokens, + } + + +class GpuUtilizationPoller: + """Polls ``nvidia-smi`` in a background daemon thread. + + Usage:: + + poller = GpuUtilizationPoller() + poller.start() + # ... do work ... + gpu_info = poller.stop() # {"mean": 85.2, "min": 70.0, "max": 98.0, "samples": 24} + + If ``nvidia-smi`` is unavailable all values are ``None`` and samples is 0. + """ + + def __init__(self, interval_seconds: float = 5.0) -> None: + self._interval = interval_seconds + self._samples: List[float] = [] + self._thread: Optional[threading.Thread] = None + self._stop_event = threading.Event() + + def start(self) -> None: + """Start background polling.""" + self._stop_event.clear() + self._samples = [] + self._thread = threading.Thread(target=self._poll_loop, daemon=True) + self._thread.start() + + def stop(self) -> Dict[str, Any]: + """Stop polling and return a summary dict.""" + self._stop_event.set() + if self._thread is not None: + self._thread.join(timeout=self._interval + 2.0) + samples = self._samples + if not samples: + return {"mean": None, "min": None, "max": None, "samples": 0} + return { + "mean": round(sum(samples) / len(samples), 1), + "min": float(min(samples)), + "max": float(max(samples)), + "samples": len(samples), + } + + def _poll_loop(self) -> None: + while not self._stop_event.wait(timeout=self._interval): + util = self._query_gpu_util() + if util is not None: + self._samples.append(util) + + def _query_gpu_util(self) -> Optional[float]: + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=utilization.gpu", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + if result.returncode == 0: + values = [] + for line in result.stdout.strip().splitlines(): + line = line.strip() + if line: + try: + values.append(float(line)) + except ValueError: + pass + if values: + return sum(values) / len(values) + except Exception: + pass + return None + + @dataclass class ShardStatus: """Typed representation of the shard status.json payload.""" @@ -36,6 +223,7 @@ class ShardStatus: slurm_job_id: Optional[str] = None slurm_array_task_id: Optional[str] = None datasets: Optional[List[Dict[str, Any]]] = None + stats: Optional[ShardStats] = None @classmethod def from_dict(cls, payload: Dict[str, Any]) -> "ShardStatus": @@ -64,6 +252,8 @@ def from_dict(cls, payload: Dict[str, Any]) -> "ShardStatus": if not isinstance(datasets, list): datasets = None + stats = ShardStats.from_dict(data.get("stats")) + return cls( status=str(data.get("status", "unknown")), retry_count=retry_count, @@ -76,6 +266,7 @@ def from_dict(cls, payload: Dict[str, Any]) -> "ShardStatus": slurm_job_id=data.get("slurm_job_id"), slurm_array_task_id=data.get("slurm_array_task_id"), datasets=datasets, + stats=stats, ) def to_dict(self) -> Dict[str, Any]: @@ -92,6 +283,7 @@ def to_dict(self) -> Dict[str, Any]: "slurm_job_id": self.slurm_job_id, "slurm_array_task_id": self.slurm_array_task_id, "datasets": self.datasets, + "stats": self.stats.to_dict() if self.stats is not None else None, } @@ -284,12 +476,42 @@ def _mark_running( return retry_count -def _mark_success(state_dir: str): - """Mark shard as successful.""" +def _mark_success(state_dir: str, stats: Optional[ShardStats] = None): + """Mark shard as successful and record benchmark statistics. + + Args: + state_dir: Shard state directory. + stats: Optional benchmark stats; ``runtime_seconds`` and + ``throughput_rows_per_sec`` are computed from the stored timestamps + when not already set. + """ prev = _read_status(state_dir) prev.status = "success" - prev.finished_at = datetime.now().isoformat() + now = datetime.now() + prev.finished_at = now.isoformat() prev.error = None + + if stats is not None: + # Derive runtime from stored start timestamp when not already supplied. + if stats.runtime_seconds is None and prev.started_at: + try: + started = datetime.fromisoformat(prev.started_at) + stats.runtime_seconds = round((now - started).total_seconds(), 3) + except (ValueError, TypeError): + pass + + # Derive throughput once we have both rows and runtime. + if ( + stats.throughput_rows_per_sec is None + and stats.rows_processed is not None + and stats.runtime_seconds is not None + and stats.runtime_seconds > 0 + ): + stats.throughput_rows_per_sec = round( + stats.rows_processed / stats.runtime_seconds, 2 + ) + + prev.stats = stats _write_status(state_dir, prev) _clear_markers(state_dir) _touch_marker(state_dir, ".SUCCESS") From 7da36b4535e941035cdb06f95df2eaa089a2fd07 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 4 May 2026 14:26:50 +0200 Subject: [PATCH 02/24] fixed stats --- src/mmirage/shard_utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index ca09939..991b8db 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -181,12 +181,18 @@ def _poll_loop(self) -> None: def _query_gpu_util(self) -> Optional[float]: try: + cmd = [ + "nvidia-smi", + "--query-gpu=utilization.gpu", + "--format=csv,noheader,nounits", + ] + # Restrict to GPUs visible to this process so we don't dilute + # utilization by averaging over idle GPUs on the same node. + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): + cmd += [f"--id={cuda_visible}"] result = subprocess.run( - [ - "nvidia-smi", - "--query-gpu=utilization.gpu", - "--format=csv,noheader,nounits", - ], + cmd, capture_output=True, text=True, timeout=5, From 3e1a9bc96e36275c1a16bfd56bd637bc6919aaf7 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 4 May 2026 14:52:48 +0200 Subject: [PATCH 03/24] small test --- src/mmirage/shard_process.py | 16 +++++++++++++++- src/mmirage/shard_utils.py | 19 +++++++++++++------ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index e9ce55a..be71e21 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -92,7 +92,21 @@ def main(): state_dir = _shard_state_dir(shard_id, loading_params.get_state_root()) collect_stats = os.environ.get("MMIRAGE_COLLECT_STATS", "") == "1" - gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller(interval_seconds=5.0) + # Determine which physical GPU indices SGLang will use so the poller + # measures only the active GPU(s) — not all GPUs on the node. + # In SLURM mode CUDA_VISIBLE_DEVICES is already set by the scheduler; + # in local mode we derive indices from tp_size (SGLang defaults to 0..N-1). + gpu_indices_for_polling: Optional[list] = None + if not os.environ.get("CUDA_VISIBLE_DEVICES"): + for proc_cfg in cfg.processors: + tp = getattr(getattr(proc_cfg, "server_args", None), "tp_size", None) + if tp and int(tp) > 0: + gpu_indices_for_polling = list(range(int(tp))) + break + + gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller( + interval_seconds=5.0, gpu_indices=gpu_indices_for_polling + ) try: retry_count = _mark_running(state_dir, shard_id, datasets_config) logger.info(f"Starting shard {shard_id}/{last_shard_id} (attempt #{retry_count})") diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index 991b8db..d340b4d 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -145,11 +145,14 @@ class GpuUtilizationPoller: If ``nvidia-smi`` is unavailable all values are ``None`` and samples is 0. """ - def __init__(self, interval_seconds: float = 5.0) -> None: + def __init__(self, interval_seconds: float = 5.0, gpu_indices: Optional[List[int]] = None) -> None: self._interval = interval_seconds self._samples: List[float] = [] self._thread: Optional[threading.Thread] = None self._stop_event = threading.Event() + # Explicit GPU indices take priority over CUDA_VISIBLE_DEVICES. + # Pass the indices SGLang will use (0..tp_size-1 in local mode). + self._gpu_indices = gpu_indices def start(self) -> None: """Start background polling.""" @@ -186,11 +189,15 @@ def _query_gpu_util(self) -> Optional[float]: "--query-gpu=utilization.gpu", "--format=csv,noheader,nounits", ] - # Restrict to GPUs visible to this process so we don't dilute - # utilization by averaging over idle GPUs on the same node. - cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") - if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): - cmd += [f"--id={cuda_visible}"] + # Restrict to the GPUs this process actually uses so we don't + # dilute utilization by averaging over idle GPUs on the same node. + # Priority: explicit gpu_indices > CUDA_VISIBLE_DEVICES > all GPUs. + if self._gpu_indices is not None: + cmd += [f"--id={','.join(str(i) for i in self._gpu_indices)}"] + else: + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): + cmd += [f"--id={cuda_visible}"] result = subprocess.run( cmd, capture_output=True, From a1386f79658ecb3f09f052b54328cf8b71931e05 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 4 May 2026 15:28:37 +0200 Subject: [PATCH 04/24] testing something --- configs/config_benchmark_datatrove.yaml | 25 +++++++++++-------- .../core/process/processors/llm/config.py | 9 +++++++ .../process/processors/llm/llm_processor.py | 5 +++- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/configs/config_benchmark_datatrove.yaml b/configs/config_benchmark_datatrove.yaml index ca4affb..f16ac3d 100644 --- a/configs/config_benchmark_datatrove.yaml +++ b/configs/config_benchmark_datatrove.yaml @@ -36,13 +36,16 @@ processors: tp_size: 1 # DataTrove baseline: tp=1 trust_remote_code: true disable_custom_all_reduce: true + # SGLang engine tuning — equivalents of DataTrove's vLLM mns/mnbt knobs. + # Tune these to close the gap vs the DataTrove baseline (7 919 tok/s/GPU). + # Uncomment and adjust for your hardware: + extra_engine_args: + max_running_requests: 512 # vLLM mns=512 equivalent + chunked_prefill_size: 32768 # vLLM mnbt=32768 equivalent + # mem_fraction_static: 0.90 # increase if GPU has headroom default_sampling_params: temperature: 0.0 # greedy — maximises reproducible throughput max_new_tokens: 1024 # DataTrove: max-tokens=1024 - # Keep context to 2 048 tokens to match DataTrove model-max-context=2048. - # SGLang enforces this via max_total_tokens on the server; set it in - # server_args if your SGLang version supports it: - # max_total_tokens: 2048 loading_params: state_dir: /users/qchapp/data/benchmark_s1k/_pipeline_state @@ -50,9 +53,9 @@ loading_params: - path: /users/qchapp/data/s1K-1.1 # save_to_disk() target above type: loadable output_dir: /users/qchapp/data/benchmark_s1k/output - num_shards: 4 # adjust to your GPU/node count - shard_id: "$SLURM_ARRAY_TASK_ID" # use 0 for a local single-shard run - batch_size: 64 + num_shards: 1 + shard_id: "$SLURM_ARRAY_TASK_ID" + batch_size: 1000 processing_params: inputs: @@ -71,8 +74,8 @@ processing_params: answer: "{{ answer }}" execution_params: - mode: local # switch to slurm for multi-node + mode: local retry: false - merge: true - report_dir: /users/qchapp/reports - hf_home: /users/qchapp/hf + merge: false + report_dir: ~/reports + hf_home: ~/hf diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index f323599..dde3029 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -59,12 +59,21 @@ class SGLangServerArgs: tp_size: Tensor parallelism size. trust_remote_code: Whether to trust remote code from HuggingFace. disable_custom_all_reduce: Whether to disable custom all reduce. + extra_engine_args: Any additional keyword arguments forwarded verbatim + to ``sgl.Engine``. Use this to pass SGLang-specific options that + are not listed above, e.g.:: + + extra_engine_args: + max_running_requests: 512 + chunked_prefill_size: 32768 + mem_fraction_static: 0.88 """ model_path: str = "none" tp_size: int = field(default_factory=_parse_tp_size_from_env) trust_remote_code: bool = True disable_custom_all_reduce: bool = False + extra_engine_args: Dict[str, Any] = field(default_factory=dict) @dataclass diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 928e033..2e3b0e9 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -58,7 +58,10 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: **kwargs: Additional arguments passed to base class. """ super().__init__(engine_args, **kwargs) - self.llm = sgl.Engine(**asdict(engine_args.server_args)) + server_kwargs = asdict(engine_args.server_args) + extra = server_kwargs.pop("extra_engine_args", {}) or {} + server_kwargs.update(extra) + self.llm = sgl.Engine(**server_kwargs) self.tokenizer = AutoTokenizer.from_pretrained( engine_args.server_args.model_path, trust_remote_code=getattr(engine_args.server_args, "trust_remote_code", False), From 5181248d37eef70bb35120ba143e785723a76c77 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 4 May 2026 15:38:59 +0200 Subject: [PATCH 05/24] slurm config --- configs/config_benchmark_datatrove.yaml | 77 +++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 4 deletions(-) diff --git a/configs/config_benchmark_datatrove.yaml b/configs/config_benchmark_datatrove.yaml index f16ac3d..7e7d0f0 100644 --- a/configs/config_benchmark_datatrove.yaml +++ b/configs/config_benchmark_datatrove.yaml @@ -53,7 +53,7 @@ loading_params: - path: /users/qchapp/data/s1K-1.1 # save_to_disk() target above type: loadable output_dir: /users/qchapp/data/benchmark_s1k/output - num_shards: 1 + num_shards: 4 shard_id: "$SLURM_ARRAY_TASK_ID" batch_size: 1000 @@ -74,8 +74,77 @@ processing_params: answer: "{{ answer }}" execution_params: - mode: local + # Execution mode: "local" or "slurm" + # - local: Run directly on this machine + # - slurm: Submit jobs to SLURM cluster + mode: slurm + + # Whether the canonical `run` command should automatically retry failed shards. + # - false: submit one run only + # - true: submit, wait, and keep retrying failed shards until success or retry budget exhaustion retry: false + + # Whether to merge shard outputs after a successful run. + # - false: keep shard_* outputs only + # - true: build merged datasets from shard_* outputs merge: false - report_dir: ~/reports - hf_home: ~/hf + + # Maximum number of times to retry a failed shard (default: 3) + max_retries: 3 + + # ========================================================================== + # SLURM CONFIGURATION (only used when mode: slurm) + # ========================================================================== + + # HPC account/partition to charge jobs to (REQUIRED for SLURM mode) + account: a127 + + # SLURM job name (default: "mmirage-sharded") + job_name: mmirage-sharded + + # Optional SLURM reservation name (leave blank or omit to not use) + # reservation: "sai-a127" + + # Number of nodes (default: 1) + nodes: 1 + + # Number of tasks per node (default: 1) + ntasks_per_node: 1 + + # Number of GPUs per node (default: 4) + gpus: 4 + + # Number of CPUs per task (default: 288) + cpus_per_task: 288 + + # Job time limit in HH:MM:SS format (default: "11:59:59") + time_limit: "11:59:59" + + # ========================================================================== + # PATH CONFIGURATION + # ========================================================================== + # These support environment variables ($VAR or ${VAR}) and home directory (~) + + # Project root directory (used as base for relative paths) + # If not set, uses current working directory + # project_root: "/path/to/project" + + # Directory for SLURM output and error files (default: ~/reports) + report_dir: "/users/${USER}/reports" + + # HuggingFace cache directory (default: ~/hf) + hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf" + + # EDF environment file path for cluster-specific setup + edf_env: "/users/${USER}/.edf/sglang.toml" + + # ========================================================================== + # JOB MONITORING (for "submit" and retry orchestration) + # ========================================================================== + + # Seconds to wait between checking job status (default: 30) + poll_interval_seconds: 30 + + # Seconds to wait after job completes before checking results (default: 60) + # This allows filesystem to settle on distributed systems + settle_time_seconds: 60 From 4af933ddf78ae46d9e55b1453a614b0052d8f132 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 4 May 2026 15:44:39 +0200 Subject: [PATCH 06/24] display issue with multiple nodes --- src/mmirage/shard_process.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index be71e21..622c435 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -94,15 +94,24 @@ def main(): collect_stats = os.environ.get("MMIRAGE_COLLECT_STATS", "") == "1" # Determine which physical GPU indices SGLang will use so the poller # measures only the active GPU(s) — not all GPUs on the node. - # In SLURM mode CUDA_VISIBLE_DEVICES is already set by the scheduler; - # in local mode we derive indices from tp_size (SGLang defaults to 0..N-1). - gpu_indices_for_polling: Optional[list] = None - if not os.environ.get("CUDA_VISIBLE_DEVICES"): - for proc_cfg in cfg.processors: - tp = getattr(getattr(proc_cfg, "server_args", None), "tp_size", None) - if tp and int(tp) > 0: - gpu_indices_for_polling = list(range(int(tp))) - break + # SLURM may allocate more GPUs than tp_size (e.g. gpus=4, tp_size=1). + # We take only the first tp_size entries from CUDA_VISIBLE_DEVICES so + # nvidia-smi --id receives exactly the GPUs SGLang is using. + tp_size = 1 + for proc_cfg in cfg.processors: + tp = getattr(getattr(proc_cfg, "server_args", None), "tp_size", None) + if tp and int(tp) > 0: + tp_size = int(tp) + break + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): + try: + all_visible = [int(x.strip()) for x in cuda_visible.split(",") if x.strip()] + gpu_indices_for_polling: Optional[list] = all_visible[:tp_size] + except ValueError: + gpu_indices_for_polling = list(range(tp_size)) + else: + gpu_indices_for_polling = list(range(tp_size)) gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller( interval_seconds=5.0, gpu_indices=gpu_indices_for_polling From 5398ab20ddb5f060287e0b4b3b652f6a83af2dd4 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 4 May 2026 15:53:34 +0200 Subject: [PATCH 07/24] small test again --- configs/config_benchmark_datatrove.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/config_benchmark_datatrove.yaml b/configs/config_benchmark_datatrove.yaml index 7e7d0f0..70eaeaf 100644 --- a/configs/config_benchmark_datatrove.yaml +++ b/configs/config_benchmark_datatrove.yaml @@ -46,6 +46,7 @@ processors: default_sampling_params: temperature: 0.0 # greedy — maximises reproducible throughput max_new_tokens: 1024 # DataTrove: max-tokens=1024 + enable_thinking: false # disable Qwen3 chain-of-thought; matches DataTrove baseline loading_params: state_dir: /users/qchapp/data/benchmark_s1k/_pipeline_state @@ -53,7 +54,7 @@ loading_params: - path: /users/qchapp/data/s1K-1.1 # save_to_disk() target above type: loadable output_dir: /users/qchapp/data/benchmark_s1k/output - num_shards: 4 + num_shards: 1 shard_id: "$SLURM_ARRAY_TASK_ID" batch_size: 1000 From aafb230f796c80e557d700a92b5833aebf3e503e Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 4 May 2026 16:22:08 +0200 Subject: [PATCH 08/24] trying again --- configs/config_benchmark_datatrove.yaml | 6 ++++-- src/mmirage/core/process/processors/llm/llm_processor.py | 8 ++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/configs/config_benchmark_datatrove.yaml b/configs/config_benchmark_datatrove.yaml index 70eaeaf..aa222d7 100644 --- a/configs/config_benchmark_datatrove.yaml +++ b/configs/config_benchmark_datatrove.yaml @@ -46,7 +46,6 @@ processors: default_sampling_params: temperature: 0.0 # greedy — maximises reproducible throughput max_new_tokens: 1024 # DataTrove: max-tokens=1024 - enable_thinking: false # disable Qwen3 chain-of-thought; matches DataTrove baseline loading_params: state_dir: /users/qchapp/data/benchmark_s1k/_pipeline_state @@ -67,7 +66,10 @@ processing_params: - name: answer type: llm output_type: plain - prompt: "{{ question }}" # bare question, no system prompt (DataTrove baseline) + # Qwen3 thinking is disabled by embedding an empty block in the prompt. + # This is equivalent to passing enable_thinking=False to the chat template and + # avoids any dependency on SGLang sampling-param support for that flag. + prompt: "<|im_start|>user\n{{ question }}\n<|im_end|>\n<|im_start|>assistant\n\n\n\n" remove_columns: false output_schema: diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 2e3b0e9..1761f67 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -230,9 +230,7 @@ def batch_process_sample( logger.error( f"Batch generation failed for text-only samples in output '{output_var.name}': {e}" ) - for global_i in text_only_indices: - empty_val = {} if output_var.output_type == "JSON" else "" - results[global_i] = batch[global_i].with_variable(output_var.name, empty_val) + raise # Multimodal batch if multimodal_indices: @@ -294,9 +292,7 @@ def batch_process_sample( logger.error( f"Batch generation failed for multimodal samples in output '{output_var.name}': {e}" ) - for global_i in multimodal_indices: - empty_val = {} if output_var.output_type == "JSON" else "" - results[global_i] = batch[global_i].with_variable(output_var.name, empty_val) + raise return [results[i] for i in range(nb_samples)] From 79ca2fe4b105b54fd268f4694f1f39d9fbbbc917 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 4 May 2026 22:58:18 +0200 Subject: [PATCH 09/24] excluding cold start --- configs/config_benchmark_datatrove.yaml | 3 +-- src/mmirage/cli_utils/status.py | 19 +++++++++++----- src/mmirage/core/process/mapper.py | 8 +++++++ .../process/processors/llm/llm_processor.py | 7 ++++++ src/mmirage/shard_process.py | 2 ++ src/mmirage/shard_utils.py | 22 ++++++++++++++----- 6 files changed, 48 insertions(+), 13 deletions(-) diff --git a/configs/config_benchmark_datatrove.yaml b/configs/config_benchmark_datatrove.yaml index aa222d7..0630ed1 100644 --- a/configs/config_benchmark_datatrove.yaml +++ b/configs/config_benchmark_datatrove.yaml @@ -40,8 +40,7 @@ processors: # Tune these to close the gap vs the DataTrove baseline (7 919 tok/s/GPU). # Uncomment and adjust for your hardware: extra_engine_args: - max_running_requests: 512 # vLLM mns=512 equivalent - chunked_prefill_size: 32768 # vLLM mnbt=32768 equivalent + max_running_requests: 1000 # vLLM mns=512 equivalent # mem_fraction_static: 0.90 # increase if GPU has headroom default_sampling_params: temperature: 0.0 # greedy — maximises reproducible throughput diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index e78b2c7..1286b55 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -184,6 +184,7 @@ def collect_bench_stats(cfg: MMirageConfig) -> Dict[str, Any]: total_input_tokens: int = 0 total_output_tokens: int = 0 has_token_data: bool = False + sum_model_load_seconds: float = 0.0 num_gpus: Optional[int] = None # taken from first shard that has it for shard_id in range(num_shards): @@ -212,6 +213,8 @@ def collect_bench_stats(cfg: MMirageConfig) -> Dict[str, Any]: if s.output_tokens is not None: total_output_tokens += s.output_tokens has_token_data = True + if s.model_load_seconds is not None: + sum_model_load_seconds += s.model_load_seconds if num_gpus is None and s.num_gpus is not None: num_gpus = s.num_gpus @@ -244,14 +247,18 @@ def collect_bench_stats(cfg: MMirageConfig) -> Dict[str, Any]: mean_gpu_util = round(sum(gpu_util_weighted) / gpu_total_rows_for_weight, 1) # Aggregate token-throughput metrics (DataTrove-compatible benchmark format). - # Uses sum of shard runtimes (total GPU-time) to compute a per-GPU token rate. + # Uses sum of inference runtimes (total minus model loading) for a per-GPU token rate + # that excludes one-time model initialisation overhead. agg_tokens_per_sec_per_gpu: Optional[float] = None agg_gpu_days_per_billion_tokens: Optional[float] = None + agg_inference_runtime: Optional[float] = None if has_token_data and total_output_tokens > 0 and runtimes and num_gpus and num_gpus > 0: - total_gpu_seconds = sum_runtime * num_gpus - agg_tokens_per_sec_per_gpu = round(total_output_tokens / total_gpu_seconds, 2) - total_gpu_days = total_gpu_seconds / 86_400 - agg_gpu_days_per_billion_tokens = round(total_gpu_days / (total_output_tokens / 1e9), 4) + agg_inference_runtime = max(0.0, sum_runtime - sum_model_load_seconds) + if agg_inference_runtime > 0: + total_gpu_seconds = agg_inference_runtime * num_gpus + agg_tokens_per_sec_per_gpu = round(total_output_tokens / total_gpu_seconds, 2) + total_gpu_days = total_gpu_seconds / 86_400 + agg_gpu_days_per_billion_tokens = round(total_gpu_days / (total_output_tokens / 1e9), 4) aggregate: Dict[str, Any] = { "total_shards": num_shards, @@ -271,6 +278,8 @@ def collect_bench_stats(cfg: MMirageConfig) -> Dict[str, Any]: "num_gpus": num_gpus, "total_input_tokens": total_input_tokens if has_token_data else None, "total_output_tokens": total_output_tokens if has_token_data else None, + "sum_model_load_seconds": round(sum_model_load_seconds, 3) if sum_model_load_seconds else None, + "sum_inference_runtime_seconds": round(agg_inference_runtime, 3) if agg_inference_runtime is not None else None, "tokens_per_sec_per_gpu": agg_tokens_per_sec_per_gpu, "gpu_days_per_billion_tokens": agg_gpu_days_per_billion_tokens, } diff --git a/src/mmirage/core/process/mapper.py b/src/mmirage/core/process/mapper.py index f9d8844..c8d8a63 100644 --- a/src/mmirage/core/process/mapper.py +++ b/src/mmirage/core/process/mapper.py @@ -121,3 +121,11 @@ def get_token_counts(self) -> Dict[str, int]: total_input += counts.get("input_tokens", 0) total_output += counts.get("output_tokens", 0) return {"input_tokens": total_input, "output_tokens": total_output} + + def get_load_time(self) -> float: + """Return total model-loading time (seconds) summed across all LLM processors.""" + total = 0.0 + for proc in self.processors.values(): + if hasattr(proc, "get_load_time"): + total += proc.get_load_time() + return total diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 1761f67..55209d8 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -5,6 +5,7 @@ from dataclasses import asdict import json import logging +import time from typing import Any, List, Tuple import jinja2 @@ -61,7 +62,9 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: server_kwargs = asdict(engine_args.server_args) extra = server_kwargs.pop("extra_engine_args", {}) or {} server_kwargs.update(extra) + _load_start = time.monotonic() self.llm = sgl.Engine(**server_kwargs) + self._model_load_seconds: float = time.monotonic() - _load_start self.tokenizer = AutoTokenizer.from_pretrained( engine_args.server_args.model_path, trust_remote_code=getattr(engine_args.server_args, "trust_remote_code", False), @@ -72,6 +75,10 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: self._total_input_tokens: int = 0 self._total_output_tokens: int = 0 + def get_load_time(self) -> float: + """Return the wall-clock seconds spent initializing the SGLang engine.""" + return self._model_load_seconds + def get_token_counts(self) -> dict: """Return cumulative token counts for this processor. diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 622c435..298c8f6 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -185,6 +185,7 @@ def main(): token_counts = mapper.get_token_counts() input_tokens = token_counts["input_tokens"] or None output_tokens = token_counts["output_tokens"] or None + model_load_seconds = mapper.get_load_time() or None # Resolve num_gpus from the first processor config that exposes tp_size. num_gpus: Optional[int] = None @@ -203,6 +204,7 @@ def main(): input_tokens=input_tokens, output_tokens=output_tokens, num_gpus=num_gpus, + model_load_seconds=model_load_seconds, ) _mark_success(state_dir, stats=stats) logger.info(f"✅ Logical shard {shard_id} completed successfully") diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index d340b4d..bb8741a 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -62,6 +62,7 @@ class ShardStats: input_tokens: Optional[int] = None output_tokens: Optional[int] = None num_gpus: Optional[int] = None + model_load_seconds: Optional[float] = None @classmethod def from_dict(cls, data: Optional[Dict[str, Any]]) -> Optional["ShardStats"]: @@ -92,32 +93,41 @@ def _opt_int(v: Any) -> Optional[int]: input_tokens=_opt_int(data.get("input_tokens")), output_tokens=_opt_int(data.get("output_tokens")), num_gpus=_opt_int(data.get("num_gpus")), + model_load_seconds=_opt_float(data.get("model_load_seconds")), ) def to_dict(self) -> Dict[str, Any]: # Derived token-throughput metrics (DataTrove-compatible benchmark format). - # tokens_per_sec_per_gpu = output_tokens / (runtime_seconds * num_gpus) - # gpu_days_per_billion_tokens = (num_gpus * runtime_seconds / 86_400) / (output_tokens / 1e9) + # Use inference_runtime (total minus model loading) so metrics reflect + # pure generation speed, excluding one-time model initialisation overhead. tokens_per_sec_per_gpu: Optional[float] = None gpu_days_per_billion_tokens: Optional[float] = None + inference_runtime: Optional[float] = None + if self.runtime_seconds is not None: + if self.model_load_seconds is not None: + inference_runtime = max(0.0, self.runtime_seconds - self.model_load_seconds) + else: + inference_runtime = self.runtime_seconds if ( self.output_tokens is not None and self.output_tokens > 0 - and self.runtime_seconds is not None - and self.runtime_seconds > 0 + and inference_runtime is not None + and inference_runtime > 0 and self.num_gpus is not None and self.num_gpus > 0 ): tokens_per_sec_per_gpu = round( - self.output_tokens / (self.runtime_seconds * self.num_gpus), 2 + self.output_tokens / (inference_runtime * self.num_gpus), 2 ) gpu_days_per_billion_tokens = round( - (self.num_gpus * self.runtime_seconds / 86_400) / (self.output_tokens / 1e9), 4 + (self.num_gpus * inference_runtime / 86_400) / (self.output_tokens / 1e9), 4 ) return { "runtime_seconds": self.runtime_seconds, "runtime_human": _format_duration(self.runtime_seconds), + "model_load_seconds": round(self.model_load_seconds, 3) if self.model_load_seconds is not None else None, + "inference_runtime_seconds": round(inference_runtime, 3) if inference_runtime is not None else None, "rows_processed": self.rows_processed, "throughput_rows_per_sec": self.throughput_rows_per_sec, "gpu_util_mean": self.gpu_util_mean, From 3ba5c976bd1ab3d2259e4e9d8e27dc002c1d22aa Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Mon, 4 May 2026 23:25:07 +0200 Subject: [PATCH 10/24] now same for gpu --- src/mmirage/shard_process.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 298c8f6..72c70a8 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -120,9 +120,6 @@ def main(): retry_count = _mark_running(state_dir, shard_id, datasets_config) logger.info(f"Starting shard {shard_id}/{last_shard_id} (attempt #{retry_count})") - if collect_stats: - gpu_poller.start() - if retry_count > 1: for ds_config in datasets_config: out_dir = _dataset_out_dir(shard_id, ds_config) @@ -146,6 +143,11 @@ def main(): ) renderer = TemplateRenderer(processing_params.output_schema) + # Start GPU polling after model loading so utilisation samples reflect + # inference only, not weight transfers during sgl.Engine() init. + if collect_stats: + gpu_poller.start() + ds_processed_all: List[DatasetLike] = [] for ds_idx, ds_shard in enumerate(ds_all_shard): ds_config = datasets_config[ds_idx] From 19d7d8a2d67de64dc1a4d5382b9b29d311e2135f Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Wed, 6 May 2026 12:41:13 +0200 Subject: [PATCH 11/24] small corrections --- README.md | 161 +++++++++--------- configs/config_benchmark_datatrove.yaml | 79 +-------- .../process/processors/llm/llm_processor.py | 8 +- src/mmirage/shard_process.py | 46 ++--- 4 files changed, 119 insertions(+), 175 deletions(-) diff --git a/README.md b/README.md index 4730dc4..b6adaaa 100644 --- a/README.md +++ b/README.md @@ -87,84 +87,6 @@ mmirage merge --config configs/config_mock.yaml --output-root /path/to/merged MMIRAGE still keeps datasets separate by creating one subdirectory per dataset under the root. -### Benchmarking shard performance - -Pass `--stats` to `run` or `submit` to enable per-shard benchmarking. This activates GPU -utilization polling and throughput tracking on compute nodes — disabled by default to -avoid unnecessary overhead. - -```bash -# Local run with stats collection -mmirage run --config configs/config_mock.yaml --stats - -# SLURM submission with stats collection -mmirage submit --config configs/config_mock.yaml --stats -``` - -After the run completes, inspect the results with: - -```bash -mmirage stats --config configs/config_mock.yaml -``` - -This prints a JSON report with per-shard details and an aggregate summary: - -```json -{ - "per_shard": [ - { - "shard_id": 0, - "status": "success", - "started_at": "2026-04-30T10:00:00", - "finished_at": "2026-04-30T10:01:05", - "stats": { - "runtime_seconds": 65.2, - "runtime_human": "1m 5s", - "rows_processed": 1024, - "throughput_rows_per_sec": 15.7, - "gpu_util_mean": 88.4, - "gpu_util_min": 72.0, - "gpu_util_max": 98.0, - "gpu_util_samples": 13, - "input_tokens": 512000, - "output_tokens": 196608, - "num_gpus": 4, - "tokens_per_sec_per_gpu": 753.1, - "gpu_days_per_billion_tokens": 0.0015 - } - } - ], - "aggregate": { - "total_shards": 4, - "completed_shards": 4, - "total_rows_processed": 4096, - "wall_clock_runtime_seconds": 68.1, - "wall_clock_runtime_human": "1m 8s", - "sum_shard_runtime_seconds": 261.4, - "sum_shard_runtime_human": "4m 21s", - "min_shard_runtime_seconds": 62.3, - "min_shard_runtime_human": "1m 2s", - "max_shard_runtime_seconds": 69.7, - "max_shard_runtime_human": "1m 9s", - "overall_throughput_rows_per_sec": 60.1, - "mean_gpu_util_pct": 87.9, - "num_gpus": 4, - "total_input_tokens": 2048000, - "total_output_tokens": 786432, - "tokens_per_sec_per_gpu": 750.8, - "gpu_days_per_billion_tokens": 0.0015 - } -} -``` - -Key metrics: -- **`runtime_seconds`** / **`runtime_human`**: time from when the shard started on the cluster (after dispatch), excluding queue wait time. -- **`overall_throughput_rows_per_sec`**: total rows / wall-clock time across all shards running in parallel. -- **`mean_gpu_util_pct`**: weighted average GPU utilization across shards (weighted by rows processed). -- **`tokens_per_sec_per_gpu`**: output tokens generated per second per GPU — the primary throughput metric used by frameworks such as [DataTrove](https://github.com/huggingface/datatrove). -- **`gpu_days_per_billion_tokens`**: total GPU-days consumed to generate 1 billion output tokens — useful for cost and scaling comparisons across different hardware configurations. -- Token metrics are `null` when no LLM processor was active, and GPU stats are `null` when `nvidia-smi` is unavailable or `--stats` was not passed. - ### Text-only: Reformatting dataset Suppose you have a dataset with samples of the following format @@ -313,6 +235,88 @@ Key multimodal features: - `image_base_path`: Base directory for resolving relative image paths - Supports PIL Images, URLs, and file paths +### Benchmarking shard performance + +Pass `--stats` to `run` or `submit` to enable per-shard benchmarking. This activates GPU +utilization polling and throughput tracking on compute nodes — disabled by default to +avoid unnecessary overhead. + +```bash +# Local run with stats collection +mmirage run --config configs/config_mock.yaml --stats + +``` + +After the run completes, inspect the results with: + +```bash +mmirage stats --config configs/config_mock.yaml +``` + +This prints a JSON report with per-shard details and an aggregate summary: + +```json +{ + "per_shard": [ + { + "shard_id": 0, + "status": "success", + "started_at": "2026-04-30T10:00:00", + "finished_at": "2026-04-30T10:01:05", + "stats": { + "runtime_seconds": 65.2, + "runtime_human": "1m 5s", + "rows_processed": 1024, + "throughput_rows_per_sec": 15.7, + "gpu_util_mean": 88.4, + "gpu_util_min": 72.0, + "gpu_util_max": 98.0, + "gpu_util_samples": 13, + "input_tokens": 512000, + "output_tokens": 196608, + "num_gpus": 4, + "tokens_per_sec_per_gpu": 753.1, + "gpu_days_per_billion_tokens": 0.0015 + } + } + ], + "aggregate": { + "total_shards": 1, + "completed_shards": 1, + "total_rows_processed": 1000, + "wall_clock_runtime_seconds": 133.04, + "wall_clock_runtime_human": "2m 13s", + "sum_shard_runtime_seconds": 133.04, + "sum_shard_runtime_human": "2m 13s", + "min_shard_runtime_seconds": 133.04, + "min_shard_runtime_human": "2m 13s", + "max_shard_runtime_seconds": 133.04, + "max_shard_runtime_human": "2m 13s", + "overall_throughput_rows_per_sec": 7.52, + "mean_gpu_util_pct": 86.2, + "num_gpus": 1, + "total_input_tokens": 146214, + "total_output_tokens": 1022046, + "sum_model_load_seconds": 38.272, + "sum_inference_runtime_seconds": 94.768, + "tokens_per_sec_per_gpu": 10784.72, + "gpu_days_per_billion_tokens": 1.0732 + } +} +``` + +Key metrics: +- **`runtime_seconds`** / **`runtime_human`**: time from when the shard started on the cluster (after dispatch), excluding queue wait time. +- **`overall_throughput_rows_per_sec`**: total rows / wall-clock time across all shards running in parallel. +- **`mean_gpu_util_pct`**: mean percentage GPU utilization across shards. +- **`tokens_per_sec_per_gpu`**: output tokens generated per second per GPU — the primary throughput metric used by frameworks such as [DataTrove](https://github.com/huggingface/datatrove). +- **`gpu_days_per_billion_tokens`**: total GPU-days consumed to generate 1 billion output tokens — useful for cost and scaling comparisons across different hardware configurations. +- Token metrics are `null` when no LLM processor was active, and GPU stats are `null` when `nvidia-smi` is unavailable or `--stats` was not passed. + +Reference benchmark: +- [DataTrove Benchmark](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark) +- `mmirage run --config configs/config_bencmark_datatrove.yaml --stats` + ## Architecture MMIRAGE uses a modular architecture: @@ -336,3 +340,4 @@ mmirage/ - JMESPath for JSON queries: [link](https://jmespath.org/) - SGLang for fast inference: [link](https://github.com/sgl-project/sglang) - Performance paper: [link](https://arxiv.org/abs/2408.02442) +- DataTrove Benchmark: [link](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark) diff --git a/configs/config_benchmark_datatrove.yaml b/configs/config_benchmark_datatrove.yaml index 0630ed1..d44e87d 100644 --- a/configs/config_benchmark_datatrove.yaml +++ b/configs/config_benchmark_datatrove.yaml @@ -24,10 +24,6 @@ # Inspect results: # # mmirage stats --config configs/config_benchmark_datatrove.yaml -# -# Key metrics to compare against DataTrove: -# tokens_per_sec_per_gpu (DataTrove: output_tps_per_gpu) -# gpu_days_per_billion_tokens (DataTrove: gpu_days_to_process_1b_tokens) processors: - type: llm @@ -36,22 +32,19 @@ processors: tp_size: 1 # DataTrove baseline: tp=1 trust_remote_code: true disable_custom_all_reduce: true - # SGLang engine tuning — equivalents of DataTrove's vLLM mns/mnbt knobs. - # Tune these to close the gap vs the DataTrove baseline (7 919 tok/s/GPU). - # Uncomment and adjust for your hardware: + # SGLang engine tuning — equivalents of DataTrove's vLLM mns/mnbt knobs extra_engine_args: - max_running_requests: 1000 # vLLM mns=512 equivalent - # mem_fraction_static: 0.90 # increase if GPU has headroom + max_running_requests: 1000 default_sampling_params: - temperature: 0.0 # greedy — maximises reproducible throughput + temperature: 0.0 max_new_tokens: 1024 # DataTrove: max-tokens=1024 loading_params: - state_dir: /users/qchapp/data/benchmark_s1k/_pipeline_state + state_dir: data/benchmark_s1k/_pipeline_state datasets: - - path: /users/qchapp/data/s1K-1.1 # save_to_disk() target above + - path: data/s1K-1.1 # save_to_disk() target above type: loadable - output_dir: /users/qchapp/data/benchmark_s1k/output + output_dir: data/benchmark_s1k/output num_shards: 1 shard_id: "$SLURM_ARRAY_TASK_ID" batch_size: 1000 @@ -76,77 +69,19 @@ processing_params: answer: "{{ answer }}" execution_params: - # Execution mode: "local" or "slurm" - # - local: Run directly on this machine - # - slurm: Submit jobs to SLURM cluster mode: slurm - - # Whether the canonical `run` command should automatically retry failed shards. - # - false: submit one run only - # - true: submit, wait, and keep retrying failed shards until success or retry budget exhaustion retry: false - - # Whether to merge shard outputs after a successful run. - # - false: keep shard_* outputs only - # - true: build merged datasets from shard_* outputs merge: false - - # Maximum number of times to retry a failed shard (default: 3) max_retries: 3 - - # ========================================================================== - # SLURM CONFIGURATION (only used when mode: slurm) - # ========================================================================== - - # HPC account/partition to charge jobs to (REQUIRED for SLURM mode) account: a127 - - # SLURM job name (default: "mmirage-sharded") job_name: mmirage-sharded - - # Optional SLURM reservation name (leave blank or omit to not use) - # reservation: "sai-a127" - - # Number of nodes (default: 1) nodes: 1 - - # Number of tasks per node (default: 1) ntasks_per_node: 1 - - # Number of GPUs per node (default: 4) gpus: 4 - - # Number of CPUs per task (default: 288) cpus_per_task: 288 - - # Job time limit in HH:MM:SS format (default: "11:59:59") time_limit: "11:59:59" - - # ========================================================================== - # PATH CONFIGURATION - # ========================================================================== - # These support environment variables ($VAR or ${VAR}) and home directory (~) - - # Project root directory (used as base for relative paths) - # If not set, uses current working directory - # project_root: "/path/to/project" - - # Directory for SLURM output and error files (default: ~/reports) report_dir: "/users/${USER}/reports" - - # HuggingFace cache directory (default: ~/hf) hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf" - - # EDF environment file path for cluster-specific setup - edf_env: "/users/${USER}/.edf/sglang.toml" - - # ========================================================================== - # JOB MONITORING (for "submit" and retry orchestration) - # ========================================================================== - - # Seconds to wait between checking job status (default: 30) + edf_env: "/users/${USER}/.edf/mmirage.toml" poll_interval_seconds: 30 - - # Seconds to wait after job completes before checking results (default: 60) - # This allows filesystem to settle on distributed systems settle_time_seconds: 60 diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 55209d8..6c686ab 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -237,7 +237,9 @@ def batch_process_sample( logger.error( f"Batch generation failed for text-only samples in output '{output_var.name}': {e}" ) - raise + for global_i in text_only_indices: + empty_val = {} if output_var.output_type == "JSON" else "" + results[global_i] = batch[global_i].with_variable(output_var.name, empty_val) # Multimodal batch if multimodal_indices: @@ -299,7 +301,9 @@ def batch_process_sample( logger.error( f"Batch generation failed for multimodal samples in output '{output_var.name}': {e}" ) - raise + for global_i in multimodal_indices: + empty_val = {} if output_var.output_type == "JSON" else "" + results[global_i] = batch[global_i].with_variable(output_var.name, empty_val) return [results[i] for i in range(nb_samples)] diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 72c70a8..c1f3e6f 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -92,30 +92,30 @@ def main(): state_dir = _shard_state_dir(shard_id, loading_params.get_state_root()) collect_stats = os.environ.get("MMIRAGE_COLLECT_STATS", "") == "1" - # Determine which physical GPU indices SGLang will use so the poller - # measures only the active GPU(s) — not all GPUs on the node. - # SLURM may allocate more GPUs than tp_size (e.g. gpus=4, tp_size=1). - # We take only the first tp_size entries from CUDA_VISIBLE_DEVICES so - # nvidia-smi --id receives exactly the GPUs SGLang is using. - tp_size = 1 - for proc_cfg in cfg.processors: - tp = getattr(getattr(proc_cfg, "server_args", None), "tp_size", None) - if tp and int(tp) > 0: - tp_size = int(tp) - break - cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") - if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): - try: - all_visible = [int(x.strip()) for x in cuda_visible.split(",") if x.strip()] - gpu_indices_for_polling: Optional[list] = all_visible[:tp_size] - except ValueError: + if collect_stats: + # Determine which physical GPU indices SGLang will use so the poller + # measures only the active GPU(s) — not all GPUs on the node. + # SLURM may allocate more GPUs than tp_size (e.g. gpus=4, tp_size=1). + # We take only the first tp_size entries from CUDA_VISIBLE_DEVICES so + # nvidia-smi --id receives exactly the GPUs SGLang is using. + tp_size = 1 + for proc_cfg in cfg.processors: + tp = getattr(getattr(proc_cfg, "server_args", None), "tp_size", None) + if tp and int(tp) > 0: + tp_size = int(tp) + break + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): + try: + all_visible = [int(x.strip()) for x in cuda_visible.split(",") if x.strip()] + gpu_indices_for_polling: Optional[list] = all_visible[:tp_size] + except ValueError: + gpu_indices_for_polling = list(range(tp_size)) + else: gpu_indices_for_polling = list(range(tp_size)) - else: - gpu_indices_for_polling = list(range(tp_size)) - - gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller( - interval_seconds=5.0, gpu_indices=gpu_indices_for_polling - ) + gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller( + interval_seconds=5.0, gpu_indices=gpu_indices_for_polling + ) try: retry_count = _mark_running(state_dir, shard_id, datasets_config) logger.info(f"Starting shard {shard_id}/{last_shard_id} (attempt #{retry_count})") From 0a46d7d2eb3657d923bc1eff62fb10d92becd3d0 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Wed, 6 May 2026 12:58:33 +0200 Subject: [PATCH 12/24] ready for PR --- src/mmirage/cli.py | 12 ++++++++++++ src/mmirage/cli_utils/status.py | 5 +++-- src/mmirage/shard_utils.py | 12 ++++++++++-- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index a160542..2981b87 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -222,6 +222,11 @@ def build_argparser() -> argparse.ArgumentParser: help="Submit retries without prompting.", ) check_parser.set_defaults(confirm_mode="prompt") + check_parser.add_argument( + "--stats", + action="store_true", + help="Enable GPU utilization and throughput collection on retried compute nodes", + ) retry_parser = subparsers.add_parser("retry", help="Submit only failed shards") add_shared_arguments(retry_parser) @@ -234,6 +239,11 @@ def build_argparser() -> argparse.ArgumentParser: help="Submit retries without prompting.", ) retry_parser.set_defaults(confirm_mode="prompt") + retry_parser.add_argument( + "--stats", + action="store_true", + help="Enable GPU utilization and throughput collection on retried compute nodes", + ) run_parser = subparsers.add_parser( "run", @@ -449,6 +459,7 @@ def handle_check(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) config_path=config_path, failed_shards=failed_shards, confirm_mode=args.confirm_mode, + collect_stats=args.stats, ) @@ -481,6 +492,7 @@ def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) config_path=config_path, failed_shards=failed_shards, confirm_mode=args.confirm_mode, + collect_stats=args.stats, ) diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index 1286b55..101d6d9 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -140,6 +140,7 @@ def submit_failed_shards( config_path: str, failed_shards: Sequence[int], confirm_mode: Literal["prompt", "yes"], + collect_stats: bool = False, ) -> int: """Submit retry jobs for failed shards when requested.""" if not failed_shards: @@ -148,7 +149,7 @@ def submit_failed_shards( if not confirm_retry(len(failed_shards), confirm_mode): return 1 - job_id = submit_slurm_job(cfg, config_path, failed_shards) + job_id = submit_slurm_job(cfg, config_path, failed_shards, collect_stats=collect_stats) if job_id is None: return 1 @@ -278,7 +279,7 @@ def collect_bench_stats(cfg: MMirageConfig) -> Dict[str, Any]: "num_gpus": num_gpus, "total_input_tokens": total_input_tokens if has_token_data else None, "total_output_tokens": total_output_tokens if has_token_data else None, - "sum_model_load_seconds": round(sum_model_load_seconds, 3) if sum_model_load_seconds else None, + "sum_model_load_seconds": round(sum_model_load_seconds, 3) if sum_model_load_seconds > 0 else None, "sum_inference_runtime_seconds": round(agg_inference_runtime, 3) if agg_inference_runtime is not None else None, "tokens_per_sec_per_gpu": agg_tokens_per_sec_per_gpu, "gpu_days_per_billion_tokens": agg_gpu_days_per_billion_tokens, diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index bb8741a..1f9d69b 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -524,15 +524,23 @@ def _mark_success(state_dir: str, stats: Optional[ShardStats] = None): pass # Derive throughput once we have both rows and runtime. + # Use inference_runtime (total minus model loading) so the metric + # reflects pure generation speed, consistent with tokens_per_sec_per_gpu. if ( stats.throughput_rows_per_sec is None and stats.rows_processed is not None and stats.runtime_seconds is not None and stats.runtime_seconds > 0 ): - stats.throughput_rows_per_sec = round( - stats.rows_processed / stats.runtime_seconds, 2 + inference_runtime = ( + max(0.0, stats.runtime_seconds - stats.model_load_seconds) + if stats.model_load_seconds is not None + else stats.runtime_seconds ) + if inference_runtime > 0: + stats.throughput_rows_per_sec = round( + stats.rows_processed / inference_runtime, 2 + ) prev.stats = stats _write_status(state_dir, prev) From 716f7b380d4f25bca41ee6dcdc3080aeb158fd8a Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Wed, 6 May 2026 13:04:37 +0200 Subject: [PATCH 13/24] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b6adaaa..1ac2c68 100644 --- a/README.md +++ b/README.md @@ -315,7 +315,7 @@ Key metrics: Reference benchmark: - [DataTrove Benchmark](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark) -- `mmirage run --config configs/config_bencmark_datatrove.yaml --stats` +- `mmirage run --config configs/config_benchmark_datatrove.yaml --stats` ## Architecture From 321828595bff94abe0e6771af907b781c7bb930b Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Wed, 6 May 2026 13:04:47 +0200 Subject: [PATCH 14/24] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/mmirage/shard_process.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index c1f3e6f..e55b7a2 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -193,8 +193,14 @@ def main(): num_gpus: Optional[int] = None for proc_cfg in cfg.processors: tp = getattr(getattr(proc_cfg, "server_args", None), "tp_size", None) - if tp and tp > 0: - num_gpus = int(tp) + if tp is None: + continue + try: + tp_int = int(tp) + except (TypeError, ValueError): + continue + if tp_int > 0: + num_gpus = tp_int break stats = ShardStats( From 1f0b9e0e0b8c172dce40dac99367138076133866 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Wed, 6 May 2026 13:13:32 +0200 Subject: [PATCH 15/24] copilot suggestions --- src/mmirage/cli_utils/status.py | 14 +++++++------- src/mmirage/shard_process.py | 23 +++++++---------------- src/mmirage/shard_utils.py | 16 ++++++++-------- 3 files changed, 22 insertions(+), 31 deletions(-) diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index 101d6d9..2aeba28 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -11,7 +11,7 @@ from mmirage.config.config import MMirageConfig from mmirage.cli_utils.slurm import submit_slurm_job -from mmirage.shard_utils import ShardStatus, _format_duration, _read_status, _shard_state_dir +from mmirage.shard_utils import ShardStatus, format_duration, read_status logger = logging.getLogger(__name__) @@ -189,8 +189,8 @@ def collect_bench_stats(cfg: MMirageConfig) -> Dict[str, Any]: num_gpus: Optional[int] = None # taken from first shard that has it for shard_id in range(num_shards): - state_dir = _shard_state_dir(shard_id, state_root) - status = _read_status(state_dir) + state_dir = shard_state_dir(state_root, shard_id) + status = read_status(state_dir) entry: Dict[str, Any] = status.to_dict() per_shard.append(entry) @@ -266,13 +266,13 @@ def collect_bench_stats(cfg: MMirageConfig) -> Dict[str, Any]: "completed_shards": sum(1 for e in per_shard if e.get("status") == "success"), "total_rows_processed": total_rows if total_rows > 0 else None, "wall_clock_runtime_seconds": wall_clock, - "wall_clock_runtime_human": _format_duration(wall_clock), + "wall_clock_runtime_human": format_duration(wall_clock), "sum_shard_runtime_seconds": round(sum_runtime, 3) if runtimes else None, - "sum_shard_runtime_human": _format_duration(round(sum_runtime, 3) if runtimes else None), + "sum_shard_runtime_human": format_duration(round(sum_runtime, 3) if runtimes else None), "min_shard_runtime_seconds": round(min(runtimes), 3) if runtimes else None, - "min_shard_runtime_human": _format_duration(round(min(runtimes), 3) if runtimes else None), + "min_shard_runtime_human": format_duration(round(min(runtimes), 3) if runtimes else None), "max_shard_runtime_seconds": round(max(runtimes), 3) if runtimes else None, - "max_shard_runtime_human": _format_duration(round(max(runtimes), 3) if runtimes else None), + "max_shard_runtime_human": format_duration(round(max(runtimes), 3) if runtimes else None), "overall_throughput_rows_per_sec": overall_throughput, "mean_gpu_util_pct": mean_gpu_util, # Token-level benchmark metrics (DataTrove-compatible). diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index e55b7a2..c5ae3d5 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -27,7 +27,7 @@ _remove_columns, _save_dataset_atomic, _shard_dataset, - _shard_state_dir, + shard_state_dir, ) logger = logging.getLogger(__name__) @@ -89,7 +89,7 @@ def main(): if not (0 <= shard_id < num_shards): raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") - state_dir = _shard_state_dir(shard_id, loading_params.get_state_root()) + state_dir = shard_state_dir(shard_id, loading_params.get_state_root()) collect_stats = os.environ.get("MMIRAGE_COLLECT_STATS", "") == "1" if collect_stats: @@ -106,13 +106,10 @@ def main(): break cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): - try: - all_visible = [int(x.strip()) for x in cuda_visible.split(",") if x.strip()] - gpu_indices_for_polling: Optional[list] = all_visible[:tp_size] - except ValueError: - gpu_indices_for_polling = list(range(tp_size)) + all_visible = [x.strip() for x in cuda_visible.split(",") if x.strip()] + gpu_indices_for_polling: Optional[List[str]] = all_visible[:tp_size] else: - gpu_indices_for_polling = list(range(tp_size)) + gpu_indices_for_polling = [str(i) for i in range(tp_size)] gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller( interval_seconds=5.0, gpu_indices=gpu_indices_for_polling ) @@ -193,14 +190,8 @@ def main(): num_gpus: Optional[int] = None for proc_cfg in cfg.processors: tp = getattr(getattr(proc_cfg, "server_args", None), "tp_size", None) - if tp is None: - continue - try: - tp_int = int(tp) - except (TypeError, ValueError): - continue - if tp_int > 0: - num_gpus = tp_int + if tp and tp > 0: + num_gpus = int(tp) break stats = ShardStats( diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index 1f9d69b..2fe41cc 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) -def _format_duration(seconds: Optional[float]) -> Optional[str]: +def format_duration(seconds: Optional[float]) -> Optional[str]: """Format a duration given in seconds as a human-readable string. Examples:: @@ -125,7 +125,7 @@ def to_dict(self) -> Dict[str, Any]: return { "runtime_seconds": self.runtime_seconds, - "runtime_human": _format_duration(self.runtime_seconds), + "runtime_human": format_duration(self.runtime_seconds), "model_load_seconds": round(self.model_load_seconds, 3) if self.model_load_seconds is not None else None, "inference_runtime_seconds": round(inference_runtime, 3) if inference_runtime is not None else None, "rows_processed": self.rows_processed, @@ -155,7 +155,7 @@ class GpuUtilizationPoller: If ``nvidia-smi`` is unavailable all values are ``None`` and samples is 0. """ - def __init__(self, interval_seconds: float = 5.0, gpu_indices: Optional[List[int]] = None) -> None: + def __init__(self, interval_seconds: float = 5.0, gpu_indices: Optional[List[str]] = None) -> None: self._interval = interval_seconds self._samples: List[float] = [] self._thread: Optional[threading.Thread] = None @@ -402,7 +402,7 @@ def _dataset_out_dir(shard_idx: int, ds_config: BaseDataLoaderConfig) -> str: return os.path.join(ds_config.output_dir, f"shard_{shard_idx}") -def _shard_state_dir(shard_idx: int, state_root: str) -> str: +def shard_state_dir(shard_idx: int, state_root: str) -> str: """Get central state directory for a logical shard.""" return os.path.join(state_root, f"shard_{shard_idx}") @@ -419,7 +419,7 @@ def _status_file(state_dir: str) -> str: return os.path.join(state_dir, "status.json") -def _read_status(state_dir: str) -> ShardStatus: +def read_status(state_dir: str) -> ShardStatus: """Read status.json if present.""" path = _status_file(state_dir) if not os.path.exists(path): @@ -470,7 +470,7 @@ def _mark_running( datasets_config: List[BaseDataLoaderConfig], ) -> int: """Mark shard as running and increment retry count.""" - prev = _read_status(state_dir) + prev = read_status(state_dir) retry_count = prev.retry_count + 1 payload = ShardStatus( @@ -508,7 +508,7 @@ def _mark_success(state_dir: str, stats: Optional[ShardStats] = None): ``throughput_rows_per_sec`` are computed from the stored timestamps when not already set. """ - prev = _read_status(state_dir) + prev = read_status(state_dir) prev.status = "success" now = datetime.now() prev.finished_at = now.isoformat() @@ -550,7 +550,7 @@ def _mark_success(state_dir: str, stats: Optional[ShardStats] = None): def _mark_failure(state_dir: str, error_msg: str): """Mark shard as failed.""" - prev = _read_status(state_dir) + prev = read_status(state_dir) prev.status = "failed" prev.finished_at = datetime.now().isoformat() prev.error = error_msg From cb1cc3c587b89ab9b8071db6cded5157528a36b5 Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Wed, 6 May 2026 14:19:52 +0200 Subject: [PATCH 16/24] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/mmirage/shard_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index 2fe41cc..b8ee558 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -28,9 +28,9 @@ def format_duration(seconds: Optional[float]) -> Optional[str]: Examples:: - _format_duration(45.3) -> "45s" - _format_duration(125.0) -> "2m 5s" - _format_duration(3725.0) -> "1h 2m 5s" + format_duration(45.3) -> "45s" + format_duration(125.0) -> "2m 5s" + format_duration(3725.0) -> "1h 2m 5s" """ if seconds is None: return None From 26eeea10fd828f2f52b5a86f446694ece4fbb7c2 Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Wed, 6 May 2026 14:20:37 +0200 Subject: [PATCH 17/24] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1ac2c68..523c2d5 100644 --- a/README.md +++ b/README.md @@ -294,7 +294,7 @@ This prints a JSON report with per-shard details and an aggregate summary: "max_shard_runtime_human": "2m 13s", "overall_throughput_rows_per_sec": 7.52, "mean_gpu_util_pct": 86.2, - "num_gpus": 1, + "num_gpus": 4, "total_input_tokens": 146214, "total_output_tokens": 1022046, "sum_model_load_seconds": 38.272, From dc1c96667b0bac31b78510a804012ea029700125 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Thu, 7 May 2026 11:59:39 +0200 Subject: [PATCH 18/24] function deduplication --- src/mmirage/cli.py | 2 +- src/mmirage/cli_utils/status.py | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index 2981b87..a4e3908 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -101,7 +101,7 @@ def launch_pipeline( candidates = sorted(set(failed_shards) | set(runtime_failed)) retryable_shards: List[int] = [] for shard_id in candidates: - _, state_attempt_count = get_shard_status(shard_state_dir(state_root, shard_id)) + _, state_attempt_count = get_shard_status(shard_state_dir(shard_id, state_root)) memory_attempt_count = attempts_by_shard.get(shard_id, 0) effective_attempt_count = max(state_attempt_count, memory_attempt_count) diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index 2aeba28..410b9f8 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -11,7 +11,7 @@ from mmirage.config.config import MMirageConfig from mmirage.cli_utils.slurm import submit_slurm_job -from mmirage.shard_utils import ShardStatus, format_duration, read_status +from mmirage.shard_utils import ShardStatus, format_duration, read_status, shard_state_dir logger = logging.getLogger(__name__) @@ -41,11 +41,6 @@ def is_retry_budget_exceeded(attempt_count: int, max_retries: int) -> bool: return attempt_count > max_allowed_attempts(max_retries) -def shard_state_dir(state_root: str, shard_id: int) -> str: - """Return the state directory for a shard.""" - return os.path.join(state_root, f"shard_{shard_id}") - - def get_shard_status(state_dir: str) -> Tuple[str, int]: """Read the current status and attempt counter for a shard.""" status_file = os.path.join(state_dir, "status.json") @@ -79,7 +74,7 @@ def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: allowed_attempts = max_allowed_attempts(max_retries) for shard_id in range(num_shards): - status, attempt_count = get_shard_status(shard_state_dir(state_root, shard_id)) + status, attempt_count = get_shard_status(shard_state_dir(shard_id, state_root)) if status == "success": success_count += 1 elif status == "running": @@ -189,7 +184,7 @@ def collect_bench_stats(cfg: MMirageConfig) -> Dict[str, Any]: num_gpus: Optional[int] = None # taken from first shard that has it for shard_id in range(num_shards): - state_dir = shard_state_dir(state_root, shard_id) + state_dir = shard_state_dir(shard_id, state_root) status = read_status(state_dir) entry: Dict[str, Any] = status.to_dict() per_shard.append(entry) From 1ef07d4996b7a646dcdb4dee697d1213bdc90340 Mon Sep 17 00:00:00 2001 From: Quentin <74377782+qchapp@users.noreply.github.com> Date: Thu, 7 May 2026 13:11:51 +0200 Subject: [PATCH 19/24] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- configs/config_benchmark_datatrove.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/config_benchmark_datatrove.yaml b/configs/config_benchmark_datatrove.yaml index d44e87d..c30e9b3 100644 --- a/configs/config_benchmark_datatrove.yaml +++ b/configs/config_benchmark_datatrove.yaml @@ -77,7 +77,7 @@ execution_params: job_name: mmirage-sharded nodes: 1 ntasks_per_node: 1 - gpus: 4 + gpus: 1 cpus_per_task: 288 time_limit: "11:59:59" report_dir: "/users/${USER}/reports" From dfa7fae6e5544a904bc724a5e13ae2e79921e808 Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Thu, 7 May 2026 13:20:47 +0200 Subject: [PATCH 20/24] copilot changes --- src/mmirage/shard_process.py | 4 +++- src/mmirage/shard_utils.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index c5ae3d5..9db30c4 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -107,7 +107,9 @@ def main(): cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): all_visible = [x.strip() for x in cuda_visible.split(",") if x.strip()] - gpu_indices_for_polling: Optional[List[str]] = all_visible[:tp_size] + # Fall back to range-based indices if CUDA_VISIBLE_DEVICES was set + # but contained only whitespace/empty entries after stripping. + gpu_indices_for_polling: Optional[List[str]] = all_visible[:tp_size] if all_visible else [str(i) for i in range(tp_size)] else: gpu_indices_for_polling = [str(i) for i in range(tp_size)] gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller( diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index b8ee558..0fc56d7 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -203,7 +203,11 @@ def _query_gpu_util(self) -> Optional[float]: # dilute utilization by averaging over idle GPUs on the same node. # Priority: explicit gpu_indices > CUDA_VISIBLE_DEVICES > all GPUs. if self._gpu_indices is not None: - cmd += [f"--id={','.join(str(i) for i in self._gpu_indices)}"] + if not self._gpu_indices: + # Empty list would produce --id= which is invalid; skip filtering. + pass + else: + cmd += [f"--id={','.join(str(i) for i in self._gpu_indices)}"] else: cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): From 7946278ee10d59ff9fad84d1bf04fb4c81a6dbdf Mon Sep 17 00:00:00 2001 From: titi <74377782+qchapp@users.noreply.github.com> Date: Fri, 8 May 2026 17:17:44 +0200 Subject: [PATCH 21/24] implemented changes requested by fabrice --- README.md | 31 ++++++++++++++++++++++++- configs/config_benchmark_datatrove.yaml | 26 +-------------------- pyproject.toml | 1 + src/mmirage/core/process/mapper.py | 15 +++++++++--- src/mmirage/shard_process.py | 6 ++--- src/mmirage/shard_utils.py | 19 +++------------ 6 files changed, 50 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index 523c2d5..d35da70 100644 --- a/README.md +++ b/README.md @@ -315,7 +315,36 @@ Key metrics: Reference benchmark: - [DataTrove Benchmark](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark) -- `mmirage run --config configs/config_benchmark_datatrove.yaml --stats` + +The config `configs/config_benchmark_datatrove.yaml` mirrors the DataTrove inference benchmark conditions: + +| Setting | Value | +|---|---| +| Dataset | `simplescaling/s1K-1.1` (train split, 1 000 samples) | +| Prompt | raw `question` field, no system prompt | +| Output | up to 1 024 tokens per sample | +| Context | 2 048-token model max context | +| Model | `Qwen/Qwen3-4B` (DataTrove baseline: tp=1 on a single GPU) | + +Download the dataset before running: + +```python +from datasets import load_dataset +ds = load_dataset('simplescaling/s1K-1.1', split='train') +ds.save_to_disk('data/s1K-1.1') +``` + +Then run with stats collection enabled: + +```bash +mmirage run --config configs/config_benchmark_datatrove.yaml --stats +``` + +Inspect results: + +```bash +mmirage stats --config configs/config_benchmark_datatrove.yaml +``` ## Architecture diff --git a/configs/config_benchmark_datatrove.yaml b/configs/config_benchmark_datatrove.yaml index c30e9b3..911cd65 100644 --- a/configs/config_benchmark_datatrove.yaml +++ b/configs/config_benchmark_datatrove.yaml @@ -1,29 +1,5 @@ # MMIRAGE — DataTrove-compatible throughput benchmark -# -# Mirrors the conditions used in the DataTrove inference benchmark -# (https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark): -# -# dataset : simplescaling/s1K-1.1 (train split, 1 000 samples) -# prompt : raw `question` field, no system prompt -# output : up to 1 024 tokens per sample -# context : 2 048-token model max context -# model : Qwen/Qwen3-4B (DataTrove baseline: tp=1 on a single GPU) -# -# Download the dataset before running: -# -# python -c " -# from datasets import load_dataset -# ds = load_dataset('simplescaling/s1K-1.1', split='train') -# ds.save_to_disk('data/s1K-1.1') -# " -# -# Then run with stats collection enabled: -# -# mmirage run --config configs/config_benchmark_datatrove.yaml --stats -# -# Inspect results: -# -# mmirage stats --config configs/config_benchmark_datatrove.yaml +# See README.md for setup instructions and benchmark details. processors: - type: llm diff --git a/pyproject.toml b/pyproject.toml index 5804d4e..6ccc56a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "jmespath", "jinja2>=3.0.0", "pillow>=9.0.0", + "humanize>=4.0.0", ] [project.optional-dependencies] diff --git a/src/mmirage/core/process/mapper.py b/src/mmirage/core/process/mapper.py index c8d8a63..86e1073 100644 --- a/src/mmirage/core/process/mapper.py +++ b/src/mmirage/core/process/mapper.py @@ -1,10 +1,19 @@ """Mapper for orchestrating variable transformations.""" +from dataclasses import dataclass from typing import Dict, Any, List, cast from mmirage.core.process.variables import BaseVar, InputVar, OutputVar from mmirage.core.process.base import AutoProcessor, BaseProcessor, BaseProcessorConfig + +@dataclass +class TokenCounts: + """Cumulative token counts from LLM processors.""" + + input_tokens: int + output_tokens: int + import logging from mmirage.core.process.variables import VariableEnvironment @@ -104,14 +113,14 @@ def rewrite_batch( return batch_environment - def get_token_counts(self) -> Dict[str, int]: + def get_token_counts(self) -> TokenCounts: """Return cumulative token counts aggregated across all LLM processors. Sums ``input_tokens`` and ``output_tokens`` from every processor that exposes a ``get_token_counts()`` method (i.e., ``LLMProcessor``). Returns: - Dict with ``input_tokens`` and ``output_tokens`` keys. + TokenCounts with ``input_tokens`` and ``output_tokens`` fields. """ total_input = 0 total_output = 0 @@ -120,7 +129,7 @@ def get_token_counts(self) -> Dict[str, int]: counts = proc.get_token_counts() total_input += counts.get("input_tokens", 0) total_output += counts.get("output_tokens", 0) - return {"input_tokens": total_input, "output_tokens": total_output} + return TokenCounts(input_tokens=total_input, output_tokens=total_output) def get_load_time(self) -> float: """Return total model-loading time (seconds) summed across all LLM processors.""" diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 9db30c4..50fe150 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -109,7 +109,7 @@ def main(): all_visible = [x.strip() for x in cuda_visible.split(",") if x.strip()] # Fall back to range-based indices if CUDA_VISIBLE_DEVICES was set # but contained only whitespace/empty entries after stripping. - gpu_indices_for_polling: Optional[List[str]] = all_visible[:tp_size] if all_visible else [str(i) for i in range(tp_size)] + gpu_indices_for_polling: List[str] = all_visible[:tp_size] if all_visible else [str(i) for i in range(tp_size)] else: gpu_indices_for_polling = [str(i) for i in range(tp_size)] gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller( @@ -184,8 +184,8 @@ def main(): # Collect token counts accumulated by LLM processor(s). token_counts = mapper.get_token_counts() - input_tokens = token_counts["input_tokens"] or None - output_tokens = token_counts["output_tokens"] or None + input_tokens = token_counts.input_tokens or None + output_tokens = token_counts.output_tokens or None model_load_seconds = mapper.get_load_time() or None # Resolve num_gpus from the first processor config that exposes tp_size. diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index 0fc56d7..d13df76 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -6,6 +6,7 @@ from datetime import datetime from dataclasses import dataclass +import humanize import json import logging import os @@ -24,24 +25,10 @@ def format_duration(seconds: Optional[float]) -> Optional[str]: - """Format a duration given in seconds as a human-readable string. - - Examples:: - - format_duration(45.3) -> "45s" - format_duration(125.0) -> "2m 5s" - format_duration(3725.0) -> "1h 2m 5s" - """ + """Format a duration given in seconds as a human-readable string.""" if seconds is None: return None - total = int(seconds) - hours, remainder = divmod(total, 3600) - minutes, secs = divmod(remainder, 60) - if hours: - return f"{hours}h {minutes}m {secs}s" - if minutes: - return f"{minutes}m {secs}s" - return f"{secs}s" + return humanize.precisedelta(seconds) @dataclass From c962618b8130e6866f1631031a8837bcdc70cb81 Mon Sep 17 00:00:00 2001 From: Fabrice Nemo Date: Mon, 11 May 2026 11:13:11 +0200 Subject: [PATCH 22/24] fixed TokenCounts logic --- src/mmirage/core/process/base.py | 20 +++++++++++++++++++ src/mmirage/core/process/mapper.py | 15 ++++---------- .../process/processors/llm/llm_processor.py | 15 +++++++------- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/mmirage/core/process/base.py b/src/mmirage/core/process/base.py index 988bae7..db918df 100644 --- a/src/mmirage/core/process/base.py +++ b/src/mmirage/core/process/base.py @@ -24,6 +24,14 @@ class BaseProcessorConfig: C = TypeVar("C", bound=OutputVar) +@dataclass +class TokenCounts: + """Cumulative token counts from LLM processors.""" + + input_tokens: int + output_tokens: int + + class BaseProcessor(abc.ABC, Generic[C]): """Abstract base class for data processors. @@ -64,6 +72,18 @@ def batch_process_sample( """ raise NotImplementedError() + @abstract + def get_token_counts(self) -> TokenCounts: + """Get cumulative token counts from this processor. + + Returns: + TokenCounts object containing input and output token counts. + + Raises: + NotImplementedError: If not implemented by subclass. + """ + raise NotImplementedError() + class ProcessorRegistry: """Registry for managing and accessing available processors. diff --git a/src/mmirage/core/process/mapper.py b/src/mmirage/core/process/mapper.py index 86e1073..f14c436 100644 --- a/src/mmirage/core/process/mapper.py +++ b/src/mmirage/core/process/mapper.py @@ -1,19 +1,12 @@ """Mapper for orchestrating variable transformations.""" from dataclasses import dataclass -from typing import Dict, Any, List, cast +from typing import Any, Dict, List, Optional, cast +from mmirage.core.process.base import AutoProcessor, BaseProcessor, BaseProcessorConfig, TokenCounts from mmirage.core.process.variables import BaseVar, InputVar, OutputVar -from mmirage.core.process.base import AutoProcessor, BaseProcessor, BaseProcessorConfig -@dataclass -class TokenCounts: - """Cumulative token counts from LLM processors.""" - - input_tokens: int - output_tokens: int - import logging from mmirage.core.process.variables import VariableEnvironment @@ -127,8 +120,8 @@ def get_token_counts(self) -> TokenCounts: for proc in self.processors.values(): if hasattr(proc, "get_token_counts"): counts = proc.get_token_counts() - total_input += counts.get("input_tokens", 0) - total_output += counts.get("output_tokens", 0) + total_input += counts.input_tokens + total_output += counts.output_tokens return TokenCounts(input_tokens=total_input, output_tokens=total_output) def get_load_time(self) -> float: diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 6c686ab..5107582 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -12,7 +12,7 @@ import sglang as sgl from transformers import AutoTokenizer -from mmirage.core.process.base import BaseProcessor, ProcessorRegistry +from mmirage.core.process.base import BaseProcessor, ProcessorRegistry, TokenCounts from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig from mmirage.core.process.variables import VariableEnvironment @@ -79,17 +79,16 @@ def get_load_time(self) -> float: """Return the wall-clock seconds spent initializing the SGLang engine.""" return self._model_load_seconds - def get_token_counts(self) -> dict: + def get_token_counts(self) -> TokenCounts: """Return cumulative token counts for this processor. Returns: - Dict with ``input_tokens`` (prompt tokens) and ``output_tokens`` - (completion tokens) accumulated since this processor was created. + TokenCounts object containing input and output token counts accumulated since this processor was created. """ - return { - "input_tokens": self._total_input_tokens, - "output_tokens": self._total_output_tokens, - } + return TokenCounts( + input_tokens=self._total_input_tokens, + output_tokens=self._total_output_tokens + ) def _accumulate_tokens(self, outputs: list) -> None: """Add token counts from a list of SGLang generate() outputs.""" From fdb0d4265451772b771dcb55b661f4d3f6ddf306 Mon Sep 17 00:00:00 2001 From: Fabrice Nemo Date: Mon, 11 May 2026 11:26:27 +0200 Subject: [PATCH 23/24] fixed various typing and logic errors --- src/mmirage/config/loading.py | 4 ++-- src/mmirage/core/process/base.py | 14 +++++++++++++- src/mmirage/core/process/mapper.py | 2 +- src/mmirage/merge_shards.py | 2 +- src/mmirage/shard_process.py | 14 +++++++++----- 5 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index 929d76d..afb1386 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -47,7 +47,7 @@ def is_unresolved_env_var(s: str) -> bool: if self.num_shards < 1: raise ValueError() except (ValueError, TypeError): - if is_unresolved_env_var(self.num_shards): + if isinstance(self.num_shards, str) and is_unresolved_env_var(self.num_shards): self.num_shards = 1 else: raise ValueError(f"Invalid value for num_shards: {self.num_shards!r}") @@ -56,7 +56,7 @@ def is_unresolved_env_var(s: str) -> bool: try: self.shard_id = int(self.shard_id) except (ValueError, TypeError): - if is_unresolved_env_var(self.shard_id): + if isinstance(self.shard_id, str) and is_unresolved_env_var(self.shard_id): self.shard_id = 0 else: raise ValueError(f"Invalid value for shard_id: {self.shard_id!r}") diff --git a/src/mmirage/core/process/base.py b/src/mmirage/core/process/base.py index db918df..6e8a283 100644 --- a/src/mmirage/core/process/base.py +++ b/src/mmirage/core/process/base.py @@ -72,7 +72,7 @@ def batch_process_sample( """ raise NotImplementedError() - @abstract + @abc.abstractmethod def get_token_counts(self) -> TokenCounts: """Get cumulative token counts from this processor. @@ -84,6 +84,18 @@ def get_token_counts(self) -> TokenCounts: """ raise NotImplementedError() + @abc.abstractmethod + def get_load_time(self) -> float: + """Get the time taken to load any necessary resources (e.g., models). + + Returns: + Time in seconds taken to load resources. + + Raises: + NotImplementedError: If not implemented by subclass. + """ + raise NotImplementedError() + class ProcessorRegistry: """Registry for managing and accessing available processors. diff --git a/src/mmirage/core/process/mapper.py b/src/mmirage/core/process/mapper.py index f14c436..877741b 100644 --- a/src/mmirage/core/process/mapper.py +++ b/src/mmirage/core/process/mapper.py @@ -75,7 +75,7 @@ def validate_vars(self) -> bool: def rewrite_batch( self, batch: Dict[str, List[Any]], - image_base_path: str = None, + image_base_path: Optional[str] = None, ) -> List[VariableEnvironment]: """Transform a batch of samples by computing output variables. diff --git a/src/mmirage/merge_shards.py b/src/mmirage/merge_shards.py index fbddf10..a54444a 100644 --- a/src/mmirage/merge_shards.py +++ b/src/mmirage/merge_shards.py @@ -49,7 +49,7 @@ def _merge_datasetdict(shard_dsets: List[DatasetDict]) -> DatasetDict: merged[str(split)] = concatenate_datasets(split_dsets) if not merged: raise RuntimeError("All splits were empty after merging.") - return DatasetDict(merged) + return DatasetDict(**merged) def _merge_shards(shard_dsets: List[DatasetLike]) -> DatasetLike: diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 50fe150..63f36fb 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -37,7 +37,7 @@ def rewrite_batch( batch: Dict[str, List[Any]], mapper: MMIRAGEMapper, renderer: TemplateRenderer, - image_base_path: str = None, + image_base_path: Optional[str] = None, ) -> Dict[str, List[Any]]: """Rewrite a batch of samples by applying transformations. Args: @@ -91,6 +91,8 @@ def main(): state_dir = shard_state_dir(shard_id, loading_params.get_state_root()) + gpu_poller: Optional[GpuUtilizationPoller] = None + collect_stats = os.environ.get("MMIRAGE_COLLECT_STATS", "") == "1" if collect_stats: # Determine which physical GPU indices SGLang will use so the poller @@ -112,9 +114,11 @@ def main(): gpu_indices_for_polling: List[str] = all_visible[:tp_size] if all_visible else [str(i) for i in range(tp_size)] else: gpu_indices_for_polling = [str(i) for i in range(tp_size)] - gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller( + + gpu_poller = GpuUtilizationPoller( interval_seconds=5.0, gpu_indices=gpu_indices_for_polling ) + try: retry_count = _mark_running(state_dir, shard_id, datasets_config) logger.info(f"Starting shard {shard_id}/{last_shard_id} (attempt #{retry_count})") @@ -144,7 +148,7 @@ def main(): # Start GPU polling after model loading so utilisation samples reflect # inference only, not weight transfers during sgl.Engine() init. - if collect_stats: + if collect_stats and gpu_poller is not None: gpu_poller.start() ds_processed_all: List[DatasetLike] = [] @@ -180,7 +184,7 @@ def main(): _save_dataset_atomic(ds_processed, out_dir) logger.info(f"✅ Saved dataset {ds_idx} shard in: {out_dir}") - gpu_info = gpu_poller.stop() if collect_stats else {"mean": None, "min": None, "max": None, "samples": 0} + gpu_info = gpu_poller.stop() if collect_stats and gpu_poller is not None else {"mean": None, "min": None, "max": None, "samples": 0} # Collect token counts accumulated by LLM processor(s). token_counts = mapper.get_token_counts() @@ -214,7 +218,7 @@ def main(): error_msg = f"{type(e).__name__}: {str(e)}" logger.error(f"❌ Shard {shard_id} failed: {error_msg}") logger.error(traceback.format_exc()) - if collect_stats: + if collect_stats and gpu_poller is not None: gpu_poller.stop() _mark_failure(state_dir, error_msg) sys.exit(1) From 37ca65a8d08ba993dcd4c08539b224888d53b9b8 Mon Sep 17 00:00:00 2001 From: Fabrice Nemo Date: Mon, 11 May 2026 15:16:36 +0200 Subject: [PATCH 24/24] fixed image base path --- src/mmirage/shard_process.py | 2 +- src/mmirage/shard_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 63f36fb..d12e07d 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -161,7 +161,7 @@ def main(): logger.info( f"Processing dataset {ds_idx} for shard {shard_id}: " - f"path={ds_config.path}, output_dir={ds_config.output_dir}" + f"image_base_path={ds_config.image_base_path}, output_dir={ds_config.output_dir}" ) ds_processed = ds_shard.map( diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index d13df76..31a254e 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -477,7 +477,7 @@ def _mark_running( slurm_array_task_id=os.environ.get("SLURM_ARRAY_TASK_ID"), datasets=[ { - "path": ds_config.path, + "image_base_path": ds_config.image_base_path, "output_dir": ds_config.output_dir, } for ds_config in datasets_config