Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,117 @@ Key multimodal features:
- `image_base_path`: Base directory for resolving relative image paths
- Supports PIL Images, URLs, and file paths

### Benchmarking shard performance

Pass `--stats` to `run` or `submit` to enable per-shard benchmarking. This activates GPU
utilization polling and throughput tracking on compute nodes — disabled by default to
avoid unnecessary overhead.

```bash
# Local run with stats collection
mmirage run --config configs/config_mock.yaml --stats

```

After the run completes, inspect the results with:

```bash
mmirage stats --config configs/config_mock.yaml
```

This prints a JSON report with per-shard details and an aggregate summary:

```json
{
"per_shard": [
{
"shard_id": 0,
"status": "success",
"started_at": "2026-04-30T10:00:00",
"finished_at": "2026-04-30T10:01:05",
"stats": {
"runtime_seconds": 65.2,
"runtime_human": "1m 5s",
"rows_processed": 1024,
"throughput_rows_per_sec": 15.7,
"gpu_util_mean": 88.4,
"gpu_util_min": 72.0,
"gpu_util_max": 98.0,
"gpu_util_samples": 13,
"input_tokens": 512000,
"output_tokens": 196608,
"num_gpus": 4,
"tokens_per_sec_per_gpu": 753.1,
"gpu_days_per_billion_tokens": 0.0015
}
}
],
"aggregate": {
"total_shards": 1,
"completed_shards": 1,
"total_rows_processed": 1000,
"wall_clock_runtime_seconds": 133.04,
"wall_clock_runtime_human": "2m 13s",
"sum_shard_runtime_seconds": 133.04,
"sum_shard_runtime_human": "2m 13s",
"min_shard_runtime_seconds": 133.04,
"min_shard_runtime_human": "2m 13s",
"max_shard_runtime_seconds": 133.04,
"max_shard_runtime_human": "2m 13s",
"overall_throughput_rows_per_sec": 7.52,
"mean_gpu_util_pct": 86.2,
"num_gpus": 4,
"total_input_tokens": 146214,
"total_output_tokens": 1022046,
"sum_model_load_seconds": 38.272,
"sum_inference_runtime_seconds": 94.768,
"tokens_per_sec_per_gpu": 10784.72,
"gpu_days_per_billion_tokens": 1.0732
Comment thread
qchapp marked this conversation as resolved.
}
}
```

Key metrics:
- **`runtime_seconds`** / **`runtime_human`**: time from when the shard started on the cluster (after dispatch), excluding queue wait time.
- **`overall_throughput_rows_per_sec`**: total rows / wall-clock time across all shards running in parallel.
- **`mean_gpu_util_pct`**: mean percentage GPU utilization across shards.
- **`tokens_per_sec_per_gpu`**: output tokens generated per second per GPU — the primary throughput metric used by frameworks such as [DataTrove](https://github.com/huggingface/datatrove).
- **`gpu_days_per_billion_tokens`**: total GPU-days consumed to generate 1 billion output tokens — useful for cost and scaling comparisons across different hardware configurations.
- Token metrics are `null` when no LLM processor was active, and GPU stats are `null` when `nvidia-smi` is unavailable or `--stats` was not passed.

Reference benchmark:
- [DataTrove Benchmark](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark)

The config `configs/config_benchmark_datatrove.yaml` mirrors the DataTrove inference benchmark conditions:

| Setting | Value |
|---|---|
| Dataset | `simplescaling/s1K-1.1` (train split, 1 000 samples) |
| Prompt | raw `question` field, no system prompt |
| Output | up to 1 024 tokens per sample |
| Context | 2 048-token model max context |
| Model | `Qwen/Qwen3-4B` (DataTrove baseline: tp=1 on a single GPU) |

Download the dataset before running:

```python
from datasets import load_dataset
ds = load_dataset('simplescaling/s1K-1.1', split='train')
ds.save_to_disk('data/s1K-1.1')
```

Then run with stats collection enabled:

```bash
mmirage run --config configs/config_benchmark_datatrove.yaml --stats
```

Inspect results:

```bash
mmirage stats --config configs/config_benchmark_datatrove.yaml
```

## Architecture

MMIRAGE uses a modular architecture:
Expand All @@ -258,3 +369,4 @@ mmirage/
- JMESPath for JSON queries: [link](https://jmespath.org/)
- SGLang for fast inference: [link](https://github.com/sgl-project/sglang)
- Performance paper: [link](https://arxiv.org/abs/2408.02442)
- DataTrove Benchmark: [link](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark)
63 changes: 63 additions & 0 deletions configs/config_benchmark_datatrove.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# MMIRAGE — DataTrove-compatible throughput benchmark
# See README.md for setup instructions and benchmark details.

processors:
- type: llm
server_args:
model_path: Qwen/Qwen3-4B # same model family as DataTrove baseline
tp_size: 1 # DataTrove baseline: tp=1
trust_remote_code: true
disable_custom_all_reduce: true
# SGLang engine tuning — equivalents of DataTrove's vLLM mns/mnbt knobs
extra_engine_args:
max_running_requests: 1000
default_sampling_params:
temperature: 0.0
max_new_tokens: 1024 # DataTrove: max-tokens=1024

loading_params:
state_dir: data/benchmark_s1k/_pipeline_state
datasets:
- path: data/s1K-1.1 # save_to_disk() target above
type: loadable
output_dir: data/benchmark_s1k/output
num_shards: 1
shard_id: "$SLURM_ARRAY_TASK_ID"
batch_size: 1000

processing_params:
inputs:
- name: question
key: question # DataTrove: prompt-column=question

outputs:
- name: answer
type: llm
output_type: plain
# Qwen3 thinking is disabled by embedding an empty <think> block in the prompt.
# This is equivalent to passing enable_thinking=False to the chat template and
# avoids any dependency on SGLang sampling-param support for that flag.
prompt: "<|im_start|>user\n{{ question }}\n<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n"

remove_columns: false
output_schema:
question: "{{ question }}"
answer: "{{ answer }}"

execution_params:
mode: slurm
retry: false
merge: false
max_retries: 3
account: a127
job_name: mmirage-sharded
nodes: 1
ntasks_per_node: 1
gpus: 1
cpus_per_task: 288
time_limit: "11:59:59"
report_dir: "/users/${USER}/reports"
hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf"
edf_env: "/users/${USER}/.edf/mmirage.toml"
poll_interval_seconds: 30
settle_time_seconds: 60
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"jmespath",
"jinja2>=3.0.0",
"pillow>=9.0.0",
"humanize>=4.0.0",
]

[project.optional-dependencies]
Expand Down
Loading
Loading