diff --git a/README.md b/README.md index 2660832..d35da70 100644 --- a/README.md +++ b/README.md @@ -235,6 +235,117 @@ Key multimodal features: - `image_base_path`: Base directory for resolving relative image paths - Supports PIL Images, URLs, and file paths +### Benchmarking shard performance + +Pass `--stats` to `run` or `submit` to enable per-shard benchmarking. This activates GPU +utilization polling and throughput tracking on compute nodes — disabled by default to +avoid unnecessary overhead. + +```bash +# Local run with stats collection +mmirage run --config configs/config_mock.yaml --stats + +``` + +After the run completes, inspect the results with: + +```bash +mmirage stats --config configs/config_mock.yaml +``` + +This prints a JSON report with per-shard details and an aggregate summary: + +```json +{ + "per_shard": [ + { + "shard_id": 0, + "status": "success", + "started_at": "2026-04-30T10:00:00", + "finished_at": "2026-04-30T10:01:05", + "stats": { + "runtime_seconds": 65.2, + "runtime_human": "1m 5s", + "rows_processed": 1024, + "throughput_rows_per_sec": 15.7, + "gpu_util_mean": 88.4, + "gpu_util_min": 72.0, + "gpu_util_max": 98.0, + "gpu_util_samples": 13, + "input_tokens": 512000, + "output_tokens": 196608, + "num_gpus": 4, + "tokens_per_sec_per_gpu": 753.1, + "gpu_days_per_billion_tokens": 0.0015 + } + } + ], + "aggregate": { + "total_shards": 1, + "completed_shards": 1, + "total_rows_processed": 1000, + "wall_clock_runtime_seconds": 133.04, + "wall_clock_runtime_human": "2m 13s", + "sum_shard_runtime_seconds": 133.04, + "sum_shard_runtime_human": "2m 13s", + "min_shard_runtime_seconds": 133.04, + "min_shard_runtime_human": "2m 13s", + "max_shard_runtime_seconds": 133.04, + "max_shard_runtime_human": "2m 13s", + "overall_throughput_rows_per_sec": 7.52, + "mean_gpu_util_pct": 86.2, + "num_gpus": 4, + "total_input_tokens": 146214, + "total_output_tokens": 1022046, + "sum_model_load_seconds": 38.272, + "sum_inference_runtime_seconds": 94.768, + "tokens_per_sec_per_gpu": 10784.72, + "gpu_days_per_billion_tokens": 1.0732 + } +} +``` + +Key metrics: +- **`runtime_seconds`** / **`runtime_human`**: time from when the shard started on the cluster (after dispatch), excluding queue wait time. +- **`overall_throughput_rows_per_sec`**: total rows / wall-clock time across all shards running in parallel. +- **`mean_gpu_util_pct`**: mean percentage GPU utilization across shards. +- **`tokens_per_sec_per_gpu`**: output tokens generated per second per GPU — the primary throughput metric used by frameworks such as [DataTrove](https://github.com/huggingface/datatrove). +- **`gpu_days_per_billion_tokens`**: total GPU-days consumed to generate 1 billion output tokens — useful for cost and scaling comparisons across different hardware configurations. +- Token metrics are `null` when no LLM processor was active, and GPU stats are `null` when `nvidia-smi` is unavailable or `--stats` was not passed. + +Reference benchmark: +- [DataTrove Benchmark](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark) + +The config `configs/config_benchmark_datatrove.yaml` mirrors the DataTrove inference benchmark conditions: + +| Setting | Value | +|---|---| +| Dataset | `simplescaling/s1K-1.1` (train split, 1 000 samples) | +| Prompt | raw `question` field, no system prompt | +| Output | up to 1 024 tokens per sample | +| Context | 2 048-token model max context | +| Model | `Qwen/Qwen3-4B` (DataTrove baseline: tp=1 on a single GPU) | + +Download the dataset before running: + +```python +from datasets import load_dataset +ds = load_dataset('simplescaling/s1K-1.1', split='train') +ds.save_to_disk('data/s1K-1.1') +``` + +Then run with stats collection enabled: + +```bash +mmirage run --config configs/config_benchmark_datatrove.yaml --stats +``` + +Inspect results: + +```bash +mmirage stats --config configs/config_benchmark_datatrove.yaml +``` + ## Architecture MMIRAGE uses a modular architecture: @@ -258,3 +369,4 @@ mmirage/ - JMESPath for JSON queries: [link](https://jmespath.org/) - SGLang for fast inference: [link](https://github.com/sgl-project/sglang) - Performance paper: [link](https://arxiv.org/abs/2408.02442) +- DataTrove Benchmark: [link](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark) diff --git a/configs/config_benchmark_datatrove.yaml b/configs/config_benchmark_datatrove.yaml new file mode 100644 index 0000000..911cd65 --- /dev/null +++ b/configs/config_benchmark_datatrove.yaml @@ -0,0 +1,63 @@ +# MMIRAGE — DataTrove-compatible throughput benchmark +# See README.md for setup instructions and benchmark details. + +processors: + - type: llm + server_args: + model_path: Qwen/Qwen3-4B # same model family as DataTrove baseline + tp_size: 1 # DataTrove baseline: tp=1 + trust_remote_code: true + disable_custom_all_reduce: true + # SGLang engine tuning — equivalents of DataTrove's vLLM mns/mnbt knobs + extra_engine_args: + max_running_requests: 1000 + default_sampling_params: + temperature: 0.0 + max_new_tokens: 1024 # DataTrove: max-tokens=1024 + +loading_params: + state_dir: data/benchmark_s1k/_pipeline_state + datasets: + - path: data/s1K-1.1 # save_to_disk() target above + type: loadable + output_dir: data/benchmark_s1k/output + num_shards: 1 + shard_id: "$SLURM_ARRAY_TASK_ID" + batch_size: 1000 + +processing_params: + inputs: + - name: question + key: question # DataTrove: prompt-column=question + + outputs: + - name: answer + type: llm + output_type: plain + # Qwen3 thinking is disabled by embedding an empty block in the prompt. + # This is equivalent to passing enable_thinking=False to the chat template and + # avoids any dependency on SGLang sampling-param support for that flag. + prompt: "<|im_start|>user\n{{ question }}\n<|im_end|>\n<|im_start|>assistant\n\n\n\n" + + remove_columns: false + output_schema: + question: "{{ question }}" + answer: "{{ answer }}" + +execution_params: + mode: slurm + retry: false + merge: false + max_retries: 3 + account: a127 + job_name: mmirage-sharded + nodes: 1 + ntasks_per_node: 1 + gpus: 1 + cpus_per_task: 288 + time_limit: "11:59:59" + report_dir: "/users/${USER}/reports" + hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf" + edf_env: "/users/${USER}/.edf/mmirage.toml" + poll_interval_seconds: 30 + settle_time_seconds: 60 diff --git a/pyproject.toml b/pyproject.toml index 5804d4e..6ccc56a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "jmespath", "jinja2>=3.0.0", "pillow>=9.0.0", + "humanize>=4.0.0", ] [project.optional-dependencies] diff --git a/src/mmirage/cli.py b/src/mmirage/cli.py index 4b40cc0..a4e3908 100644 --- a/src/mmirage/cli.py +++ b/src/mmirage/cli.py @@ -15,6 +15,7 @@ from mmirage.cli_utils.slurm import require_slurm, submit_slurm_job, wait_for_slurm_job from mmirage.cli_utils.status import ( check_failed_shards, + collect_bench_stats, is_retry_budget_exceeded, shard_state_dir, get_shard_status, @@ -29,12 +30,13 @@ logger = logging.getLogger(__name__) -def run_local(config_path: str, shard_id: Optional[int] = None) -> int: +def run_local(config_path: str, shard_id: Optional[int] = None, collect_stats: bool = False) -> int: """Run one shard in the current Python environment. Args: config_path: Absolute path to the MMIRAGE YAML config file. shard_id: Optional shard id to inject via SLURM_ARRAY_TASK_ID. + collect_stats: If True, enable GPU utilization polling in the shard process. Returns: Process return code from shard execution. @@ -43,6 +45,8 @@ def run_local(config_path: str, shard_id: Optional[int] = None) -> int: env = os.environ.copy() if shard_id is not None: env["SLURM_ARRAY_TASK_ID"] = str(shard_id) + if collect_stats: + env["MMIRAGE_COLLECT_STATS"] = "1" logger.info("Running local shard processing: %s", " ".join(command)) result = subprocess.run(command, env=env, check=False) @@ -54,6 +58,7 @@ def launch_pipeline( config_path: str, force_retry: bool = False, require_completion: bool = False, + collect_stats: bool = False, ) -> int: """Launch the pipeline according to execution mode and retry settings. @@ -63,6 +68,7 @@ def launch_pipeline( force_retry: If True, enable retry orchestration regardless of config flag. require_completion: If True, wait for completion and verify shard status before returning success in SLURM mode when auto-retry is off. + collect_stats: If True, enable GPU utilization polling on compute nodes. Returns: Exit code: 0 on success, 1 on failure. @@ -72,7 +78,7 @@ def launch_pipeline( if not cfg.execution_params.is_slurm(): initial_shard_id = cfg.loading_params.get_shard_id() if not auto_retry: - exit_code = run_local(config_path, initial_shard_id) + exit_code = run_local(config_path, initial_shard_id, collect_stats=collect_stats) if exit_code == 0: logger.info("All shards completed successfully") return exit_code @@ -84,7 +90,7 @@ def launch_pipeline( run_exit_codes = {} for shard_id in shard_ids: attempts_by_shard[shard_id] = attempts_by_shard.get(shard_id, 0) + 1 - run_exit_codes[shard_id] = run_local(config_path, shard_id) + run_exit_codes[shard_id] = run_local(config_path, shard_id, collect_stats=collect_stats) failed_shards, summary = check_failed_shards(cfg) if status_exit_code(failed_shards, summary) == 0: @@ -95,7 +101,7 @@ def launch_pipeline( candidates = sorted(set(failed_shards) | set(runtime_failed)) retryable_shards: List[int] = [] for shard_id in candidates: - _, state_attempt_count = get_shard_status(shard_state_dir(state_root, shard_id)) + _, state_attempt_count = get_shard_status(shard_state_dir(shard_id, state_root)) memory_attempt_count = attempts_by_shard.get(shard_id, 0) effective_attempt_count = max(state_attempt_count, memory_attempt_count) @@ -115,7 +121,7 @@ def launch_pipeline( shard_ids: List[int] = [] while True: - job_id = submit_slurm_job(cfg, config_path, shard_ids) + job_id = submit_slurm_job(cfg, config_path, shard_ids, collect_stats=collect_stats) if job_id is None: return 1 @@ -192,6 +198,11 @@ def build_argparser() -> argparse.ArgumentParser: help="Comma-separated shard ids to submit instead of the full array", ) submit_parser.add_argument("--wait", action="store_true", help="Wait for the submitted job") + submit_parser.add_argument( + "--stats", + action="store_true", + help="Enable GPU utilization and throughput collection on compute nodes", + ) check_parser = subparsers.add_parser("check", help="Inspect shard status") add_shared_arguments(check_parser) @@ -211,6 +222,11 @@ def build_argparser() -> argparse.ArgumentParser: help="Submit retries without prompting.", ) check_parser.set_defaults(confirm_mode="prompt") + check_parser.add_argument( + "--stats", + action="store_true", + help="Enable GPU utilization and throughput collection on retried compute nodes", + ) retry_parser = subparsers.add_parser("retry", help="Submit only failed shards") add_shared_arguments(retry_parser) @@ -223,6 +239,11 @@ def build_argparser() -> argparse.ArgumentParser: help="Submit retries without prompting.", ) retry_parser.set_defaults(confirm_mode="prompt") + retry_parser.add_argument( + "--stats", + action="store_true", + help="Enable GPU utilization and throughput collection on retried compute nodes", + ) run_parser = subparsers.add_parser( "run", @@ -240,6 +261,11 @@ def build_argparser() -> argparse.ArgumentParser: default=None, help="Run a single shard locally (overrides execution mode)", ) + run_parser.add_argument( + "--stats", + action="store_true", + help="Enable GPU utilization and throughput collection during shard execution", + ) merge_parser = subparsers.add_parser( "merge", @@ -282,6 +308,12 @@ def build_argparser() -> argparse.ArgumentParser: help="Log verbosity", ) + stats_parser = subparsers.add_parser( + "stats", + help="Show per-shard benchmark statistics (runtime, throughput, GPU utilization)", + ) + add_shared_arguments(stats_parser) + return parser @@ -346,13 +378,14 @@ def handle_run(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) - Exit code for the run operation. """ if args.shard_id is not None: - return run_local(config_path, args.shard_id) + return run_local(config_path, args.shard_id, collect_stats=args.stats) exit_code = launch_pipeline( cfg, config_path, force_retry=args.force_retry, require_completion=cfg.execution_params.merge, + collect_stats=args.stats, ) if exit_code != 0: return exit_code @@ -380,7 +413,7 @@ def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str return 1 shard_ids = parse_shard_ids(args.shard_ids, cfg.loading_params.get_num_shards()) - job_id = submit_slurm_job(cfg, config_path, shard_ids) + job_id = submit_slurm_job(cfg, config_path, shard_ids, collect_stats=args.stats) if job_id is None: return 1 @@ -426,6 +459,7 @@ def handle_check(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) config_path=config_path, failed_shards=failed_shards, confirm_mode=args.confirm_mode, + collect_stats=args.stats, ) @@ -458,9 +492,26 @@ def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) config_path=config_path, failed_shards=failed_shards, confirm_mode=args.confirm_mode, + collect_stats=args.stats, ) +def handle_stats(args: argparse.Namespace, cfg: MMirageConfig, _config_path: str) -> int: + """Print per-shard benchmark statistics and aggregate totals. + + Args: + args: Parsed CLI namespace. + cfg: Parsed MMIRAGE configuration object. + _config_path: Absolute path to the MMIRAGE YAML config file (not needed here). + + Returns: + Exit code: 0 always (stats are informational). + """ + report = collect_bench_stats(cfg) + print(json.dumps(report, indent=2)) + return 0 + + def handle_merge(args: argparse.Namespace, cfg: MMirageConfig, _config_path: str) -> int: """Merge shard outputs defined in config.loading_params.datasets. @@ -513,6 +564,7 @@ def main() -> None: "check": handle_check, "retry": handle_retry, "merge": handle_merge, + "stats": handle_stats, } handler = handlers.get(args.command) if handler is None: diff --git a/src/mmirage/cli_utils/slurm.py b/src/mmirage/cli_utils/slurm.py index ad69d45..bb2c7d4 100644 --- a/src/mmirage/cli_utils/slurm.py +++ b/src/mmirage/cli_utils/slurm.py @@ -54,7 +54,7 @@ def _shell_path(value: str, project_root: str) -> str: return raw -def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str: +def build_sbatch_script(cfg: MMirageConfig, config_path: str, collect_stats: bool = False) -> str: """Build the sbatch payload executed for each array task.""" project_root = get_project_root(cfg) hf_home = _shell_path(cfg.execution_params.hf_home, project_root) @@ -69,10 +69,14 @@ def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str: f"export SHARD_PROCESS={_bash_double_quote(shard_process_path)}", f"export HF_HOME={_bash_double_quote(hf_home)}", f"export MMIRAGE_CONFIG={_bash_double_quote(config_path)}", + ] + if collect_stats: + lines.append("export MMIRAGE_COLLECT_STATS=1") + lines.extend([ f"mkdir -p {_bash_double_quote(hf_home)}", f"mkdir -p {_bash_double_quote(state_root)}", "srun_args=(--cpus-per-task ${SLURM_CPUS_PER_TASK:-1} --wait 60)", - ] + ]) if cfg.execution_params.edf_env: edf_env = expand_path(cfg.execution_params.edf_env, project_root) @@ -99,6 +103,7 @@ def submit_slurm_job( cfg: MMirageConfig, config_path: str, shard_ids: Optional[Sequence[int]] = None, + collect_stats: bool = False, ) -> Optional[int]: """Submit a SLURM array job and return its job ID.""" project_root = get_project_root(cfg) @@ -134,7 +139,7 @@ def submit_slurm_job( logger.info("Submitting SLURM job: %s", " ".join(command)) result = subprocess.run( command, - input=build_sbatch_script(cfg, config_path), + input=build_sbatch_script(cfg, config_path, collect_stats=collect_stats), text=True, capture_output=True, check=False, diff --git a/src/mmirage/cli_utils/status.py b/src/mmirage/cli_utils/status.py index 838099a..410b9f8 100644 --- a/src/mmirage/cli_utils/status.py +++ b/src/mmirage/cli_utils/status.py @@ -7,11 +7,11 @@ import os import sys from dataclasses import dataclass -from typing import List, Literal, Sequence, Tuple +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple from mmirage.config.config import MMirageConfig from mmirage.cli_utils.slurm import submit_slurm_job -from mmirage.shard_utils import ShardStatus +from mmirage.shard_utils import ShardStatus, format_duration, read_status, shard_state_dir logger = logging.getLogger(__name__) @@ -41,11 +41,6 @@ def is_retry_budget_exceeded(attempt_count: int, max_retries: int) -> bool: return attempt_count > max_allowed_attempts(max_retries) -def shard_state_dir(state_root: str, shard_id: int) -> str: - """Return the state directory for a shard.""" - return os.path.join(state_root, f"shard_{shard_id}") - - def get_shard_status(state_dir: str) -> Tuple[str, int]: """Read the current status and attempt counter for a shard.""" status_file = os.path.join(state_dir, "status.json") @@ -79,7 +74,7 @@ def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]: allowed_attempts = max_allowed_attempts(max_retries) for shard_id in range(num_shards): - status, attempt_count = get_shard_status(shard_state_dir(state_root, shard_id)) + status, attempt_count = get_shard_status(shard_state_dir(shard_id, state_root)) if status == "success": success_count += 1 elif status == "running": @@ -140,6 +135,7 @@ def submit_failed_shards( config_path: str, failed_shards: Sequence[int], confirm_mode: Literal["prompt", "yes"], + collect_stats: bool = False, ) -> int: """Submit retry jobs for failed shards when requested.""" if not failed_shards: @@ -148,8 +144,141 @@ def submit_failed_shards( if not confirm_retry(len(failed_shards), confirm_mode): return 1 - job_id = submit_slurm_job(cfg, config_path, failed_shards) + job_id = submit_slurm_job(cfg, config_path, failed_shards, collect_stats=collect_stats) if job_id is None: return 1 return 0 + + +def collect_bench_stats(cfg: MMirageConfig) -> Dict[str, Any]: + """Collect per-shard benchmark statistics and compute aggregate totals. + + Returns a dict with two keys: + + - ``per_shard``: list of dicts, one per shard, each containing the full + :class:`~mmirage.shard_utils.ShardStatus` payload plus a flattened + ``stats`` sub-dict. + - ``aggregate``: rolled-up totals across all completed shards. + + Shards without ``stats`` (e.g. still running or from older runs) are + included in ``per_shard`` but excluded from aggregate calculations. + """ + state_root = cfg.loading_params.get_state_root() + num_shards = cfg.loading_params.get_num_shards() + + per_shard: List[Dict[str, Any]] = [] + + total_rows: int = 0 + sum_runtime: float = 0.0 + runtimes: List[float] = [] + gpu_util_weighted: List[float] = [] # util * rows for weighted mean + gpu_total_rows_for_weight: int = 0 + earliest_start: Optional[str] = None + latest_finish: Optional[str] = None + # Token-level aggregates (DataTrove-compatible benchmark format). + total_input_tokens: int = 0 + total_output_tokens: int = 0 + has_token_data: bool = False + sum_model_load_seconds: float = 0.0 + num_gpus: Optional[int] = None # taken from first shard that has it + + for shard_id in range(num_shards): + state_dir = shard_state_dir(shard_id, state_root) + status = read_status(state_dir) + entry: Dict[str, Any] = status.to_dict() + per_shard.append(entry) + + if status.status != "success" or status.stats is None: + continue + + s = status.stats + if s.runtime_seconds is not None: + sum_runtime += s.runtime_seconds + runtimes.append(s.runtime_seconds) + if s.rows_processed is not None: + total_rows += s.rows_processed + if s.gpu_util_mean is not None and s.rows_processed: + gpu_util_weighted.append(s.gpu_util_mean * s.rows_processed) + gpu_total_rows_for_weight += s.rows_processed + + # Accumulate token counts. + if s.input_tokens is not None: + total_input_tokens += s.input_tokens + has_token_data = True + if s.output_tokens is not None: + total_output_tokens += s.output_tokens + has_token_data = True + if s.model_load_seconds is not None: + sum_model_load_seconds += s.model_load_seconds + if num_gpus is None and s.num_gpus is not None: + num_gpus = s.num_gpus + + # Track earliest start / latest finish for wall-clock runtime. + if status.started_at: + if earliest_start is None or status.started_at < earliest_start: + earliest_start = status.started_at + if status.finished_at: + if latest_finish is None or status.finished_at > latest_finish: + latest_finish = status.finished_at + + # Wall-clock runtime: time from first shard start to last shard finish. + wall_clock: Optional[float] = None + if earliest_start and latest_finish: + try: + from datetime import datetime as _dt + wall_clock = round( + (_dt.fromisoformat(latest_finish) - _dt.fromisoformat(earliest_start)).total_seconds(), + 3, + ) + except (ValueError, TypeError): + pass + + overall_throughput: Optional[float] = None + if total_rows > 0 and wall_clock and wall_clock > 0: + overall_throughput = round(total_rows / wall_clock, 2) + + mean_gpu_util: Optional[float] = None + if gpu_util_weighted and gpu_total_rows_for_weight > 0: + mean_gpu_util = round(sum(gpu_util_weighted) / gpu_total_rows_for_weight, 1) + + # Aggregate token-throughput metrics (DataTrove-compatible benchmark format). + # Uses sum of inference runtimes (total minus model loading) for a per-GPU token rate + # that excludes one-time model initialisation overhead. + agg_tokens_per_sec_per_gpu: Optional[float] = None + agg_gpu_days_per_billion_tokens: Optional[float] = None + agg_inference_runtime: Optional[float] = None + if has_token_data and total_output_tokens > 0 and runtimes and num_gpus and num_gpus > 0: + agg_inference_runtime = max(0.0, sum_runtime - sum_model_load_seconds) + if agg_inference_runtime > 0: + total_gpu_seconds = agg_inference_runtime * num_gpus + agg_tokens_per_sec_per_gpu = round(total_output_tokens / total_gpu_seconds, 2) + total_gpu_days = total_gpu_seconds / 86_400 + agg_gpu_days_per_billion_tokens = round(total_gpu_days / (total_output_tokens / 1e9), 4) + + aggregate: Dict[str, Any] = { + "total_shards": num_shards, + "completed_shards": sum(1 for e in per_shard if e.get("status") == "success"), + "total_rows_processed": total_rows if total_rows > 0 else None, + "wall_clock_runtime_seconds": wall_clock, + "wall_clock_runtime_human": format_duration(wall_clock), + "sum_shard_runtime_seconds": round(sum_runtime, 3) if runtimes else None, + "sum_shard_runtime_human": format_duration(round(sum_runtime, 3) if runtimes else None), + "min_shard_runtime_seconds": round(min(runtimes), 3) if runtimes else None, + "min_shard_runtime_human": format_duration(round(min(runtimes), 3) if runtimes else None), + "max_shard_runtime_seconds": round(max(runtimes), 3) if runtimes else None, + "max_shard_runtime_human": format_duration(round(max(runtimes), 3) if runtimes else None), + "overall_throughput_rows_per_sec": overall_throughput, + "mean_gpu_util_pct": mean_gpu_util, + # Token-level benchmark metrics (DataTrove-compatible). + "num_gpus": num_gpus, + "total_input_tokens": total_input_tokens if has_token_data else None, + "total_output_tokens": total_output_tokens if has_token_data else None, + "sum_model_load_seconds": round(sum_model_load_seconds, 3) if sum_model_load_seconds > 0 else None, + "sum_inference_runtime_seconds": round(agg_inference_runtime, 3) if agg_inference_runtime is not None else None, + "tokens_per_sec_per_gpu": agg_tokens_per_sec_per_gpu, + "gpu_days_per_billion_tokens": agg_gpu_days_per_billion_tokens, + } + + return {"per_shard": per_shard, "aggregate": aggregate} + diff --git a/src/mmirage/config/loading.py b/src/mmirage/config/loading.py index 929d76d..afb1386 100644 --- a/src/mmirage/config/loading.py +++ b/src/mmirage/config/loading.py @@ -47,7 +47,7 @@ def is_unresolved_env_var(s: str) -> bool: if self.num_shards < 1: raise ValueError() except (ValueError, TypeError): - if is_unresolved_env_var(self.num_shards): + if isinstance(self.num_shards, str) and is_unresolved_env_var(self.num_shards): self.num_shards = 1 else: raise ValueError(f"Invalid value for num_shards: {self.num_shards!r}") @@ -56,7 +56,7 @@ def is_unresolved_env_var(s: str) -> bool: try: self.shard_id = int(self.shard_id) except (ValueError, TypeError): - if is_unresolved_env_var(self.shard_id): + if isinstance(self.shard_id, str) and is_unresolved_env_var(self.shard_id): self.shard_id = 0 else: raise ValueError(f"Invalid value for shard_id: {self.shard_id!r}") diff --git a/src/mmirage/core/process/base.py b/src/mmirage/core/process/base.py index 988bae7..6e8a283 100644 --- a/src/mmirage/core/process/base.py +++ b/src/mmirage/core/process/base.py @@ -24,6 +24,14 @@ class BaseProcessorConfig: C = TypeVar("C", bound=OutputVar) +@dataclass +class TokenCounts: + """Cumulative token counts from LLM processors.""" + + input_tokens: int + output_tokens: int + + class BaseProcessor(abc.ABC, Generic[C]): """Abstract base class for data processors. @@ -64,6 +72,30 @@ def batch_process_sample( """ raise NotImplementedError() + @abc.abstractmethod + def get_token_counts(self) -> TokenCounts: + """Get cumulative token counts from this processor. + + Returns: + TokenCounts object containing input and output token counts. + + Raises: + NotImplementedError: If not implemented by subclass. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_load_time(self) -> float: + """Get the time taken to load any necessary resources (e.g., models). + + Returns: + Time in seconds taken to load resources. + + Raises: + NotImplementedError: If not implemented by subclass. + """ + raise NotImplementedError() + class ProcessorRegistry: """Registry for managing and accessing available processors. diff --git a/src/mmirage/core/process/mapper.py b/src/mmirage/core/process/mapper.py index 5310150..877741b 100644 --- a/src/mmirage/core/process/mapper.py +++ b/src/mmirage/core/process/mapper.py @@ -1,9 +1,11 @@ """Mapper for orchestrating variable transformations.""" -from typing import Dict, Any, List, cast +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, cast +from mmirage.core.process.base import AutoProcessor, BaseProcessor, BaseProcessorConfig, TokenCounts from mmirage.core.process.variables import BaseVar, InputVar, OutputVar -from mmirage.core.process.base import AutoProcessor, BaseProcessor, BaseProcessorConfig + import logging @@ -73,7 +75,7 @@ def validate_vars(self) -> bool: def rewrite_batch( self, batch: Dict[str, List[Any]], - image_base_path: str = None, + image_base_path: Optional[str] = None, ) -> List[VariableEnvironment]: """Transform a batch of samples by computing output variables. @@ -103,3 +105,29 @@ def rewrite_batch( ) return batch_environment + + def get_token_counts(self) -> TokenCounts: + """Return cumulative token counts aggregated across all LLM processors. + + Sums ``input_tokens`` and ``output_tokens`` from every processor that + exposes a ``get_token_counts()`` method (i.e., ``LLMProcessor``). + + Returns: + TokenCounts with ``input_tokens`` and ``output_tokens`` fields. + """ + total_input = 0 + total_output = 0 + for proc in self.processors.values(): + if hasattr(proc, "get_token_counts"): + counts = proc.get_token_counts() + total_input += counts.input_tokens + total_output += counts.output_tokens + return TokenCounts(input_tokens=total_input, output_tokens=total_output) + + def get_load_time(self) -> float: + """Return total model-loading time (seconds) summed across all LLM processors.""" + total = 0.0 + for proc in self.processors.values(): + if hasattr(proc, "get_load_time"): + total += proc.get_load_time() + return total diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index f323599..dde3029 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -59,12 +59,21 @@ class SGLangServerArgs: tp_size: Tensor parallelism size. trust_remote_code: Whether to trust remote code from HuggingFace. disable_custom_all_reduce: Whether to disable custom all reduce. + extra_engine_args: Any additional keyword arguments forwarded verbatim + to ``sgl.Engine``. Use this to pass SGLang-specific options that + are not listed above, e.g.:: + + extra_engine_args: + max_running_requests: 512 + chunked_prefill_size: 32768 + mem_fraction_static: 0.88 """ model_path: str = "none" tp_size: int = field(default_factory=_parse_tp_size_from_env) trust_remote_code: bool = True disable_custom_all_reduce: bool = False + extra_engine_args: Dict[str, Any] = field(default_factory=dict) @dataclass diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 56afd17..5107582 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -5,13 +5,14 @@ from dataclasses import asdict import json import logging +import time from typing import Any, List, Tuple import jinja2 import sglang as sgl from transformers import AutoTokenizer -from mmirage.core.process.base import BaseProcessor, ProcessorRegistry +from mmirage.core.process.base import BaseProcessor, ProcessorRegistry, TokenCounts from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig from mmirage.core.process.variables import VariableEnvironment @@ -58,13 +59,43 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: **kwargs: Additional arguments passed to base class. """ super().__init__(engine_args, **kwargs) - self.llm = sgl.Engine(**asdict(engine_args.server_args)) + server_kwargs = asdict(engine_args.server_args) + extra = server_kwargs.pop("extra_engine_args", {}) or {} + server_kwargs.update(extra) + _load_start = time.monotonic() + self.llm = sgl.Engine(**server_kwargs) + self._model_load_seconds: float = time.monotonic() - _load_start self.tokenizer = AutoTokenizer.from_pretrained( engine_args.server_args.model_path, trust_remote_code=getattr(engine_args.server_args, "trust_remote_code", False), ) self.sampling_params = engine_args.default_sampling_params self.chat_template = engine_args.chat_template + # Cumulative token counts across all generate() calls in this processor's lifetime. + self._total_input_tokens: int = 0 + self._total_output_tokens: int = 0 + + def get_load_time(self) -> float: + """Return the wall-clock seconds spent initializing the SGLang engine.""" + return self._model_load_seconds + + def get_token_counts(self) -> TokenCounts: + """Return cumulative token counts for this processor. + + Returns: + TokenCounts object containing input and output token counts accumulated since this processor was created. + """ + return TokenCounts( + input_tokens=self._total_input_tokens, + output_tokens=self._total_output_tokens + ) + + def _accumulate_tokens(self, outputs: list) -> None: + """Add token counts from a list of SGLang generate() outputs.""" + for out in outputs: + meta = out.get("meta_info") or {} + self._total_input_tokens += int(meta.get("prompt_tokens") or 0) + self._total_output_tokens += int(meta.get("completion_tokens") or 0) def build_prompt( self, prompt_template: str, vars_samples: List[VariableEnvironment] @@ -190,6 +221,8 @@ def batch_process_sample( f"{len(text_only_outputs) if isinstance(text_only_outputs, list) else 'non-list'}" ) + self._accumulate_tokens(text_only_outputs) + for local_idx, global_i in enumerate(text_only_indices): value = text_only_outputs[local_idx].get("text", "").strip() if output_var.output_type == "JSON": @@ -252,6 +285,8 @@ def batch_process_sample( f"{len(multimodal_outputs) if isinstance(multimodal_outputs, list) else 'non-list'}" ) + self._accumulate_tokens(multimodal_outputs) + for local_idx, global_i in enumerate(multimodal_indices): value = multimodal_outputs[local_idx].get("text", "").strip() if output_var.output_type == "JSON": diff --git a/src/mmirage/merge_shards.py b/src/mmirage/merge_shards.py index fbddf10..a54444a 100644 --- a/src/mmirage/merge_shards.py +++ b/src/mmirage/merge_shards.py @@ -49,7 +49,7 @@ def _merge_datasetdict(shard_dsets: List[DatasetDict]) -> DatasetDict: merged[str(split)] = concatenate_datasets(split_dsets) if not merged: raise RuntimeError("All splits were empty after merging.") - return DatasetDict(merged) + return DatasetDict(**merged) def _merge_shards(shard_dsets: List[DatasetLike]) -> DatasetLike: diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 66e8529..d12e07d 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -5,9 +5,10 @@ import argparse import logging +import os import sys import traceback -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from mmirage.config.utils import load_mmirage_config from mmirage.core.loader.base import DatasetLike @@ -15,6 +16,8 @@ from mmirage.core.process.mapper import MMIRAGEMapper from mmirage.core.writer.renderer import TemplateRenderer from mmirage.shard_utils import ( + GpuUtilizationPoller, + ShardStats, _cleanup_old_shard_data, _count_rows, _dataset_out_dir, @@ -24,7 +27,7 @@ _remove_columns, _save_dataset_atomic, _shard_dataset, - _shard_state_dir, + shard_state_dir, ) logger = logging.getLogger(__name__) @@ -34,7 +37,7 @@ def rewrite_batch( batch: Dict[str, List[Any]], mapper: MMIRAGEMapper, renderer: TemplateRenderer, - image_base_path: str = None, + image_base_path: Optional[str] = None, ) -> Dict[str, List[Any]]: """Rewrite a batch of samples by applying transformations. Args: @@ -86,7 +89,35 @@ def main(): if not (0 <= shard_id < num_shards): raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") - state_dir = _shard_state_dir(shard_id, loading_params.get_state_root()) + state_dir = shard_state_dir(shard_id, loading_params.get_state_root()) + + gpu_poller: Optional[GpuUtilizationPoller] = None + + collect_stats = os.environ.get("MMIRAGE_COLLECT_STATS", "") == "1" + if collect_stats: + # Determine which physical GPU indices SGLang will use so the poller + # measures only the active GPU(s) — not all GPUs on the node. + # SLURM may allocate more GPUs than tp_size (e.g. gpus=4, tp_size=1). + # We take only the first tp_size entries from CUDA_VISIBLE_DEVICES so + # nvidia-smi --id receives exactly the GPUs SGLang is using. + tp_size = 1 + for proc_cfg in cfg.processors: + tp = getattr(getattr(proc_cfg, "server_args", None), "tp_size", None) + if tp and int(tp) > 0: + tp_size = int(tp) + break + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): + all_visible = [x.strip() for x in cuda_visible.split(",") if x.strip()] + # Fall back to range-based indices if CUDA_VISIBLE_DEVICES was set + # but contained only whitespace/empty entries after stripping. + gpu_indices_for_polling: List[str] = all_visible[:tp_size] if all_visible else [str(i) for i in range(tp_size)] + else: + gpu_indices_for_polling = [str(i) for i in range(tp_size)] + + gpu_poller = GpuUtilizationPoller( + interval_seconds=5.0, gpu_indices=gpu_indices_for_polling + ) try: retry_count = _mark_running(state_dir, shard_id, datasets_config) @@ -115,6 +146,11 @@ def main(): ) renderer = TemplateRenderer(processing_params.output_schema) + # Start GPU polling after model loading so utilisation samples reflect + # inference only, not weight transfers during sgl.Engine() init. + if collect_stats and gpu_poller is not None: + gpu_poller.start() + ds_processed_all: List[DatasetLike] = [] for ds_idx, ds_shard in enumerate(ds_all_shard): ds_config = datasets_config[ds_idx] @@ -125,7 +161,7 @@ def main(): logger.info( f"Processing dataset {ds_idx} for shard {shard_id}: " - f"path={ds_config.path}, output_dir={ds_config.output_dir}" + f"image_base_path={ds_config.image_base_path}, output_dir={ds_config.output_dir}" ) ds_processed = ds_shard.map( @@ -148,13 +184,42 @@ def main(): _save_dataset_atomic(ds_processed, out_dir) logger.info(f"✅ Saved dataset {ds_idx} shard in: {out_dir}") - _mark_success(state_dir) + gpu_info = gpu_poller.stop() if collect_stats and gpu_poller is not None else {"mean": None, "min": None, "max": None, "samples": 0} + + # Collect token counts accumulated by LLM processor(s). + token_counts = mapper.get_token_counts() + input_tokens = token_counts.input_tokens or None + output_tokens = token_counts.output_tokens or None + model_load_seconds = mapper.get_load_time() or None + + # Resolve num_gpus from the first processor config that exposes tp_size. + num_gpus: Optional[int] = None + for proc_cfg in cfg.processors: + tp = getattr(getattr(proc_cfg, "server_args", None), "tp_size", None) + if tp and tp > 0: + num_gpus = int(tp) + break + + stats = ShardStats( + rows_processed=shard_rows, + gpu_util_mean=gpu_info["mean"], + gpu_util_min=gpu_info["min"], + gpu_util_max=gpu_info["max"], + gpu_util_samples=gpu_info["samples"], + input_tokens=input_tokens, + output_tokens=output_tokens, + num_gpus=num_gpus, + model_load_seconds=model_load_seconds, + ) + _mark_success(state_dir, stats=stats) logger.info(f"✅ Logical shard {shard_id} completed successfully") except Exception as e: error_msg = f"{type(e).__name__}: {str(e)}" logger.error(f"❌ Shard {shard_id} failed: {error_msg}") logger.error(traceback.format_exc()) + if collect_stats and gpu_poller is not None: + gpu_poller.stop() _mark_failure(state_dir, error_msg) sys.exit(1) diff --git a/src/mmirage/shard_utils.py b/src/mmirage/shard_utils.py index 76af6c9..31a254e 100644 --- a/src/mmirage/shard_utils.py +++ b/src/mmirage/shard_utils.py @@ -6,11 +6,14 @@ from datetime import datetime from dataclasses import dataclass +import humanize import json import logging import os import shutil import socket +import subprocess +import threading import uuid from typing import Any, Dict, List, Optional @@ -21,6 +24,204 @@ logger = logging.getLogger(__name__) +def format_duration(seconds: Optional[float]) -> Optional[str]: + """Format a duration given in seconds as a human-readable string.""" + if seconds is None: + return None + return humanize.precisedelta(seconds) + + +@dataclass +class ShardStats: + """Per-shard benchmark statistics recorded at completion.""" + + runtime_seconds: Optional[float] = None + rows_processed: Optional[int] = None + throughput_rows_per_sec: Optional[float] = None + gpu_util_mean: Optional[float] = None + gpu_util_min: Optional[float] = None + gpu_util_max: Optional[float] = None + gpu_util_samples: Optional[int] = None + # Token-level throughput metrics (DataTrove-compatible benchmark format). + # input_tokens: total prompt tokens consumed across all LLM calls in this shard. + # output_tokens: total completion tokens generated across all LLM calls in this shard. + # num_gpus: number of GPUs used (tensor-parallel size from the LLM processor config). + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + num_gpus: Optional[int] = None + model_load_seconds: Optional[float] = None + + @classmethod + def from_dict(cls, data: Optional[Dict[str, Any]]) -> Optional["ShardStats"]: + """Build a ShardStats from a JSON payload, or return None if data is missing.""" + if not isinstance(data, dict): + return None + + def _opt_float(v: Any) -> Optional[float]: + try: + return float(v) if v is not None else None + except (TypeError, ValueError): + return None + + def _opt_int(v: Any) -> Optional[int]: + try: + return int(v) if v is not None else None + except (TypeError, ValueError): + return None + + return cls( + runtime_seconds=_opt_float(data.get("runtime_seconds")), + rows_processed=_opt_int(data.get("rows_processed")), + throughput_rows_per_sec=_opt_float(data.get("throughput_rows_per_sec")), + gpu_util_mean=_opt_float(data.get("gpu_util_mean")), + gpu_util_min=_opt_float(data.get("gpu_util_min")), + gpu_util_max=_opt_float(data.get("gpu_util_max")), + gpu_util_samples=_opt_int(data.get("gpu_util_samples")), + input_tokens=_opt_int(data.get("input_tokens")), + output_tokens=_opt_int(data.get("output_tokens")), + num_gpus=_opt_int(data.get("num_gpus")), + model_load_seconds=_opt_float(data.get("model_load_seconds")), + ) + + def to_dict(self) -> Dict[str, Any]: + # Derived token-throughput metrics (DataTrove-compatible benchmark format). + # Use inference_runtime (total minus model loading) so metrics reflect + # pure generation speed, excluding one-time model initialisation overhead. + tokens_per_sec_per_gpu: Optional[float] = None + gpu_days_per_billion_tokens: Optional[float] = None + inference_runtime: Optional[float] = None + if self.runtime_seconds is not None: + if self.model_load_seconds is not None: + inference_runtime = max(0.0, self.runtime_seconds - self.model_load_seconds) + else: + inference_runtime = self.runtime_seconds + if ( + self.output_tokens is not None + and self.output_tokens > 0 + and inference_runtime is not None + and inference_runtime > 0 + and self.num_gpus is not None + and self.num_gpus > 0 + ): + tokens_per_sec_per_gpu = round( + self.output_tokens / (inference_runtime * self.num_gpus), 2 + ) + gpu_days_per_billion_tokens = round( + (self.num_gpus * inference_runtime / 86_400) / (self.output_tokens / 1e9), 4 + ) + + return { + "runtime_seconds": self.runtime_seconds, + "runtime_human": format_duration(self.runtime_seconds), + "model_load_seconds": round(self.model_load_seconds, 3) if self.model_load_seconds is not None else None, + "inference_runtime_seconds": round(inference_runtime, 3) if inference_runtime is not None else None, + "rows_processed": self.rows_processed, + "throughput_rows_per_sec": self.throughput_rows_per_sec, + "gpu_util_mean": self.gpu_util_mean, + "gpu_util_min": self.gpu_util_min, + "gpu_util_max": self.gpu_util_max, + "gpu_util_samples": self.gpu_util_samples, + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "num_gpus": self.num_gpus, + "tokens_per_sec_per_gpu": tokens_per_sec_per_gpu, + "gpu_days_per_billion_tokens": gpu_days_per_billion_tokens, + } + + +class GpuUtilizationPoller: + """Polls ``nvidia-smi`` in a background daemon thread. + + Usage:: + + poller = GpuUtilizationPoller() + poller.start() + # ... do work ... + gpu_info = poller.stop() # {"mean": 85.2, "min": 70.0, "max": 98.0, "samples": 24} + + If ``nvidia-smi`` is unavailable all values are ``None`` and samples is 0. + """ + + def __init__(self, interval_seconds: float = 5.0, gpu_indices: Optional[List[str]] = None) -> None: + self._interval = interval_seconds + self._samples: List[float] = [] + self._thread: Optional[threading.Thread] = None + self._stop_event = threading.Event() + # Explicit GPU indices take priority over CUDA_VISIBLE_DEVICES. + # Pass the indices SGLang will use (0..tp_size-1 in local mode). + self._gpu_indices = gpu_indices + + def start(self) -> None: + """Start background polling.""" + self._stop_event.clear() + self._samples = [] + self._thread = threading.Thread(target=self._poll_loop, daemon=True) + self._thread.start() + + def stop(self) -> Dict[str, Any]: + """Stop polling and return a summary dict.""" + self._stop_event.set() + if self._thread is not None: + self._thread.join(timeout=self._interval + 2.0) + samples = self._samples + if not samples: + return {"mean": None, "min": None, "max": None, "samples": 0} + return { + "mean": round(sum(samples) / len(samples), 1), + "min": float(min(samples)), + "max": float(max(samples)), + "samples": len(samples), + } + + def _poll_loop(self) -> None: + while not self._stop_event.wait(timeout=self._interval): + util = self._query_gpu_util() + if util is not None: + self._samples.append(util) + + def _query_gpu_util(self) -> Optional[float]: + try: + cmd = [ + "nvidia-smi", + "--query-gpu=utilization.gpu", + "--format=csv,noheader,nounits", + ] + # Restrict to the GPUs this process actually uses so we don't + # dilute utilization by averaging over idle GPUs on the same node. + # Priority: explicit gpu_indices > CUDA_VISIBLE_DEVICES > all GPUs. + if self._gpu_indices is not None: + if not self._gpu_indices: + # Empty list would produce --id= which is invalid; skip filtering. + pass + else: + cmd += [f"--id={','.join(str(i) for i in self._gpu_indices)}"] + else: + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible and cuda_visible.lower() not in ("all", "nodevfiles"): + cmd += [f"--id={cuda_visible}"] + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=5, + check=False, + ) + if result.returncode == 0: + values = [] + for line in result.stdout.strip().splitlines(): + line = line.strip() + if line: + try: + values.append(float(line)) + except ValueError: + pass + if values: + return sum(values) / len(values) + except Exception: + pass + return None + + @dataclass class ShardStatus: """Typed representation of the shard status.json payload.""" @@ -36,6 +237,7 @@ class ShardStatus: slurm_job_id: Optional[str] = None slurm_array_task_id: Optional[str] = None datasets: Optional[List[Dict[str, Any]]] = None + stats: Optional[ShardStats] = None @classmethod def from_dict(cls, payload: Dict[str, Any]) -> "ShardStatus": @@ -64,6 +266,8 @@ def from_dict(cls, payload: Dict[str, Any]) -> "ShardStatus": if not isinstance(datasets, list): datasets = None + stats = ShardStats.from_dict(data.get("stats")) + return cls( status=str(data.get("status", "unknown")), retry_count=retry_count, @@ -76,6 +280,7 @@ def from_dict(cls, payload: Dict[str, Any]) -> "ShardStatus": slurm_job_id=data.get("slurm_job_id"), slurm_array_task_id=data.get("slurm_array_task_id"), datasets=datasets, + stats=stats, ) def to_dict(self) -> Dict[str, Any]: @@ -92,6 +297,7 @@ def to_dict(self) -> Dict[str, Any]: "slurm_job_id": self.slurm_job_id, "slurm_array_task_id": self.slurm_array_task_id, "datasets": self.datasets, + "stats": self.stats.to_dict() if self.stats is not None else None, } @@ -187,7 +393,7 @@ def _dataset_out_dir(shard_idx: int, ds_config: BaseDataLoaderConfig) -> str: return os.path.join(ds_config.output_dir, f"shard_{shard_idx}") -def _shard_state_dir(shard_idx: int, state_root: str) -> str: +def shard_state_dir(shard_idx: int, state_root: str) -> str: """Get central state directory for a logical shard.""" return os.path.join(state_root, f"shard_{shard_idx}") @@ -204,7 +410,7 @@ def _status_file(state_dir: str) -> str: return os.path.join(state_dir, "status.json") -def _read_status(state_dir: str) -> ShardStatus: +def read_status(state_dir: str) -> ShardStatus: """Read status.json if present.""" path = _status_file(state_dir) if not os.path.exists(path): @@ -255,7 +461,7 @@ def _mark_running( datasets_config: List[BaseDataLoaderConfig], ) -> int: """Mark shard as running and increment retry count.""" - prev = _read_status(state_dir) + prev = read_status(state_dir) retry_count = prev.retry_count + 1 payload = ShardStatus( @@ -271,7 +477,7 @@ def _mark_running( slurm_array_task_id=os.environ.get("SLURM_ARRAY_TASK_ID"), datasets=[ { - "path": ds_config.path, + "image_base_path": ds_config.image_base_path, "output_dir": ds_config.output_dir, } for ds_config in datasets_config @@ -284,12 +490,50 @@ def _mark_running( return retry_count -def _mark_success(state_dir: str): - """Mark shard as successful.""" - prev = _read_status(state_dir) +def _mark_success(state_dir: str, stats: Optional[ShardStats] = None): + """Mark shard as successful and record benchmark statistics. + + Args: + state_dir: Shard state directory. + stats: Optional benchmark stats; ``runtime_seconds`` and + ``throughput_rows_per_sec`` are computed from the stored timestamps + when not already set. + """ + prev = read_status(state_dir) prev.status = "success" - prev.finished_at = datetime.now().isoformat() + now = datetime.now() + prev.finished_at = now.isoformat() prev.error = None + + if stats is not None: + # Derive runtime from stored start timestamp when not already supplied. + if stats.runtime_seconds is None and prev.started_at: + try: + started = datetime.fromisoformat(prev.started_at) + stats.runtime_seconds = round((now - started).total_seconds(), 3) + except (ValueError, TypeError): + pass + + # Derive throughput once we have both rows and runtime. + # Use inference_runtime (total minus model loading) so the metric + # reflects pure generation speed, consistent with tokens_per_sec_per_gpu. + if ( + stats.throughput_rows_per_sec is None + and stats.rows_processed is not None + and stats.runtime_seconds is not None + and stats.runtime_seconds > 0 + ): + inference_runtime = ( + max(0.0, stats.runtime_seconds - stats.model_load_seconds) + if stats.model_load_seconds is not None + else stats.runtime_seconds + ) + if inference_runtime > 0: + stats.throughput_rows_per_sec = round( + stats.rows_processed / inference_runtime, 2 + ) + + prev.stats = stats _write_status(state_dir, prev) _clear_markers(state_dir) _touch_marker(state_dir, ".SUCCESS") @@ -297,7 +541,7 @@ def _mark_success(state_dir: str): def _mark_failure(state_dir: str, error_msg: str): """Mark shard as failed.""" - prev = _read_status(state_dir) + prev = read_status(state_dir) prev.status = "failed" prev.finished_at = datetime.now().isoformat() prev.error = error_msg