Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,88 @@ Key multimodal features:
- `image_base_path`: Base directory for resolving relative image paths
- Supports PIL Images, URLs, and file paths

### Benchmarking shard performance

Pass `--stats` to `run` or `submit` to enable per-shard benchmarking. This activates GPU
utilization polling and throughput tracking on compute nodes — disabled by default to
avoid unnecessary overhead.

```bash
# Local run with stats collection
mmirage run --config configs/config_mock.yaml --stats

```

After the run completes, inspect the results with:

```bash
mmirage stats --config configs/config_mock.yaml
```

This prints a JSON report with per-shard details and an aggregate summary:

```json
{
"per_shard": [
{
"shard_id": 0,
"status": "success",
"started_at": "2026-04-30T10:00:00",
"finished_at": "2026-04-30T10:01:05",
"stats": {
"runtime_seconds": 65.2,
"runtime_human": "1m 5s",
"rows_processed": 1024,
"throughput_rows_per_sec": 15.7,
"gpu_util_mean": 88.4,
"gpu_util_min": 72.0,
"gpu_util_max": 98.0,
"gpu_util_samples": 13,
"input_tokens": 512000,
"output_tokens": 196608,
"num_gpus": 4,
"tokens_per_sec_per_gpu": 753.1,
"gpu_days_per_billion_tokens": 0.0015
}
}
],
"aggregate": {
"total_shards": 1,
"completed_shards": 1,
"total_rows_processed": 1000,
"wall_clock_runtime_seconds": 133.04,
"wall_clock_runtime_human": "2m 13s",
"sum_shard_runtime_seconds": 133.04,
"sum_shard_runtime_human": "2m 13s",
"min_shard_runtime_seconds": 133.04,
"min_shard_runtime_human": "2m 13s",
"max_shard_runtime_seconds": 133.04,
"max_shard_runtime_human": "2m 13s",
"overall_throughput_rows_per_sec": 7.52,
"mean_gpu_util_pct": 86.2,
"num_gpus": 1,
Comment thread
qchapp marked this conversation as resolved.
Outdated
"total_input_tokens": 146214,
"total_output_tokens": 1022046,
"sum_model_load_seconds": 38.272,
"sum_inference_runtime_seconds": 94.768,
"tokens_per_sec_per_gpu": 10784.72,
"gpu_days_per_billion_tokens": 1.0732
Comment thread
qchapp marked this conversation as resolved.
}
}
```

