-
Notifications
You must be signed in to change notification settings - Fork 187
Add DSBench-DA evaluation #1254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ef38e4f
a411396
e316d49
66ed0db
9250b87
ff6b113
a0dbb8d
72792a2
67a018c
a16051d
752e834
5e03c39
d3b039b
ac674fc
a478c3d
9778a78
f594928
c42e12b
3fdcada
2176ccd
f5986e5
719a29c
1e8dbe6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # settings that define how evaluation should be done by default (all can be changed from cmdline) | ||
| EVAL_SPLIT = "test" | ||
| METRICS_TYPE = "math" | ||
|
|
||
| # Use DSBench evaluator (extends MathEvaluator) with relaxed extraction and case-insensitive MCQ and handling of dict and list. | ||
| GENERATION_ARGS = "++prompt_config=generic/dsbench-da ++eval_type=dsbench ++eval_config.relaxed_extraction=true" | ||
|
|
||
| # Recommend running LLM judge to verify dicts and lists correctly | ||
| # JUDGE_PIPELINE_ARGS = { | ||
| # "generation_type": "math_judge", | ||
| # "model": "gpt-4.1", | ||
| # "server_type": "openai", | ||
| # "server_address": "https://api.openai.com/v1", | ||
| # } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,232 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import argparse | ||
| import json | ||
| import zipfile | ||
| from pathlib import Path | ||
|
|
||
| from huggingface_hub import hf_hub_download | ||
|
|
||
|
|
||
| def read_excel_to_text(excel_path: Path) -> str: | ||
| """Read Excel file and convert to text representation.""" | ||
| import pandas as pd | ||
|
|
||
| try: | ||
| # Explicitly handle .xlsb files with pyxlsb engine | ||
| engine = "pyxlsb" if excel_path.suffix == ".xlsb" else None | ||
| with pd.ExcelFile(excel_path, engine=engine) as xls: | ||
| sheets = {sheet_name: xls.parse(sheet_name) for sheet_name in xls.sheet_names} | ||
| except Exception as e: | ||
| raise RuntimeError(f"Failed to read Excel file {excel_path}: {e}") from e | ||
|
|
||
| combined_text = "" | ||
| for sheet_name, df in sheets.items(): | ||
| sheet_text = df.to_string(index=False) | ||
| combined_text += f"Sheet name: {sheet_name}\n{sheet_text}\n\n" | ||
| return combined_text | ||
|
|
||
|
|
||
| def format_paths_for_prompt(paths: list[Path], actual_root: Path, display_root: Path) -> str: | ||
| """Format file paths for display in prompt. | ||
|
|
||
| Args: | ||
| paths: List of absolute Path objects to format | ||
| actual_root: Root directory where files actually exist | ||
| display_root: Root directory to display in paths (absolute for abs paths, Path(".") for relative) | ||
| """ | ||
| if not paths: | ||
| return "" | ||
|
|
||
| formatted = [] | ||
| for path in paths: | ||
| try: | ||
| rel = path.relative_to(actual_root) | ||
| disp_path = display_root / rel | ||
| except ValueError: | ||
| disp_path = path | ||
| formatted.append(str(disp_path)) | ||
|
|
||
| return " ".join(formatted) | ||
|
|
||
|
|
||
| def save_data(split: str, data_dir: str | Path, display_root: str | Path | None, incontext_data: bool) -> None: | ||
| """Download and prepare DSBench data.""" | ||
| print(f"Preparing DSBench data for {split} split and saving to {data_dir}...") | ||
|
|
||
| data_dir = Path(data_dir) | ||
| data_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| extracted_data_dir = data_dir / "data" | ||
|
|
||
| # Extract if not already cached (hf_hub_download handles download caching) | ||
| if not extracted_data_dir.exists(): | ||
| print(" Downloading dataset from HuggingFace...") | ||
| zip_path = Path( | ||
| hf_hub_download(repo_id="liqiang888/DSBench", filename="data_analysis/data.zip", repo_type="dataset") | ||
| ) | ||
| print(" Extracting data...") | ||
| with zipfile.ZipFile(zip_path, "r") as zip_ref: | ||
| zip_ref.extractall(data_dir) | ||
| if not extracted_data_dir.exists(): | ||
| raise FileNotFoundError(f"Could not find data directory after extraction in {extracted_data_dir}") | ||
| print(f" Dataset cached to {data_dir}") | ||
| else: | ||
| print(f" Using cached dataset from {data_dir}") | ||
|
|
||
| # Load metadata | ||
| print(" Loading metadata...") | ||
| metadata_path = Path( | ||
| hf_hub_download(repo_id="liqiang888/DSBench", filename="data_analysis/data.json", repo_type="dataset") | ||
| ) | ||
| metadata = [] | ||
| with open(metadata_path, "r") as f: | ||
| for line in f: | ||
| if line.strip(): | ||
| metadata.append(json.loads(line.strip())) | ||
|
|
||
| # Process all tasks | ||
| if not display_root: | ||
| display_root = extracted_data_dir | ||
| else: | ||
| display_root = Path(display_root) | ||
|
|
||
| print( | ||
| f" Processing {len(metadata)} tasks at {extracted_data_dir} - using display root {display_root} for paths shown in the prompt..." | ||
| ) | ||
| all_entries = [] | ||
|
|
||
| for task in metadata: | ||
| task_id = task["id"] | ||
| task_dir = extracted_data_dir / task_id | ||
|
|
||
| if not task_dir.exists(): | ||
| raise FileNotFoundError( | ||
| f"Task directory not found: {task_dir}. " | ||
| f"Expected task {task_id} from metadata but directory is missing. " | ||
| "Data extraction may have failed." | ||
| ) | ||
| if len(task["answers"]) != len(task["questions"]): | ||
| raise ValueError( | ||
| f"Task {task_id}: mismatched questions ({len(task['questions'])}) " | ||
| f"and answers ({len(task['answers'])}) counts in metadata." | ||
| ) | ||
|
|
||
| # Read introduction | ||
| intro_file = task_dir / "introduction.txt" | ||
| introduction = "" | ||
| if intro_file.exists(): | ||
| introduction = intro_file.read_text(encoding="utf-8", errors="ignore") | ||
|
|
||
| # Get data files - support all Excel formats | ||
| excel_files = [] | ||
| for ext in ["*.xlsx", "*.xlsb", "*.xlsm"]: | ||
| excel_files.extend(task_dir.glob(ext)) | ||
| excel_files = [f for f in excel_files if "answer" not in f.name.lower()] | ||
|
|
||
| # Read Excel content for in-context mode | ||
| if incontext_data: | ||
| excel_content = "" | ||
| for excel_file in excel_files: | ||
| sheets_text = read_excel_to_text(excel_file) | ||
| excel_content += f"The excel file {excel_file.name} is: {sheets_text}\n\n" | ||
|
|
||
| # Format paths for tool mode (relative to data directory) | ||
| excel_paths = format_paths_for_prompt(excel_files, actual_root=extracted_data_dir, display_root=display_root) | ||
|
|
||
| # Uncomment to get image files and csv files (for future multimodal and agentic support) | ||
| # image_files = [] | ||
| # for ext in ["*.jpg", "*.png", "*.jpeg"]: | ||
| # image_files.extend(task_dir.glob(ext)) | ||
| # csv_files = list(task_dir.glob("*.csv")) | ||
|
|
||
| # Process each question | ||
| for idx, question_name in enumerate(task["questions"]): | ||
| question_file = task_dir / f"{question_name}.txt" | ||
|
|
||
| if not question_file.exists(): | ||
| print(f" Warning: {task_id}/{question_name}.txt not found, skipping") | ||
| continue | ||
|
|
||
| question_text = question_file.read_text(encoding="utf-8", errors="ignore").strip() | ||
|
|
||
| # Build problem text (introduction + question) | ||
| problem_text = "" | ||
| if introduction: | ||
| problem_text += f"The introduction is detailed as follows.\n{introduction}\n\n" | ||
| problem_text += f"The question for this task is detailed as follows.\n{question_text}" | ||
|
|
||
| # Create entry with all necessary fields | ||
| entry = { | ||
| # Skills standard fields | ||
| "problem": problem_text, | ||
| "expected_answer": task["answers"][idx], | ||
|
sgunasekar marked this conversation as resolved.
|
||
| # For tool mode | ||
| "excel_paths": excel_paths, | ||
| # Metadata | ||
| "task_id": task_id, | ||
| "question_id": question_name, | ||
| "task_name": task["name"], | ||
| "task_url": task["url"], | ||
| "task_year": task["year"], | ||
| } | ||
|
|
||
| if incontext_data: | ||
| entry["excel_content"] = excel_content.strip() | ||
|
|
||
| all_entries.append(entry) | ||
|
|
||
| # Validate we got some entries | ||
| if not all_entries: | ||
| raise ValueError( | ||
| f"No valid entries created! Processed {len(metadata)} tasks but all failed. " | ||
| "Check that data was downloaded correctly and Excel files are readable." | ||
| ) | ||
|
|
||
| # Save to output file | ||
| output_file = data_dir / f"{split}.jsonl" | ||
| with open(output_file, "w") as f: | ||
| for entry in all_entries: | ||
| f.write(json.dumps(entry) + "\n") | ||
|
|
||
| print(f" ✓ Saved {len(all_entries)} questions to {output_file}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--split", default="test", choices=("test",), help="DSBench only has test split") | ||
| parser.add_argument( | ||
| "--data_dir", type=str, default=None, help="Directory to save the data (defaults to dataset directory)" | ||
| ) | ||
| parser.add_argument( | ||
| "--display_root", | ||
| type=str, | ||
| default=None, | ||
| help='Root directory to display in paths (absolute for abs paths, Path(".") for relative)', | ||
| ) | ||
| parser.add_argument( | ||
| "--incontext_data", | ||
| action="store_true", | ||
| help="Have the excel files read in-context under 'excel_content' field (Default: False)", | ||
| ) | ||
| args = parser.parse_args() | ||
| print(args) | ||
| if args.data_dir is None: | ||
| # Save to the same directory as this script | ||
| data_dir = Path(__file__).absolute().parent | ||
| else: | ||
| data_dir = Path(args.data_dir) | ||
|
|
||
| save_data(args.split, data_dir, args.display_root, args.incontext_data) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import json | ||
| import logging | ||
| import re | ||
| from typing import Any | ||
|
|
||
| from math_verify import StringExtractionConfig, parse, verify | ||
|
|
||
| from nemo_skills.evaluation.evaluator.math import MathEvaluator | ||
| from nemo_skills.evaluation.math_grader import math_equal | ||
| from nemo_skills.utils import get_logger_name | ||
|
|
||
| LOG = logging.getLogger(get_logger_name(__file__)) | ||
|
|
||
|
|
||
| def relaxed_equal(gt_answer: Any, predicted_answer: Any) -> bool: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we update the original math_equal with these changes?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we do that, I guess we'd be able to fully reuse math evaluator here
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes updating math evaluator would be most cleanest: two issues
One option is to use the "relaxed" argument that is already there for extract_answer and use it to branch into relaxed-mcq.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd probably make a change directly. It feels like this is the right way to compare things. E.g. if options are A, B, C, D and llm says \boxed{a}, where A is correct, that should be counted as correct I guess. And the same for the other change. So my suggestion would be to make a change directly but please run e.g. nano-v3 math eval on maybe comp-math-24-25 and if we get score within normal random variance, we should be good
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And maybe also some mcq benchmark, eg gpqa
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would still feel more comfortable doing this as a new PR so that if it breaks anyones workflow they can revert to it. Plus would unblock dsbench for now. |
||
| """ | ||
| Relaxed equality check with: | ||
| 1. Case-insensitive MCQ matching | ||
| 2. Dict/list comparison using math_equal recursively | ||
| """ | ||
| if predicted_answer is None: | ||
| return gt_answer is None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can gt_answer be None? If not, probably better to just return False here for clarity
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not in this eval per se - but scenario I am thinking is a benchmark where some questions have short answers and others require some action (e.g., saving files). Then the gt answer can be None. |
||
|
|
||
| try: | ||
| predicted_answer = json.loads(predicted_answer) | ||
| except Exception: | ||
| pass # keep original string form | ||
| try: | ||
| gt_answer = json.loads(gt_answer) | ||
| except Exception: | ||
| pass # keep original string form | ||
|
|
||
| if isinstance(predicted_answer, dict): | ||
| if not isinstance(gt_answer, dict): | ||
| # check if any of the values in predicted_answer are equal to gt_answer | ||
| return any(relaxed_equal(gt_answer, p) for p in predicted_answer.values()) | ||
|
|
||
| # check if all the keys in gt_answer are in predicted_answer and if the values are equal; ok for predicted_answer to have more keys | ||
| return all( | ||
| k in predicted_answer and relaxed_equal(gt_answer[k], predicted_answer[k]) for k in gt_answer.keys() | ||
| ) | ||
|
|
||
| if isinstance(predicted_answer, list): | ||
| if not isinstance(gt_answer, list): | ||
| # check if any of the values in predicted_answer are equal to gt_answer | ||
| return any(relaxed_equal(gt_answer, p) for p in predicted_answer) | ||
| # check if the lengths are equal and if all the values are equal | ||
| return len(gt_answer) == len(predicted_answer) and all( | ||
| relaxed_equal(e, p) for e, p in zip(gt_answer, predicted_answer) | ||
| ) | ||
|
|
||
| # Try case-insensitive MCQ matching | ||
| # TODO: add support for numeric and roman numeral MCQs (i.e. "1", "I", "2", "II", etc.) | ||
| mcq_options = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | ||
| norm_gt_mcq = str(gt_answer).strip().upper() | ||
| norm_pred_mcq = str(predicted_answer).strip().upper() | ||
| is_mcq = re.fullmatch("|".join(mcq_options), norm_gt_mcq) | ||
| if is_mcq: | ||
| parsed_gt = parse(norm_gt_mcq, [StringExtractionConfig(strings=tuple(mcq_options))]) | ||
| parsed_pred = parse(norm_pred_mcq, [StringExtractionConfig(strings=tuple(mcq_options))]) | ||
| mcq_result = verify(parsed_gt, parsed_pred) | ||
| if mcq_result: | ||
| return mcq_result | ||
|
|
||
| return math_equal(str(gt_answer), str(predicted_answer)) | ||
|
|
||
|
|
||
| class DSBenchEvaluator(MathEvaluator): | ||
| def __init__(self, config: dict, num_parallel_requests=10): | ||
| super().__init__(config, num_parallel_requests) | ||
| self.eval_config.extract_regex = r"(?:The final answer is |\\boxed=)(.+)$" | ||
|
sgunasekar marked this conversation as resolved.
|
||
|
|
||
| async def eval_single(self, data_point: dict[str, Any]) -> dict[str, Any]: | ||
| """Evaluate single DSBench problem with relaxed fallback.""" | ||
| # First try standard math evaluation | ||
| data_point = await super().eval_single(data_point) | ||
|
|
||
| # If symbolic_correct is False, try relaxed_equal | ||
| if not data_point["symbolic_correct"]: | ||
| expected_answer = data_point["expected_answer"] | ||
| predicted_answer = data_point["predicted_answer"] | ||
|
|
||
| if relaxed_equal(expected_answer, predicted_answer): | ||
| data_point["symbolic_correct"] = True | ||
|
|
||
| return data_point | ||
Uh oh!
There was an error while loading. Please reload this page.