-
Notifications
You must be signed in to change notification settings - Fork 585
Expand file tree
/
Copy pathevaluate_json.py
More file actions
104 lines (87 loc) · 3.21 KB
/
evaluate_json.py
File metadata and controls
104 lines (87 loc) · 3.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
# Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
# Code repository: https://github.com/rasbt/reasoning-from-scratch
import argparse
import json
from pathlib import Path
from reasoning_from_scratch.ch03 import (
extract_final_candidate,
grade_answer,
)
def parse_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--json_path",
type=str,
required=True,
help="Path to the records file (.json or .jsonl).",
)
parser.add_argument(
"--gtruth_answer",
type=str,
default="gtruth_answer",
help="Key name for the ground-truth answer",
)
parser.add_argument(
"--generated_text",
type=str,
default="generated_text",
help="Key name for generated model output",
)
return parser.parse_args()
def load_records(json_path):
path = Path(json_path)
if not path.exists():
raise FileNotFoundError(f"Input file not found: {path}")
with path.open("r", encoding="utf-8") as f:
try:
parsed = json.load(f)
except json.JSONDecodeError:
f.seek(0)
records = []
for line_num, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
records.append(json.loads(line))
except json.JSONDecodeError as exc:
raise ValueError(
f"Invalid JSON on line {line_num} in {path}: {exc}"
) from exc
return records
if isinstance(parsed, list):
return parsed
if isinstance(parsed, dict):
if "records" in parsed and isinstance(parsed["records"], list):
return parsed["records"]
return [parsed]
raise ValueError(
f"Unsupported JSON root type in {path}: {type(parsed).__name__}"
)
def evaluate_records(records, gtruth_key, generated_text_key):
num_examples = len(records)
num_correct = 0
for idx, record in enumerate(records, start=1):
if not isinstance(record, dict):
raise ValueError(
f"Record {idx} is not a JSON object: {type(record).__name__}"
)
if gtruth_key not in record:
raise KeyError(f"Record {idx} is missing key: {gtruth_key}")
if generated_text_key not in record:
raise KeyError(f"Record {idx} is missing key: {generated_text_key}")
extracted = extract_final_candidate(record[generated_text_key])
is_correct = grade_answer(extracted, record[gtruth_key])
num_correct += int(is_correct)
acc = num_correct / num_examples if num_examples else 0.0
return num_correct, num_examples, acc
if __name__ == "__main__":
args = parse_args()
records = load_records(args.json_path)
num_correct, num_examples, acc = evaluate_records(
records=records,
gtruth_key=args.gtruth_answer,
generated_text_key=args.generated_text,
)
print(f"Accuracy: {acc*100:.1f}% ({num_correct}/{num_examples})")