8686
8787from benchmarks .benchmarker .runner import BenchmarkRunner , RunConfig
8888from benchmarks .benchmarker .utils import wait_for_service
89- from benchmarks .dataset .mmsu import load_mmsu_samples
89+ from benchmarks .dataset .mmsu import MmsuSample , load_mmsu_samples
9090from benchmarks .metrics .performance import compute_speed_metrics
91- from benchmarks .tasks .mmsu import (
91+ from benchmarks .tasks .audio_understanding import (
9292 build_mmsu_results ,
9393 compute_mmsu_metrics ,
9494 make_mmsu_send_fn ,
9595 print_mmsu_summary ,
9696 save_mmsu_results ,
9797)
98+ from benchmarks .tasks .tts import compute_text_audio_consistency , print_wer_summary
9899
99100logging .basicConfig (
100101 level = logging .INFO ,
101102 format = "%(asctime)s %(name)s %(levelname)s %(message)s" ,
102103)
103104
104105
105- async def run (args : argparse .Namespace ) -> dict :
106+ async def run (
107+ args : argparse .Namespace ,
108+ * ,
109+ samples : list [MmsuSample ] | None = None ,
110+ ) -> dict :
106111 base_url = args .base_url or f"http://{ args .host } :{ args .port } "
107112 api_url = f"{ base_url } /v1/chat/completions"
108113 modalities = ["text" , "audio" ] if args .modalities == "text+audio" else ["text" ]
109114
110- samples = load_mmsu_samples (
111- max_samples = args .max_samples ,
112- task_names = args .task_names .split ("," ) if args .task_names else None ,
113- categories = args .categories .split ("," ) if args .categories else None ,
114- seed = args .seed ,
115- )
115+ if samples is None :
116+ samples = load_mmsu_samples (
117+ max_samples = args .max_samples ,
118+ task_names = args .task_names .split ("," ) if args .task_names else None ,
119+ categories = args .categories .split ("," ) if args .categories else None ,
120+ seed = args .seed ,
121+ repo_id = args .repo_id ,
122+ )
116123
117124 save_audio_dir = None
118125 if args .save_audio and args .output_dir :
@@ -150,6 +157,17 @@ async def run(args: argparse.Namespace) -> dict:
150157
151158 print_mmsu_summary (metrics , args .model , speed_metrics = speed )
152159
160+ output : dict = {"accuracy" : metrics , "speed" : speed }
161+ wer_results = None
162+ if audio_mode :
163+ wer_results = compute_text_audio_consistency (
164+ request_results ,
165+ args .lang ,
166+ args .asr_device ,
167+ )
168+ output ["wer" ] = wer_results
169+ print_wer_summary (wer_results ["summary" ], args .model )
170+
153171 if args .output_dir :
154172 save_mmsu_results (
155173 results ,
@@ -165,9 +183,10 @@ async def run(args: argparse.Namespace) -> dict:
165183 },
166184 args .output_dir ,
167185 speed_metrics = speed ,
186+ wer_metrics = wer_results ,
168187 )
169188
170- return { "accuracy" : metrics , "speed" : speed }
189+ return output
171190
172191
173192def main () -> None :
@@ -190,6 +209,19 @@ def main() -> None:
190209 p .add_argument ("--save-audio" , action = "store_true" )
191210 p .add_argument ("--disable-tqdm" , action = "store_true" )
192211 p .add_argument ("--seed" , type = int , default = None )
212+ p .add_argument (
213+ "--repo-id" ,
214+ type = str ,
215+ default = None ,
216+ help = "HuggingFace dataset repo (e.g. 'zhaochenyang20/mmsu-ci-2000'). "
217+ "Defaults to loading the full ddwang2000/MMSU (train split)." ,
218+ )
219+ p .add_argument (
220+ "--lang" , type = str , default = "en" , help = "Language for ASR WER evaluation"
221+ )
222+ p .add_argument (
223+ "--asr-device" , type = str , default = "cuda:0" , help = "Device for ASR model"
224+ )
193225
194226 args = p .parse_args ()
195227 wait_for_service (args .base_url or f"http://{ args .host } :{ args .port } " )
0 commit comments