Skip to content

Commit 4727204

Browse files
committed
Fix post-processing aime dataset
Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com>
1 parent 3c99fe5 commit 4727204

5 files changed

Lines changed: 48 additions & 42 deletions

File tree

examples/07_GPT-OSS-120B_SGLang_Example/run.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from inference_endpoint import metrics
3535
from inference_endpoint.config.runtime_settings import RuntimeSettings
3636
from inference_endpoint.config.schema import LoadPattern, LoadPatternType
37-
from inference_endpoint.dataset_manager.dataset import Dataset
37+
from inference_endpoint.dataset_manager import Dataset, EmptyDataset
3838
from inference_endpoint.dataset_manager.predefined.aime25 import AIME25, AIME_MLPerf
3939
from inference_endpoint.dataset_manager.predefined.gpqa import GPQA, GPQA_MLPerf
4040
from inference_endpoint.endpoint_client.configs import (
@@ -57,7 +57,7 @@
5757

5858
# Configuration for SGLang server
5959
SGLANG_SERVER_HOST = "localhost"
60-
SGLANG_SERVER_PORT = 3000
60+
SGLANG_SERVER_PORT = 30000
6161
SGLANG_ENDPOINT = f"http://{SGLANG_SERVER_HOST}:{SGLANG_SERVER_PORT}/generate"
6262

6363

@@ -102,19 +102,6 @@ def create_sglang_client(tmp_dir: Path) -> HTTPEndpointClient:
102102
return client
103103

104104

105-
class EmptyDataset(Dataset):
106-
"""Empty dataset for performance run."""
107-
108-
def __init__(self):
109-
super().__init__(None)
110-
111-
def load_sample(self, index: int):
112-
return None
113-
114-
def num_samples(self):
115-
return 0
116-
117-
118105
def run_benchmark_session(
119106
accuracy_datasets: list[Dataset], issuer: HttpClientSampleIssuer, args
120107
):

src/inference_endpoint/dataset_manager/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
This module handles dataset loading, preprocessing, and management.
2020
"""
2121

22-
from .dataset import Dataset
22+
from .dataset import Dataset, EmptyDataset
2323
from .factory import DataLoaderFactory
2424
from .transforms import (
2525
AddStaticColumns,
@@ -33,6 +33,7 @@
3333

3434
__all__ = [
3535
"Dataset",
36+
"EmptyDataset",
3637
"DataLoaderFactory",
3738
"ColumnNameRemap",
3839
"AddStaticColumns",

src/inference_endpoint/dataset_manager/dataset.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -323,12 +323,6 @@ def load_from_huggingface(
323323
load_options = load_options or {}
324324
cache_options = cache_options or {}
325325

326-
# if cache_dir is not None and cache_dir.exists():
327-
# try:
328-
# ds = load_from_disk(str(cache_dir), **cache_options)
329-
# return ds[split].to_pandas()
330-
# except Exception as e:
331-
# logger.warning(f"Error loading dataset from cache: {e}")
332326
ds = load_dataset(dataset_path, dataset_name, **load_options)
333327

334328
if cache_dir is not None:
@@ -450,3 +444,16 @@ def load_sample(self, index: int) -> Any:
450444

451445
def num_samples(self) -> int:
452446
return len(self.data)
447+
448+
449+
class EmptyDataset(Dataset):
450+
"""Empty dataset for performance run."""
451+
452+
def __init__(self):
453+
super().__init__(None)
454+
455+
def load_sample(self, index: int):
456+
return None
457+
458+
def num_samples(self):
459+
return 0

src/inference_endpoint/dataset_manager/predefined/aime25/__init__.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import random
17+
import re
1718
from logging import getLogger
1819
from pathlib import Path
1920

@@ -30,6 +31,16 @@
3031
logger = getLogger(__name__)
3132

3233

34+
def normalize_number(s):
35+
"""Normalize a number string to an integer.
36+
Reference https://github.com/openai/gpt-oss/blob/48db88d8e29f48493fe75f084a8c9bd900a2b92f/gpt_oss/evals/aime_eval.py#L20
37+
"""
38+
match = re.match(r"\d+", s) # match digits from the start
39+
if not match:
40+
return None
41+
return int(match.group(0))
42+
43+
3344
class AIME25(
3445
Dataset,
3546
dataset_id="aime25",
@@ -110,12 +121,15 @@ def generate(
110121

111122
processed_rows = []
112123
for _, row in df.iterrows():
113-
correct_answer = row["answer"]
114-
124+
correct_answer = (
125+
normalize_number(row["answer"])
126+
if isinstance(row["answer"], str)
127+
else row["answer"]
128+
)
115129
# Create processed row
116130
processed_row = {
117131
"question": row["question"], # Original question
118-
"answer": correct_answer,
132+
"answer": str(correct_answer),
119133
}
120134

121135
processed_rows.append(processed_row)
@@ -126,21 +140,6 @@ def generate(
126140
logger.info(f"Saved {len(df)} samples to {dst_path}")
127141
return df
128142

129-
# @classmethod
130-
# def generate_aime25_dataset(
131-
# cls,
132-
# datasets_dir: Path,
133-
# max_samples: int | None = None,
134-
# force: bool = False,
135-
# ) -> pd.DataFrame:
136-
# """Generate the AIME25 dataset to a file."""
137-
# df = AIME25.generate(
138-
# datasets_dir=Path(datasets_dir),
139-
# max_samples=max_samples,
140-
# force=force,
141-
# )
142-
# return df
143-
144143

145144
class AIME_MLPerf(AIME25):
146145
"""AIME_MLPerf: AIME 2025 MLPerf Dataset

src/inference_endpoint/evaluation/extractor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,23 @@ def extract(cls, text: str) -> str | None:
165165

166166

167167
class BoxedMathExtractor(Extractor):
168-
"""Extract boxed math answer from response text."""
168+
"""Extract boxed math answer from response text.
169+
Based on OpenAI's extract_boxed_math function from GPT-OSS.
170+
https://github.com/openai/gpt-oss/blob/main/gpt_oss/evals/aime_eval.py
171+
"""
169172

170173
@classmethod
171174
def extract(cls, text: str) -> str | None:
172-
matches = re.findall(r"\\boxed\{([^}]+)\}", text)
175+
pattern = r"boxed{(.*?)}|framebox{(.*?)}"
176+
matches = re.findall(pattern, text, re.DOTALL)
177+
if matches:
178+
for match in matches[::-1]:
179+
for group in match:
180+
if group != "":
181+
retval = group.split(",")[-1].strip()
182+
return retval
183+
pattern = r"\d+" # get the last integer if no pattern found
184+
matches = re.findall(pattern, text, re.DOTALL)
173185
if matches:
174186
return matches[-1]
175187
return None

0 commit comments

Comments
 (0)