5151)
5252from inference_endpoint .endpoint_client .http_client import HTTPEndpointClient
5353from inference_endpoint .endpoint_client .http_sample_issuer import HttpClientSampleIssuer
54+ from inference_endpoint .evaluation .extractor import ABCDExtractor , BoxedMathExtractor
55+ from inference_endpoint .evaluation .scoring import PassAt1Scorer
5456from inference_endpoint .load_generator import (
5557 BenchmarkSession ,
5658 MaxThroughputScheduler ,
6264
6365# Configuration for SGLang server
6466SGLANG_SERVER_HOST = "localhost"
65- SGLANG_SERVER_PORT = 3000
67+ SGLANG_SERVER_PORT = 30000
6668SGLANG_ENDPOINT = f"http://{ SGLANG_SERVER_HOST } :{ SGLANG_SERVER_PORT } /generate"
6769
6870
@@ -166,7 +168,7 @@ def create_transforms() -> list:
166168 "max_new_tokens" : 32768 ,
167169 "temperature" : 1.0 ,
168170 "top_p" : 1.0 ,
169- "tok_k " : - 1 ,
171+ "top_k " : - 1 ,
170172 }
171173 ),
172174 ]
@@ -247,7 +249,9 @@ def num_samples(self):
247249 return 0
248250
249251
250- def run_benchmark_session (dataset : Dataset , issuer : HttpClientSampleIssuer , args ):
252+ def run_benchmark_session (
253+ accuracy_datasets : list [Dataset ], issuer : HttpClientSampleIssuer , args
254+ ):
251255 """Run a benchmark session with the SGLang endpoint.
252256
253257 Args:
@@ -276,7 +280,9 @@ def run_benchmark_session(dataset: Dataset, issuer: HttpClientSampleIssuer, args
276280 scheduler = MaxThroughputScheduler (rt_settings , WithoutReplacementSampleOrder )
277281
278282 # Run the benchmark session
279- n_total = dataset .num_samples () * dataset .repeats
283+ n_total = sum (
284+ [dataset .num_samples () * dataset .repeats for dataset in accuracy_datasets ]
285+ )
280286
281287 with tqdm (desc = "GPQA Benchmark" , total = n_total , unit = "samples" ) as pbar :
282288 pbar_hook .set_pbar (pbar )
@@ -285,65 +291,88 @@ def run_benchmark_session(dataset: Dataset, issuer: HttpClientSampleIssuer, args
285291 EmptyDataset (),
286292 issuer ,
287293 scheduler ,
288- accuracy_datasets = [ dataset ] ,
289- name = "gpqa_sglang_benchmark " ,
294+ accuracy_datasets = accuracy_datasets ,
295+ name = "gpqa_aime25_sglang_benchmark " ,
290296 report_dir = args .report_dir ,
291297 dump_events_log = True ,
292298 max_shutdown_timeout_s = None ,
293299 )
294300 sess .wait_for_test_end ()
295301
302+ # Create the scorer
303+ scorer = PassAt1Scorer (
304+ GPQA .DATASET_ID ,
305+ accuracy_datasets [0 ],
306+ args .report_dir ,
307+ extractor = ABCDExtractor ,
308+ )
309+
310+ # Score the dataset
311+ score , n_repeats = scorer .score ()
312+ print (f"Pass@1 Score ({ n_repeats } repeats): { score } " )
313+
314+ scorer = PassAt1Scorer (
315+ AIME25 .DATASET_ID ,
316+ accuracy_datasets [1 ],
317+ args .report_dir ,
318+ extractor = BoxedMathExtractor ,
319+ ground_truth_column = "answer" ,
320+ )
321+
322+ # Score the dataset
323+ score , n_repeats = scorer .score ()
324+ print (f"Pass@1 Score ({ n_repeats } repeats): { score } " )
325+
296326
297327def run_main (args ):
298328 """Main function to run the example."""
299329 # Setup paths
300330 tmp_dir = Path ("/tmp/sglang_manual_example" )
301331 tmp_dir .mkdir (parents = True , exist_ok = True )
332+ num_repeats = args .num_repeats
302333
303334 try :
304- if False :
305- print ("Generating GPQA diamond dataset..." )
306- df = generate_gpqa_dataset (
307- datasets_dir = "datasets" ,
308- force = args .force_regenerate ,
309- )
310- print (f"Loaded { len (df )} samples from GPQA diamond" )
311-
312- # Step 2: Create transforms
313- print ("Creating transforms..." )
314- transforms = create_transforms ()
315-
316- # Step 3: Create Dataset with transforms (transforms will be applied during load())
317- print ("Creating dataset with transforms..." )
318- print (df .columns )
319- df .to_parquet ("datasets/gqpa_diamond_pre-transformed_gpt-oss.parquet" )
320- dataset = GPQA (
321- df , transforms = transforms , repeats = 5
322- ) # Artificial Analysis uses 5 repeats
323- dataset .load ()
324- else :
325- print ("Generating AIME25 dataset..." )
326- df = generate_aime25_dataset (
327- datasets_dir = "datasets" ,
328- force = args .force_regenerate ,
329- )
330- print (f"Loaded { len (df )} samples from AIME25" )
331-
332- # Step 2: Create transforms
333- print ("Creating transforms..." )
334- transforms = create_aime25_transforms ()
335-
336- # Step 3: Create Dataset with transforms (transforms will be applied during load())
337- print ("Creating dataset with transforms..." )
338- print (df .columns )
339- df .to_parquet ("datasets/aime25_pre-transformed_gpt-oss.parquet" )
340- # breakpoint()
341- dataset = AIME25 (
342- df , transforms = transforms , repeats = 5
343- ) # Artificial Analysis uses 5 repeats
344- dataset .load ()
345-
346- print (f"Dataset loaded with { dataset .num_samples ()} samples" )
335+ # Always generate GPQA diamond dataset
336+ print ("Generating GPQA diamond dataset..." )
337+ df = generate_gpqa_dataset (
338+ datasets_dir = "datasets" ,
339+ force = args .force_regenerate ,
340+ )
341+ print (f"Loaded { len (df )} samples from GPQA diamond" )
342+
343+ # Step 2: Create transforms
344+ print ("Creating transforms..." )
345+ transforms = create_transforms ()
346+
347+ # Step 3: Create Dataset with transforms (transforms will be applied during load())
348+ print ("Creating dataset with transforms..." )
349+ print (df .columns )
350+ df .to_parquet ("datasets/gqpa_diamond_pre-transformed_gpt-oss.parquet" )
351+ gpqa_dataset = GPQA (
352+ df , transforms = transforms , repeats = num_repeats
353+ ) # Artificial Analysis uses 5 repeats
354+ gpqa_dataset .load ()
355+ # Always generate AIME25 dataset
356+ print ("Generating AIME25 dataset..." )
357+ df = generate_aime25_dataset (
358+ datasets_dir = "datasets" ,
359+ force = args .force_regenerate ,
360+ )
361+ print (f"Loaded { len (df )} samples from AIME25" )
362+
363+ # Step 2: Create transforms
364+ print ("Creating transforms..." )
365+ transforms = create_aime25_transforms ()
366+
367+ # Step 3: Create Dataset with transforms (transforms will be applied during load())
368+ print ("Creating dataset with transforms..." )
369+ print (df .columns )
370+ df .to_parquet ("datasets/aime25_pre-transformed_gpt-oss.parquet" )
371+ aime25_dataset = AIME25 (
372+ df , transforms = transforms , repeats = num_repeats
373+ ) # Artificial Analysis uses 5 repeats
374+ aime25_dataset .load ()
375+ print (f"Dataset loaded with { aime25_dataset .num_samples ()} samples" )
347376
348377 # Step 4: Create SGLang client
349378 print (f"Creating SGLang client for endpoint: { SGLANG_ENDPOINT } " )
@@ -352,7 +381,7 @@ def run_main(args):
352381
353382 # Step 5: Run benchmark session
354383 print ("Starting benchmark session..." )
355- run_benchmark_session (dataset , sample_issuer , args )
384+ run_benchmark_session ([ gpqa_dataset , aime25_dataset ] , sample_issuer , args )
356385
357386 print (f"\n Benchmark complete! Results saved to { args .report_dir } /" )
358387
@@ -397,6 +426,13 @@ def main():
397426 help = "Directory to save benchmark reports (default: gpqa_sglang_report)" ,
398427 )
399428
429+ parser .add_argument (
430+ "--num-repeats" ,
431+ type = int ,
432+ default = 1 ,
433+ help = "Number of repeats to run (default: 1)" ,
434+ )
435+
400436 args = parser .parse_args ()
401437
402438 print ("=" * 60 )
0 commit comments