Skip to content

Commit 922b9aa

Browse files
[CI] Add MMSU CI for Qwen3 Omni (Stage 5 + 6) (#298)
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
1 parent 3ed40e3 commit 922b9aa

11 files changed

Lines changed: 494 additions & 74 deletions

File tree

.github/workflows/test-qwen3-omni-ci.yaml

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,88 @@ jobs:
181181
shell: bash
182182
run: |
183183
bash .github/scripts/delete_gpu_process.sh
184+
185+
stage-4-mmsu:
186+
name: stage 4 - MMSU accuracy + speed
187+
needs: stage-3-mmmu-tts-consistency
188+
runs-on: [self-hosted]
189+
timeout-minutes: 20
190+
container:
191+
image: frankleeeee/sglang-omni:dev
192+
options: --gpus all --rm -v /dev/shm:/dev/shm
193+
steps:
194+
- name: Checkout code
195+
uses: actions/checkout@v4
196+
197+
- uses: ./.github/actions/omni-setup
198+
with:
199+
venv-name: omni-qwen3
200+
201+
- name: Run MMSU CI (accuracy + speed)
202+
shell: bash
203+
run: |
204+
source omni-qwen3/bin/activate
205+
export PYTHONPATH=$PWD
206+
pytest tests/test_model/test_qwen3_omni_mmsu_ci.py -v -s -x
207+
env:
208+
HF_ENDPOINT: https://hf-mirror.com
209+
210+
- name: Print MMSU CI artifacts (accuracy + speed)
211+
if: always()
212+
shell: bash
213+
run: |
214+
source omni-qwen3/bin/activate
215+
echo "=== Qwen3-Omni MMSU CI results (summary only) ==="
216+
for f in $(find /tmp -path '*/mmsu/mmsu_results.json' 2>/dev/null); do
217+
echo "--- $f ---"
218+
python -c "import json,sys; d=json.load(open(sys.argv[1])); d.pop('per_sample',None); print(json.dumps(d, indent=2, ensure_ascii=False))" "$f"
219+
echo ""
220+
done
221+
222+
- name: Kill GPU processes
223+
if: always()
224+
shell: bash
225+
run: |
226+
bash .github/scripts/delete_gpu_process.sh
227+
228+
stage-5-mmsu-tts-consistency:
229+
name: stage 5 - MMSU TTS consistency
230+
needs: stage-4-mmsu
231+
runs-on: [self-hosted]
232+
timeout-minutes: 15
233+
container:
234+
image: frankleeeee/sglang-omni:dev
235+
options: --gpus all --rm -v /dev/shm:/dev/shm
236+
steps:
237+
- name: Checkout code
238+
uses: actions/checkout@v4
239+
240+
- uses: ./.github/actions/omni-setup
241+
with:
242+
venv-name: omni-qwen3
243+
244+
- name: Run MMSU TTS Consistency CI (WER + speed)
245+
shell: bash
246+
run: |
247+
source omni-qwen3/bin/activate
248+
export PYTHONPATH=$PWD
249+
pytest tests/test_model/test_qwen3_omni_mmsu_tts_consistency_ci.py -v -s -x
250+
env:
251+
HF_ENDPOINT: https://hf-mirror.com
252+
253+
- name: Print MMSU TTS Consistency CI artifacts (WER + speed)
254+
if: always()
255+
shell: bash
256+
run: |
257+
echo "=== Qwen3-Omni MMSU TTS Consistency CI results ==="
258+
for f in $(find /tmp -path '*/mmsu_audio/mmsu_results.json' 2>/dev/null); do
259+
echo "--- $f ---"
260+
cat "$f"
261+
echo ""
262+
done
263+
264+
- name: Kill GPU processes
265+
if: always()
266+
shell: bash
267+
run: |
268+
bash .github/scripts/delete_gpu_process.sh

benchmarks/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ and accuracy (WER, MMSU, MMMU) across supported modality combinations.
77

88
```
99
benchmarks/
10-
├── tasks/ # Per-task logic (tts, mmsu, visual_understand)
10+
├── tasks/ # Per-task logic (tts, audio_understanding, visual_understand)
1111
├── metrics/ # Metric computation (performance, accuracy)
1212
├── dataset/ # Dataset loaders + download helpers
1313
├── benchmarker/ # Framework: runner, data structures, utilities

benchmarks/dataset/mmsu.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55

66
import random
77
import re
8+
import tempfile
89
from dataclasses import dataclass
910
from pathlib import Path
1011

12+
from datasets import Audio, load_dataset
13+
1114

1215
@dataclass
1316
class MmsuSample:
@@ -50,13 +53,19 @@ def load_mmsu_samples(
5053
task_names: list[str] | None = None,
5154
categories: list[str] | None = None,
5255
seed: int | None = None,
56+
*,
57+
repo_id: str | None = None,
5358
) -> list[MmsuSample]:
54-
"""Load MMSU samples from HuggingFace dataset ``ddwang2000/MMSU``."""
55-
import tempfile
59+
"""Load MMSU samples.
60+
5661
57-
from datasets import Audio, load_dataset
62+
Note (Yifei, Chenyang):
63+
repo_id defaults to None which loads the full ddwang2000/MMSU
64+
(train split, ~5000 samples). zhaochenyang20/mmsu-ci-2000 to
65+
load our pre-built subset for CI.
66+
"""
5867

59-
ds = load_dataset("ddwang2000/MMSU")
68+
ds = load_dataset(repo_id or "ddwang2000/MMSU")
6069
assert list(ds.keys()) == [
6170
"train"
6271
], f"Expected only 'train' split, got {list(ds.keys())}"

benchmarks/dataset/prepare.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"mmmu": "MMMU/MMMU",
3030
"mmmu-ci-50": "zhaochenyang20/mmmu-ci-50",
3131
"mmsu": "ddwang2000/MMSU",
32+
"mmsu-ci-2000": "zhaochenyang20/mmsu-ci-2000",
3233
}
3334

3435
_CLI_LOCAL_DIRS: dict[str, str] = {

benchmarks/eval/benchmark_omni_mmmu.py

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,9 @@
6666
from benchmarks.dataset.mmmu import load_mmmu_samples
6767
from benchmarks.metrics.performance import compute_speed_metrics
6868
from benchmarks.tasks.tts import (
69-
SampleOutput,
70-
calculate_wer_metrics,
71-
load_asr_model,
69+
compute_text_audio_consistency,
7270
print_speed_summary,
7371
print_wer_summary,
74-
transcribe_and_compute_wer,
7572
)
7673
from benchmarks.tasks.visual_understand import (
7774
compute_mmmu_metrics,
@@ -170,70 +167,16 @@ async def run_mmmu_eval(config: MMMUEvalConfig) -> dict:
170167
}
171168

172169
if config.enable_audio:
173-
wer_results = _compute_audio_wer(
170+
results["wer"] = compute_text_audio_consistency(
174171
request_results, config.lang, config.asr_device
175172
)
176-
results["wer"] = wer_results
177173

178174
if config.output_dir:
179175
save_json_results(results, config.output_dir, "mmmu_results.json")
180176

181177
return results
182178

183179

184-
def _compute_audio_wer(
185-
request_results: list,
186-
lang: str,
187-
asr_device: str,
188-
) -> dict:
189-
"""Transcribe audio outputs with ASR and compute WER against text outputs.
190-
191-
Text output is the reference; ASR transcription of the audio is the
192-
hypothesis. Returns a dict with summary and per_sample keys.
193-
"""
194-
asr = load_asr_model(lang, asr_device)
195-
196-
outputs: list[SampleOutput] = []
197-
for result in request_results:
198-
199-
ref_text = " ".join(result.text.split())
200-
output = SampleOutput(
201-
sample_id=result.request_id,
202-
target_text=ref_text,
203-
latency_s=result.latency_s,
204-
audio_duration_s=result.audio_duration_s,
205-
)
206-
207-
if not result.is_success or not result.wav_path:
208-
output.error = result.error or "No audio in response"
209-
outputs.append(output)
210-
continue
211-
212-
output = transcribe_and_compute_wer(
213-
output, result.wav_path, asr, lang, asr_device
214-
)
215-
outputs.append(output)
216-
217-
wer_summary = calculate_wer_metrics(outputs, lang)
218-
219-
per_sample = [
220-
{
221-
"id": o.sample_id,
222-
"is_success": o.is_success,
223-
"wer": o.wer if o.is_success else None,
224-
"ref_text": o.target_text[:100],
225-
"hyp_text": o.whisper_text[:100],
226-
"ref_norm": o.ref_norm,
227-
"hyp_norm": o.hyp_norm,
228-
"audio_duration_s": o.audio_duration_s,
229-
"error": o.error,
230-
}
231-
for o in outputs
232-
]
233-
234-
return {"summary": wer_summary, "per_sample": per_sample}
235-
236-
237180
def _config_from_args(args: argparse.Namespace) -> MMMUEvalConfig:
238181
return MMMUEvalConfig(
239182
base_url=args.base_url,

benchmarks/eval/benchmark_omni_mmsu.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,33 +86,40 @@
8686

8787
from benchmarks.benchmarker.runner import BenchmarkRunner, RunConfig
8888
from benchmarks.benchmarker.utils import wait_for_service
89-
from benchmarks.dataset.mmsu import load_mmsu_samples
89+
from benchmarks.dataset.mmsu import MmsuSample, load_mmsu_samples
9090
from benchmarks.metrics.performance import compute_speed_metrics
91-
from benchmarks.tasks.mmsu import (
91+
from benchmarks.tasks.audio_understanding import (
9292
build_mmsu_results,
9393
compute_mmsu_metrics,
9494
make_mmsu_send_fn,
9595
print_mmsu_summary,
9696
save_mmsu_results,
9797
)
98+
from benchmarks.tasks.tts import compute_text_audio_consistency, print_wer_summary
9899

99100
logging.basicConfig(
100101
level=logging.INFO,
101102
format="%(asctime)s %(name)s %(levelname)s %(message)s",
102103
)
103104

104105

105-
async def run(args: argparse.Namespace) -> dict:
106+
async def run(
107+
args: argparse.Namespace,
108+
*,
109+
samples: list[MmsuSample] | None = None,
110+
) -> dict:
106111
base_url = args.base_url or f"http://{args.host}:{args.port}"
107112
api_url = f"{base_url}/v1/chat/completions"
108113
modalities = ["text", "audio"] if args.modalities == "text+audio" else ["text"]
109114

110-
samples = load_mmsu_samples(
111-
max_samples=args.max_samples,
112-
task_names=args.task_names.split(",") if args.task_names else None,
113-
categories=args.categories.split(",") if args.categories else None,
114-
seed=args.seed,
115-
)
115+
if samples is None:
116+
samples = load_mmsu_samples(
117+
max_samples=args.max_samples,
118+
task_names=args.task_names.split(",") if args.task_names else None,
119+
categories=args.categories.split(",") if args.categories else None,
120+
seed=args.seed,
121+
repo_id=args.repo_id,
122+
)
116123

117124
save_audio_dir = None
118125
if args.save_audio and args.output_dir:
@@ -150,6 +157,17 @@ async def run(args: argparse.Namespace) -> dict:
150157

151158
print_mmsu_summary(metrics, args.model, speed_metrics=speed)
152159

160+
output: dict = {"accuracy": metrics, "speed": speed}
161+
wer_results = None
162+
if audio_mode:
163+
wer_results = compute_text_audio_consistency(
164+
request_results,
165+
args.lang,
166+
args.asr_device,
167+
)
168+
output["wer"] = wer_results
169+
print_wer_summary(wer_results["summary"], args.model)
170+
153171
if args.output_dir:
154172
save_mmsu_results(
155173
results,
@@ -165,9 +183,10 @@ async def run(args: argparse.Namespace) -> dict:
165183
},
166184
args.output_dir,
167185
speed_metrics=speed,
186+
wer_metrics=wer_results,
168187
)
169188

170-
return {"accuracy": metrics, "speed": speed}
189+
return output
171190

172191

173192
def main() -> None:
@@ -190,6 +209,19 @@ def main() -> None:
190209
p.add_argument("--save-audio", action="store_true")
191210
p.add_argument("--disable-tqdm", action="store_true")
192211
p.add_argument("--seed", type=int, default=None)
212+
p.add_argument(
213+
"--repo-id",
214+
type=str,
215+
default=None,
216+
help="HuggingFace dataset repo (e.g. 'zhaochenyang20/mmsu-ci-2000'). "
217+
"Defaults to loading the full ddwang2000/MMSU (train split).",
218+
)
219+
p.add_argument(
220+
"--lang", type=str, default="en", help="Language for ASR WER evaluation"
221+
)
222+
p.add_argument(
223+
"--asr-device", type=str, default="cuda:0", help="Device for ASR model"
224+
)
193225

194226
args = p.parse_args()
195227
wait_for_service(args.base_url or f"http://{args.host}:{args.port}")
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ class MmsuResult:
182182
raw_response: str = ""
183183
is_correct: bool = False
184184
is_parseable: bool = False
185+
is_success: bool = False
185186
latency_s: float = 0.0
186187
has_audio: bool = False
187188
audio_duration_s: float = 0.0
@@ -294,6 +295,7 @@ def build_mmsu_results(
294295
raw_response=request_result.text,
295296
is_correct=index_match or text_match,
296297
is_parseable=predicted_index is not None or bool(predicted_answer),
298+
is_success=bool(request_result.is_success),
297299
latency_s=request_result.latency_s,
298300
error=request_result.error,
299301
)
@@ -310,12 +312,15 @@ def build_mmsu_results(
310312
def compute_mmsu_metrics(results: list[MmsuResult]) -> dict[str, Any]:
311313
total = len(results)
312314
parseable = sum(1 for result in results if result.is_parseable)
315+
successful = sum(1 for result in results if result.is_success)
313316
correct = sum(1 for result in results if result.is_correct)
314317

315318
return {
316319
"total_samples": total,
317320
"parseable_samples": parseable,
318321
"unparseable_samples": total - parseable,
322+
"successful_samples": successful,
323+
"failed_samples": total - successful,
319324
"correct": correct,
320325
"incorrect": total - correct,
321326
"overall_accuracy": round(correct / total, 4) if total else 0.0,
@@ -340,6 +345,9 @@ def print_mmsu_summary(
340345
print(f" MMSU Results - {model_name}")
341346
print("=" * 60)
342347
print(f" Total samples: {metrics['total_samples']}")
348+
print(
349+
f" Successful: {metrics.get('successful_samples', metrics['total_samples'])}"
350+
)
343351
print(f" Parseable: {metrics['parseable_samples']}")
344352
print(f" Correct: {metrics['correct']}")
345353
print(f" Overall accuracy: {metrics['overall_accuracy']:.2%}")
@@ -359,6 +367,7 @@ def print_mmsu_summary(
359367
if speed_metrics.get("rtf_mean") is not None:
360368
print(f" RTF mean: {speed_metrics.get('rtf_mean', 0):.4f}")
361369
print(f" Throughput: {speed_metrics.get('throughput_qps', 0):.2f} req/s")
370+
print(f" Tok/s agg: {speed_metrics.get('tok_per_s_agg', 0):.2f}")
362371
audio_returned = speed_metrics.get("audio_returned")
363372
audio_expected = speed_metrics.get("audio_expected")
364373
if audio_expected:
@@ -373,6 +382,7 @@ def save_mmsu_results(
373382
output_dir: str,
374383
*,
375384
speed_metrics: dict[str, Any] | None = None,
385+
wer_metrics: dict[str, Any] | None = None,
376386
) -> None:
377387
summary_output = {
378388
"summary": metrics,
@@ -381,6 +391,8 @@ def save_mmsu_results(
381391
}
382392
if speed_metrics:
383393
summary_output["speed_metrics"] = speed_metrics
394+
if wer_metrics:
395+
summary_output["wer"] = wer_metrics
384396

385397
save_json_results(summary_output, output_dir, "mmsu_results.json")
386398

0 commit comments

Comments
 (0)