Skip to content

Commit 7946278

Browse files
committed
implemented changes requested by fabrice
1 parent dfa7fae commit 7946278

6 files changed

Lines changed: 50 additions & 48 deletions

File tree

README.md

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,36 @@ Key metrics:
315315

316316
Reference benchmark:
317317
- [DataTrove Benchmark](https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark)
318-
- `mmirage run --config configs/config_benchmark_datatrove.yaml --stats`
318+
319+
The config `configs/config_benchmark_datatrove.yaml` mirrors the DataTrove inference benchmark conditions:
320+
321+
| Setting | Value |
322+
|---|---|
323+
| Dataset | `simplescaling/s1K-1.1` (train split, 1 000 samples) |
324+
| Prompt | raw `question` field, no system prompt |
325+
| Output | up to 1 024 tokens per sample |
326+
| Context | 2 048-token model max context |
327+
| Model | `Qwen/Qwen3-4B` (DataTrove baseline: tp=1 on a single GPU) |
328+
329+
Download the dataset before running:
330+
331+
```python
332+
from datasets import load_dataset
333+
ds = load_dataset('simplescaling/s1K-1.1', split='train')
334+
ds.save_to_disk('data/s1K-1.1')
335+
```
336+
337+
Then run with stats collection enabled:
338+
339+
```bash
340+
mmirage run --config configs/config_benchmark_datatrove.yaml --stats
341+
```
342+
343+
Inspect results:
344+
345+
```bash
346+
mmirage stats --config configs/config_benchmark_datatrove.yaml
347+
```
319348

320349
## Architecture
321350

configs/config_benchmark_datatrove.yaml

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,5 @@
11
# MMIRAGE — DataTrove-compatible throughput benchmark
2-
#
3-
# Mirrors the conditions used in the DataTrove inference benchmark
4-
# (https://github.com/huggingface/datatrove/tree/main/examples/inference/benchmark):
5-
#
6-
# dataset : simplescaling/s1K-1.1 (train split, 1 000 samples)
7-
# prompt : raw `question` field, no system prompt
8-
# output : up to 1 024 tokens per sample
9-
# context : 2 048-token model max context
10-
# model : Qwen/Qwen3-4B (DataTrove baseline: tp=1 on a single GPU)
11-
#
12-
# Download the dataset before running:
13-
#
14-
# python -c "
15-
# from datasets import load_dataset
16-
# ds = load_dataset('simplescaling/s1K-1.1', split='train')
17-
# ds.save_to_disk('data/s1K-1.1')
18-
# "
19-
#
20-
# Then run with stats collection enabled:
21-
#
22-
# mmirage run --config configs/config_benchmark_datatrove.yaml --stats
23-
#
24-
# Inspect results:
25-
#
26-
# mmirage stats --config configs/config_benchmark_datatrove.yaml
2+
# See README.md for setup instructions and benchmark details.
273

284
processors:
295
- type: llm

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
"jmespath",
3939
"jinja2>=3.0.0",
4040
"pillow>=9.0.0",
41+
"humanize>=4.0.0",
4142
]
4243

4344
[project.optional-dependencies]

src/mmirage/core/process/mapper.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
"""Mapper for orchestrating variable transformations."""
22

3+
from dataclasses import dataclass
34
from typing import Dict, Any, List, cast
45

56
from mmirage.core.process.variables import BaseVar, InputVar, OutputVar
67
from mmirage.core.process.base import AutoProcessor, BaseProcessor, BaseProcessorConfig
78

9+
10+
@dataclass
11+
class TokenCounts:
12+
"""Cumulative token counts from LLM processors."""
13+
14+
input_tokens: int
15+
output_tokens: int
16+
817
import logging
918

1019
from mmirage.core.process.variables import VariableEnvironment
@@ -104,14 +113,14 @@ def rewrite_batch(
104113

105114
return batch_environment
106115

107-
def get_token_counts(self) -> Dict[str, int]:
116+
def get_token_counts(self) -> TokenCounts:
108117
"""Return cumulative token counts aggregated across all LLM processors.
109118
110119
Sums ``input_tokens`` and ``output_tokens`` from every processor that
111120
exposes a ``get_token_counts()`` method (i.e., ``LLMProcessor``).
112121
113122
Returns:
114-
Dict with ``input_tokens`` and ``output_tokens`` keys.
123+
TokenCounts with ``input_tokens`` and ``output_tokens`` fields.
115124
"""
116125
total_input = 0
117126
total_output = 0
@@ -120,7 +129,7 @@ def get_token_counts(self) -> Dict[str, int]:
120129
counts = proc.get_token_counts()
121130
total_input += counts.get("input_tokens", 0)
122131
total_output += counts.get("output_tokens", 0)
123-
return {"input_tokens": total_input, "output_tokens": total_output}
132+
return TokenCounts(input_tokens=total_input, output_tokens=total_output)
124133

125134
def get_load_time(self) -> float:
126135
"""Return total model-loading time (seconds) summed across all LLM processors."""

src/mmirage/shard_process.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def main():
109109
all_visible = [x.strip() for x in cuda_visible.split(",") if x.strip()]
110110
# Fall back to range-based indices if CUDA_VISIBLE_DEVICES was set
111111
# but contained only whitespace/empty entries after stripping.
112-
gpu_indices_for_polling: Optional[List[str]] = all_visible[:tp_size] if all_visible else [str(i) for i in range(tp_size)]
112+
gpu_indices_for_polling: List[str] = all_visible[:tp_size] if all_visible else [str(i) for i in range(tp_size)]
113113
else:
114114
gpu_indices_for_polling = [str(i) for i in range(tp_size)]
115115
gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller(
@@ -184,8 +184,8 @@ def main():
184184

185185
# Collect token counts accumulated by LLM processor(s).
186186
token_counts = mapper.get_token_counts()
187-
input_tokens = token_counts["input_tokens"] or None
188-
output_tokens = token_counts["output_tokens"] or None
187+
input_tokens = token_counts.input_tokens or None
188+
output_tokens = token_counts.output_tokens or None
189189
model_load_seconds = mapper.get_load_time() or None
190190

191191
# Resolve num_gpus from the first processor config that exposes tp_size.

src/mmirage/shard_utils.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from datetime import datetime
88
from dataclasses import dataclass
9+
import humanize
910
import json
1011
import logging
1112
import os
@@ -24,24 +25,10 @@
2425

2526

2627
def format_duration(seconds: Optional[float]) -> Optional[str]:
27-
"""Format a duration given in seconds as a human-readable string.
28-
29-
Examples::
30-
31-
format_duration(45.3) -> "45s"
32-
format_duration(125.0) -> "2m 5s"
33-
format_duration(3725.0) -> "1h 2m 5s"
34-
"""
28+
"""Format a duration given in seconds as a human-readable string."""
3529
if seconds is None:
3630
return None
37-
total = int(seconds)
38-
hours, remainder = divmod(total, 3600)
39-
minutes, secs = divmod(remainder, 60)
40-
if hours:
41-
return f"{hours}h {minutes}m {secs}s"
42-
if minutes:
43-
return f"{minutes}m {secs}s"
44-
return f"{secs}s"
31+
return humanize.precisedelta(seconds)
4532

4633

4734
@dataclass

0 commit comments

Comments
 (0)