Key metrics:
- **`runtime_seconds`** / **`runtime_human`**: time from when the shard started on the cluster (after dispatch), excluding queue wait time.
- **`overall_throughput_rows_per_sec`**: total rows / wall-clock time across all shards running in parallel.
- **`mean_gpu_util_pct`**: mean percentage GPU utilization across shards.
- **`tokens_per_sec_per_gpu`**: output tokens generated per second per GPU — the primary throughput metric used by frameworks such as [DataTrove](https://github.com/huggingface/datatrove).
- **`gpu_days_per_billion_tokens`**: total GPU-days consumed to generate 1 billion output tokens — useful for cost and scaling comparisons across different hardware configurations.
- Token metrics are `null` when no LLM processor was active, and GPU stats are `null` when `nvidia-smi` is unavailable or `--stats` was not passed.

Reference benchmark:
- [DataTrove Benchmark](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark)
- `mmirage run --config configs/config_benchmark_datatrove.yaml --stats`

## Architecture

MMIRAGE uses a modular architecture:
Expand All @@ -258,3 +340,4 @@ mmirage/
- JMESPath for JSON queries: [link](https://jmespath.org/)
- SGLang for fast inference: [link](https://github.com/sgl-project/sglang)
- Performance paper: [link](https://arxiv.org/abs/2408.02442)
- DataTrove Benchmark: [link](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark)
87 changes: 87 additions & 0 deletions configs/config_benchmark_datatrove.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# MMIRAGE — DataTrove-compatible throughput benchmark
#
# Mirrors the conditions used in the DataTrove inference benchmark
# (https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark):
#
# dataset : simplescaling/s1K-1.1 (train split, 1 000 samples)
# prompt : raw `question` field, no system prompt
# output : up to 1 024 tokens per sample
# context : 2 048-token model max context
# model : Qwen/Qwen3-4B (DataTrove baseline: tp=1 on a single GPU)
#
# Download the dataset before running:
#
# python -c "
# from datasets import load_dataset
# ds = load_dataset('simplescaling/s1K-1.1', split='train')
# ds.save_to_disk('data/s1K-1.1')
# "
#
# Then run with stats collection enabled:
#
# mmirage run --config configs/config_benchmark_datatrove.yaml --stats
#
# Inspect results:
#
# mmirage stats --config configs/config_benchmark_datatrove.yaml

Comment thread
qchapp marked this conversation as resolved.
Outdated
processors:
- type: llm
server_args:
model_path: Qwen/Qwen3-4B # same model family as DataTrove baseline
tp_size: 1 # DataTrove baseline: tp=1
trust_remote_code: true
disable_custom_all_reduce: true
# SGLang engine tuning — equivalents of DataTrove's vLLM mns/mnbt knobs
extra_engine_args:
max_running_requests: 1000
default_sampling_params:
temperature: 0.0
max_new_tokens: 1024 # DataTrove: max-tokens=1024

loading_params:
state_dir: data/benchmark_s1k/_pipeline_state
datasets:
- path: data/s1K-1.1 # save_to_disk() target above
type: loadable
output_dir: data/benchmark_s1k/output
num_shards: 1
shard_id: "$SLURM_ARRAY_TASK_ID"
batch_size: 1000

processing_params:
inputs:
- name: question
key: question # DataTrove: prompt-column=question

outputs:
- name: answer
type: llm
output_type: plain
# Qwen3 thinking is disabled by embedding an empty <think> block in the prompt.
# This is equivalent to passing enable_thinking=False to the chat template and
# avoids any dependency on SGLang sampling-param support for that flag.
prompt: "<|im_start|>user\n{{ question }}\n<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n"

remove_columns: false
output_schema:
question: "{{ question }}"
answer: "{{ answer }}"

execution_params:
mode: slurm
retry: false
merge: false
max_retries: 3
account: a127
job_name: mmirage-sharded
nodes: 1
ntasks_per_node: 1
gpus: 4
Comment thread
qchapp marked this conversation as resolved.
Outdated
cpus_per_task: 288
time_limit: "11:59:59"
report_dir: "/users/${USER}/reports"
hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf"
edf_env: "/users/${USER}/.edf/mmirage.toml"
poll_interval_seconds: 30
settle_time_seconds: 60
64 changes: 58 additions & 6 deletions src/mmirage/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mmirage.cli_utils.slurm import require_slurm, submit_slurm_job, wait_for_slurm_job
from mmirage.cli_utils.status import (
check_failed_shards,
collect_bench_stats,
is_retry_budget_exceeded,
shard_state_dir,
get_shard_status,
Expand All @@ -29,12 +30,13 @@
logger = logging.getLogger(__name__)


def run_local(config_path: str, shard_id: Optional[int] = None) -> int:
def run_local(config_path: str, shard_id: Optional[int] = None, collect_stats: bool = False) -> int:
"""Run one shard in the current Python environment.

Args:
config_path: Absolute path to the MMIRAGE YAML config file.
shard_id: Optional shard id to inject via SLURM_ARRAY_TASK_ID.
collect_stats: If True, enable GPU utilization polling in the shard process.

Returns:
Process return code from shard execution.
Expand All @@ -43,6 +45,8 @@ def run_local(config_path: str, shard_id: Optional[int] = None) -> int:
env = os.environ.copy()
if shard_id is not None:
env["SLURM_ARRAY_TASK_ID"] = str(shard_id)
if collect_stats:
env["MMIRAGE_COLLECT_STATS"] = "1"

logger.info("Running local shard processing: %s", " ".join(command))
result = subprocess.run(command, env=env, check=False)
Expand All @@ -54,6 +58,7 @@ def launch_pipeline(
config_path: str,
force_retry: bool = False,
require_completion: bool = False,
collect_stats: bool = False,
) -> int:
"""Launch the pipeline according to execution mode and retry settings.

Expand All @@ -63,6 +68,7 @@ def launch_pipeline(
force_retry: If True, enable retry orchestration regardless of config flag.
require_completion: If True, wait for completion and verify shard
status before returning success in SLURM mode when auto-retry is off.
collect_stats: If True, enable GPU utilization polling on compute nodes.

Returns:
Exit code: 0 on success, 1 on failure.
Expand All @@ -72,7 +78,7 @@ def launch_pipeline(
if not cfg.execution_params.is_slurm():
initial_shard_id = cfg.loading_params.get_shard_id()
if not auto_retry:
exit_code = run_local(config_path, initial_shard_id)
exit_code = run_local(config_path, initial_shard_id, collect_stats=collect_stats)
if exit_code == 0:
logger.info("All shards completed successfully")
return exit_code
Expand All @@ -84,7 +90,7 @@ def launch_pipeline(
run_exit_codes = {}
for shard_id in shard_ids:
attempts_by_shard[shard_id] = attempts_by_shard.get(shard_id, 0) + 1
run_exit_codes[shard_id] = run_local(config_path, shard_id)
run_exit_codes[shard_id] = run_local(config_path, shard_id, collect_stats=collect_stats)

failed_shards, summary = check_failed_shards(cfg)
if status_exit_code(failed_shards, summary) == 0:
Expand Down Expand Up @@ -115,7 +121,7 @@ def launch_pipeline(
shard_ids: List[int] = []

while True:
job_id = submit_slurm_job(cfg, config_path, shard_ids)
job_id = submit_slurm_job(cfg, config_path, shard_ids, collect_stats=collect_stats)
if job_id is None:
return 1

Expand Down Expand Up @@ -192,6 +198,11 @@ def build_argparser() -> argparse.ArgumentParser:
help="Comma-separated shard ids to submit instead of the full array",
)
submit_parser.add_argument("--wait", action="store_true", help="Wait for the submitted job")
submit_parser.add_argument(
"--stats",
action="store_true",
help="Enable GPU utilization and throughput collection on compute nodes",
)

check_parser = subparsers.add_parser("check", help="Inspect shard status")
add_shared_arguments(check_parser)
Expand All @@ -211,6 +222,11 @@ def build_argparser() -> argparse.ArgumentParser:
help="Submit retries without prompting.",
)
check_parser.set_defaults(confirm_mode="prompt")
check_parser.add_argument(
"--stats",
action="store_true",
help="Enable GPU utilization and throughput collection on retried compute nodes",
)

retry_parser = subparsers.add_parser("retry", help="Submit only failed shards")
add_shared_arguments(retry_parser)
Expand All @@ -223,6 +239,11 @@ def build_argparser() -> argparse.ArgumentParser:
help="Submit retries without prompting.",
)
retry_parser.set_defaults(confirm_mode="prompt")
retry_parser.add_argument(
"--stats",
action="store_true",
help="Enable GPU utilization and throughput collection on retried compute nodes",
)

run_parser = subparsers.add_parser(
"run",
Expand All @@ -240,6 +261,11 @@ def build_argparser() -> argparse.ArgumentParser:
default=None,
help="Run a single shard locally (overrides execution mode)",
)
run_parser.add_argument(
"--stats",
action="store_true",
help="Enable GPU utilization and throughput collection during shard execution",
)

merge_parser = subparsers.add_parser(
"merge",
Expand Down Expand Up @@ -282,6 +308,12 @@ def build_argparser() -> argparse.ArgumentParser:
help="Log verbosity",
)

stats_parser = subparsers.add_parser(
"stats",
help="Show per-shard benchmark statistics (runtime, throughput, GPU utilization)",
)
add_shared_arguments(stats_parser)

return parser


Expand Down Expand Up @@ -346,13 +378,14 @@ def handle_run(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -
Exit code for the run operation.
"""
if args.shard_id is not None:
return run_local(config_path, args.shard_id)
return run_local(config_path, args.shard_id, collect_stats=args.stats)

exit_code = launch_pipeline(
cfg,
config_path,
force_retry=args.force_retry,
require_completion=cfg.execution_params.merge,
collect_stats=args.stats,
)
if exit_code != 0:
return exit_code
Expand Down Expand Up @@ -380,7 +413,7 @@ def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str
return 1

shard_ids = parse_shard_ids(args.shard_ids, cfg.loading_params.get_num_shards())
job_id = submit_slurm_job(cfg, config_path, shard_ids)
job_id = submit_slurm_job(cfg, config_path, shard_ids, collect_stats=args.stats)
if job_id is None:
return 1

Expand Down Expand Up @@ -426,6 +459,7 @@ def handle_check(args: argparse.Namespace, cfg: MMirageConfig, config_path: str)
config_path=config_path,
failed_shards=failed_shards,
confirm_mode=args.confirm_mode,
collect_stats=args.stats,
)


Expand Down Expand Up @@ -458,9 +492,26 @@ def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str)
config_path=config_path,
failed_shards=failed_shards,
confirm_mode=args.confirm_mode,
collect_stats=args.stats,
)


def handle_stats(args: argparse.Namespace, cfg: MMirageConfig, _config_path: str) -> int:
"""Print per-shard benchmark statistics and aggregate totals.

Args:
args: Parsed CLI namespace.
cfg: Parsed MMIRAGE configuration object.
_config_path: Absolute path to the MMIRAGE YAML config file (not needed here).

Returns:
Exit code: 0 always (stats are informational).
"""
report = collect_bench_stats(cfg)
print(json.dumps(report, indent=2))
return 0


def handle_merge(args: argparse.Namespace, cfg: MMirageConfig, _config_path: str) -> int:
"""Merge shard outputs defined in config.loading_params.datasets.

Expand Down Expand Up @@ -513,6 +564,7 @@ def main() -> None:
"check": handle_check,
"retry": handle_retry,
"merge": handle_merge,
"stats": handle_stats,
}
handler = handlers.get(args.command)
if handler is None:
Expand Down
Loading
Loading