Skip to content

Commit 3c99fe5

Browse files
committed
Refactor
Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com>
1 parent dec0362 commit 3c99fe5

5 files changed

Lines changed: 162 additions & 173 deletions

File tree

examples/06_GPT-OSS-120B_SGLang_Example/eval_accuracy.py renamed to examples/07_GPT-OSS-120B_SGLang_Example/eval_accuracy.py

File renamed without changes.
File renamed without changes.

examples/06_GPT-OSS-120B_SGLang_Example/run.py renamed to examples/07_GPT-OSS-120B_SGLang_Example/run.py

Lines changed: 7 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,12 @@
3131
import random
3232
from pathlib import Path
3333

34-
import pandas as pd
3534
from inference_endpoint import metrics
3635
from inference_endpoint.config.runtime_settings import RuntimeSettings
3736
from inference_endpoint.config.schema import LoadPattern, LoadPatternType
3837
from inference_endpoint.dataset_manager.dataset import Dataset
39-
from inference_endpoint.dataset_manager.predefined.aime25 import AIME25
40-
from inference_endpoint.dataset_manager.predefined.gpqa import GPQA
41-
from inference_endpoint.dataset_manager.transforms import (
42-
AddStaticColumns,
43-
DropColumns,
44-
Harmonize,
45-
UserPromptFormatter,
46-
)
38+
from inference_endpoint.dataset_manager.predefined.aime25 import AIME25, AIME_MLPerf
39+
from inference_endpoint.dataset_manager.predefined.gpqa import GPQA, GPQA_MLPerf
4740
from inference_endpoint.endpoint_client.configs import (
4841
AioHttpConfig,
4942
HTTPClientConfig,
@@ -64,7 +57,7 @@
6457

6558
# Configuration for SGLang server
6659
SGLANG_SERVER_HOST = "localhost"
67-
SGLANG_SERVER_PORT = 30000
60+
SGLANG_SERVER_PORT = 3000
6861
SGLANG_ENDPOINT = f"http://{SGLANG_SERVER_HOST}:{SGLANG_SERVER_PORT}/generate"
6962

7063

@@ -82,133 +75,6 @@ def set_pbar(self, pbar: tqdm):
8275
self.pbar = pbar
8376

8477

85-
def generate_gpqa_dataset(
86-
datasets_dir: Path,
87-
variant: str = "diamond",
88-
max_samples: int | None = None,
89-
force: bool = False,
90-
) -> pd.DataFrame:
91-
"""Generate the GPQA dataset to a file.
92-
93-
Args:
94-
datasets_dir: Directory where datasets are stored
95-
variant: GPQA variant to use (default: "diamond")
96-
max_samples: Maximum number of samples to include (default: None = all)
97-
force: Force regeneration of dataset even if it exists
98-
99-
Returns:
100-
DataFrame containing the GPQA dataset
101-
"""
102-
df = GPQA.generate(
103-
datasets_dir=Path(datasets_dir),
104-
variant=variant,
105-
max_samples=max_samples,
106-
force=force,
107-
)
108-
return df
109-
110-
111-
def generate_aime25_dataset(
112-
datasets_dir: Path,
113-
max_samples: int | None = None,
114-
force: bool = False,
115-
) -> pd.DataFrame:
116-
"""Generate the AIME25 dataset to a file."""
117-
df = AIME25.generate(
118-
datasets_dir=Path(datasets_dir),
119-
max_samples=max_samples,
120-
force=force,
121-
)
122-
return df
123-
124-
125-
def create_transforms() -> list:
126-
"""Create the list of transforms to apply to the GPQA dataset.
127-
128-
Returns:
129-
List of transforms to apply
130-
"""
131-
prompt_format = (
132-
"{question}\n\n"
133-
"(A) {choice1}\n"
134-
"(B) {choice2}\n"
135-
"(C) {choice3}\n"
136-
"(D) {choice4}\n\n"
137-
"Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'."
138-
)
139-
140-
return [
141-
# Step 1: Format the prompt from question and choices
142-
UserPromptFormatter(
143-
user_prompt_format=prompt_format,
144-
output_column="user_prompt",
145-
),
146-
# Step 2: Harmonize the prompt for SGLang/GPT-OSS
147-
Harmonize(
148-
prompt_column="user_prompt",
149-
),
150-
# Step 3: Drop columns we don't need for inference
151-
DropColumns(
152-
columns=[
153-
"question",
154-
"choice1",
155-
"choice2",
156-
"choice3",
157-
"choice4",
158-
"domain",
159-
"subdomain",
160-
"user_prompt",
161-
],
162-
errors="ignore",
163-
),
164-
# Step 4: Add metadata columns since we don't want to do a dict update every iteration
165-
AddStaticColumns(
166-
{
167-
"stream": True,
168-
"max_new_tokens": 32768,
169-
"temperature": 1.0,
170-
"top_p": 1.0,
171-
"top_k": -1,
172-
}
173-
),
174-
]
175-
176-
177-
def create_aime25_transforms() -> list:
178-
"""Create the list of transforms to apply to the AIME25 dataset."""
179-
prompt_format = "{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}."
180-
181-
return [
182-
# Step 1: Format the prompt from question and choices
183-
UserPromptFormatter(
184-
user_prompt_format=prompt_format,
185-
output_column="user_prompt",
186-
),
187-
# Step 2: Harmonize the prompt for SGLang/GPT-OSS
188-
Harmonize(
189-
prompt_column="user_prompt",
190-
),
191-
# Step 3: Drop columns we don't need for inference
192-
DropColumns(
193-
columns=[
194-
"question",
195-
"user_prompt",
196-
],
197-
errors="ignore",
198-
),
199-
# Step 4: Add metadata columns since we don't want to do a dict update every iteration
200-
AddStaticColumns(
201-
{
202-
"stream": True,
203-
"max_new_tokens": 32768,
204-
"temperature": 1.0,
205-
"top_p": 1.0,
206-
"tok_k": -1,
207-
}
208-
),
209-
]
210-
211-
21278
def create_sglang_client(tmp_dir: Path) -> HTTPEndpointClient:
21379
"""Create an SGLang HTTP client for issuing queries.
21480
@@ -334,43 +200,11 @@ def run_main(args):
334200
try:
335201
# Always generate GPQA diamond dataset
336202
print("Generating GPQA diamond dataset...")
337-
df = generate_gpqa_dataset(
338-
datasets_dir="datasets",
339-
force=args.force_regenerate,
340-
)
341-
print(f"Loaded {len(df)} samples from GPQA diamond")
342-
343-
# Step 2: Create transforms
344-
print("Creating transforms...")
345-
transforms = create_transforms()
346-
347-
# Step 3: Create Dataset with transforms (transforms will be applied during load())
348-
print("Creating dataset with transforms...")
349-
print(df.columns)
350-
df.to_parquet("datasets/gqpa_diamond_pre-transformed_gpt-oss.parquet")
351-
gpqa_dataset = GPQA(
352-
df, transforms=transforms, repeats=num_repeats
353-
) # Artificial Analysis uses 5 repeats
203+
gpqa_dataset = GPQA_MLPerf.get_gpqa_dataloader(num_repeats=num_repeats)
354204
gpqa_dataset.load()
355205
# Always generate AIME25 dataset
356206
print("Generating AIME25 dataset...")
357-
df = generate_aime25_dataset(
358-
datasets_dir="datasets",
359-
force=args.force_regenerate,
360-
)
361-
print(f"Loaded {len(df)} samples from AIME25")
362-
363-
# Step 2: Create transforms
364-
print("Creating transforms...")
365-
transforms = create_aime25_transforms()
366-
367-
# Step 3: Create Dataset with transforms (transforms will be applied during load())
368-
print("Creating dataset with transforms...")
369-
print(df.columns)
370-
df.to_parquet("datasets/aime25_pre-transformed_gpt-oss.parquet")
371-
aime25_dataset = AIME25(
372-
df, transforms=transforms, repeats=num_repeats
373-
) # Artificial Analysis uses 5 repeats
207+
aime25_dataset = AIME_MLPerf.get_aime25_dataloader(num_repeats=num_repeats)
374208
aime25_dataset.load()
375209
print(f"Dataset loaded with {aime25_dataset.num_samples()} samples")
376210

@@ -394,7 +228,7 @@ def run_main(args):
394228
def main():
395229
"""Main entry point for the manual example."""
396230
parser = argparse.ArgumentParser(
397-
description="GPQA dataset example with SGLang endpoint",
231+
description="GPQA and AIME25 MLPerf dataset example with SGLang endpoint",
398232
formatter_class=argparse.RawDescriptionHelpFormatter,
399233
epilog=__doc__,
400234
)
@@ -436,7 +270,7 @@ def main():
436270
args = parser.parse_args()
437271

438272
print("=" * 60)
439-
print("GPQA Dataset Example with SGLang")
273+
print("GPQA and AIME25 MLPerf Dataset Example with SGLang")
440274
print("=" * 60)
441275
print("\nConfiguration:")
442276
print(f" SGLang endpoint: {SGLANG_ENDPOINT}")

src/inference_endpoint/dataset_manager/predefined/aime25/__init__.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
from pathlib import Path
1919

2020
import pandas as pd
21+
from inference_endpoint.dataset_manager.transforms import (
22+
AddStaticColumns,
23+
DropColumns,
24+
Harmonize,
25+
UserPromptFormatter,
26+
)
2127

2228
from ...dataset import Dataset, load_from_huggingface
2329

@@ -119,3 +125,81 @@ def generate(
119125
df.to_parquet(dst_path)
120126
logger.info(f"Saved {len(df)} samples to {dst_path}")
121127
return df
128+
129+
# @classmethod
130+
# def generate_aime25_dataset(
131+
# cls,
132+
# datasets_dir: Path,
133+
# max_samples: int | None = None,
134+
# force: bool = False,
135+
# ) -> pd.DataFrame:
136+
# """Generate the AIME25 dataset to a file."""
137+
# df = AIME25.generate(
138+
# datasets_dir=Path(datasets_dir),
139+
# max_samples=max_samples,
140+
# force=force,
141+
# )
142+
# return df
143+
144+
145+
class AIME_MLPerf(AIME25):
146+
"""AIME_MLPerf: AIME 2025 MLPerf Dataset
147+
Reference: https://huggingface.co/datasets/opencompass/AIME2025/
148+
"""
149+
150+
@classmethod
151+
def generate(
152+
cls,
153+
datasets_dir: Path,
154+
max_samples: int | None = None,
155+
force: bool = False,
156+
) -> pd.DataFrame:
157+
"""Generate the AIME25 MLPerf dataset to a file."""
158+
df = AIME25.generate(
159+
datasets_dir=Path(datasets_dir),
160+
max_samples=max_samples,
161+
force=force,
162+
)
163+
return df
164+
165+
@classmethod
166+
def create_aime25_transforms(cls) -> list:
167+
"""Create the list of transforms to apply to the AIME25 dataset."""
168+
prompt_format = "{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}."
169+
170+
return [
171+
# Step 1: Format the prompt from question and choices
172+
UserPromptFormatter(
173+
user_prompt_format=prompt_format,
174+
output_column="user_prompt",
175+
),
176+
# Step 2: Harmonize the prompt for SGLang/GPT-OSS
177+
Harmonize(
178+
prompt_column="user_prompt",
179+
),
180+
# Step 3: Drop columns we don't need for inference
181+
DropColumns(
182+
columns=[
183+
"question",
184+
"user_prompt",
185+
],
186+
errors="ignore",
187+
),
188+
# Step 4: Add metadata columns since we don't want to do a dict update every iteration
189+
AddStaticColumns(
190+
{
191+
"stream": True,
192+
"max_new_tokens": 32768,
193+
"temperature": 1.0,
194+
"top_p": 1.0,
195+
"top_k": -1,
196+
}
197+
),
198+
]
199+
200+
@classmethod
201+
def get_aime25_dataloader(cls, num_repeats: int = 5):
202+
df = AIME25.generate(datasets_dir=Path("datasets"))
203+
transforms = AIME_MLPerf.create_aime25_transforms()
204+
aime25_dataset = AIME25(df, transforms=transforms, repeats=num_repeats)
205+
return aime25_dataset

0 commit comments

Comments
 (0)