-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_knowledge_swap_pair.py
More file actions
67 lines (49 loc) · 2.17 KB
/
evaluate_knowledge_swap_pair.py
File metadata and controls
67 lines (49 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import random
import os
import argparse
from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch
import copy
from tqdm import tqdm
import re
from utils import get_layer_names, load_model, load_tokenizer
def main(args):
input_path = os.path.join(os.path.dirname(__file__), "data", "intervention", args.model_name, "data_intervention_filtered.jsonl")
output_path = os.path.join(os.path.dirname(__file__), "results", "evaluate_knowledge", args.model_name, "evaluate_knowledge_swap_pair.jsonl")
if os.path.exists(output_path):
os.remove(output_path)
# create output dir
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# load input data
input_data = []
with open(input_path) as rf:
for line in rf:
input_data.append(json.loads(line))
# load model
model = load_model(args.model_name)
model.eval()
tokenizer = load_tokenizer(args.model_name)
tokenizer.padding_side = "left"
# inference
for sample in tqdm(input_data):
result = copy.deepcopy(sample)
subanswer_2 = result["subanswer_2"]
modified_main_query = result["modified_main_query"]
input_texts = [modified_main_query]
inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(model.device)
output_ids = model.generate(**inputs, do_sample=False, pad_token_id=tokenizer.eos_token_id, max_new_tokens=10)
pred_modified_main_query = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
result["pred_modified_main_query"] = pred_modified_main_query[len(modified_main_query):].strip().replace("\n", " ")
result["modified_main_query_iscorrect"] = int(any(re.search(rf"\b{re.escape(a.lower())}\b", result["pred_modified_main_query"].lower()) for a in subanswer_2))
with open(output_path, "a") as wf:
wf.write(json.dumps(result) + "\n")
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, choices=[
"meta-llama/Llama-2-13b-hf", "google/gemma-7b", "Qwen/Qwen2.5-14B"])
args = parser.parse_args()
main(args)