-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_knowledge.py
More file actions
85 lines (66 loc) · 3.81 KB
/
evaluate_knowledge.py
File metadata and controls
85 lines (66 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import argparse
from pathlib import Path
import json
import numpy as np
import pandas as pd
import copy
from tqdm import tqdm
import re
from utils import load_model, load_tokenizer
def main(args):
# create output dir
output_dir = os.path.join(args.save_path, args.model_name)
os.makedirs(output_dir, exist_ok=True)
# load input data
input_data = []
with open(args.input_path) as rf:
for line in rf:
input_data.append(json.loads(line))
# load model
model = load_model(args.model_name)
model.eval()
tokenizer = load_tokenizer(args.model_name)
tokenizer.padding_side = "left"
# inference
for sample in tqdm(input_data):
result = copy.deepcopy(sample)
subquery_1 = result["subquery_1"].strip()
subanswer_1 = result["subanswer_1"]
subquery_2 = result["subquery_2"].strip()
subanswer_2 = result["subanswer_2"]
main_query = result["main_query"].strip()
### add version without spacing
result["subquery_1"] = subquery_1
result["subquery_2"] = subquery_2
result["main_query"] = main_query
# check reasoning shortcut
exclude_e1_e2 = re.search(f"(.+? is to {re.escape(subanswer_1[0])} as )(.+)", main_query).group(2)
exclude_e2 = re.sub(f" {re.escape(subanswer_1[0])}", "", main_query)
result["exclude_e1_e2"] = exclude_e1_e2
result["exclude_e2"] = exclude_e2
input_texts = [subquery_1, subquery_2, main_query, exclude_e1_e2, exclude_e2]
inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(model.device)
output_ids = model.generate(**inputs, do_sample=False, pad_token_id=tokenizer.eos_token_id, max_new_tokens=10)
pred_subquery_1, pred_subquery_2, pred_main_query, pred_exclude_e1_e2, pred_exclude_e2 = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
result["pred_subquery_1"] = pred_subquery_1[len(subquery_1):].strip().replace("\n", " ")
result["pred_subquery_2"] = pred_subquery_2[len(subquery_2):].strip().replace("\n", " ")
result["pred_main_query"] = pred_main_query[len(main_query):].strip().replace("\n", " ")
result["pred_exclude_e1_e2"] = pred_exclude_e1_e2[len(exclude_e1_e2):].strip().replace("\n", " ")
result["pred_exclude_e2"] = pred_exclude_e2[len(exclude_e2):].strip().replace("\n", " ")
result["subquery_1_iscorrect"] = int(any(re.search(rf"\b{re.escape(a.lower())}\b", result["pred_subquery_1"].lower()) for a in subanswer_1))
result["subquery_2_iscorrect"] = int(any(re.search(rf"\b{re.escape(a.lower())}\b", result["pred_subquery_2"].lower()) for a in subanswer_2))
result["main_query_iscorrect"] = int(any(re.search(rf"\b{re.escape(a.lower())}\b", result["pred_main_query"].lower()) for a in subanswer_2))
result["exclude_e1_e2_iscorrect"] = int(any(re.search(rf"\b{re.escape(a.lower())}\b", result["pred_exclude_e1_e2"].lower()) for a in subanswer_2))
result["exclude_e2_iscorrect"] = int(any(re.search(rf"\b{re.escape(a.lower())}\b", result["pred_exclude_e2"].lower()) for a in subanswer_2))
with open(os.path.join(output_dir, "evaluate_knowledge.jsonl"), "a") as wf:
wf.write(json.dumps(result) + "\n")
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, choices=[
"meta-llama/Llama-2-13b-hf", "google/gemma-7b", "Qwen/Qwen2.5-14B"])
parser.add_argument("--input_path", type=str, default=os.path.join(os.path.abspath(os.path.dirname(__file__)), "data", "wikidata_analogies_50k_prompts.jsonl"))
parser.add_argument("--save_path", type=str, default=os.path.join(os.path.abspath(os.path.dirname(__file__)), "results", "evaluate_knowledge"))
args = parser.parse_args()
main(args)