-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcheck_attribute_rate.py
More file actions
119 lines (95 loc) · 4.83 KB
/
check_attribute_rate.py
File metadata and controls
119 lines (95 loc) · 4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
import re
import argparse
import ast
from tqdm import tqdm
import pandas as pd
from heatmap import draw_heatmap
import os
def main(args):
print(args)
save_path = os.path.join(os.path.dirname(__file__), "figures", f"{args.mode}", os.path.basename(args.model_name), f"heatmap_attribute_rate_{args.source}_{args.attribute_entity}_fonted.png")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
num_layers_dict = {
"meta-llama/Llama-3.1-8B":32,
"meta-llama/Llama-2-7b-hf":32,
"meta-llama/Llama-2-13b-hf":40,
"Qwen/Qwen2.5-7B":28,
"Qwen/Qwen2.5-14B":48,
"google/gemma-7b":28
}
num_layers = num_layers_dict[args.model_name]
entity_attribute_pairs = []
entity_attribute_pairs_path = os.path.join(os.path.dirname(__file__), "data", "entity_attribute_pairs_proper.jsonl")
with open(entity_attribute_pairs_path) as rf:
for line in rf:
entity_attribute_pairs.append(json.loads(line))
# convert to dict
entity_attribute_dict = dict()
for entity_attribute_pair in entity_attribute_pairs:
entity = entity_attribute_pair["entity"]
attributes = entity_attribute_pair["attributes"]
entity_attribute_dict[entity] = attributes
# filter less than 3 characters
entity_attribute_dict_filtered = dict()
for entity, attributes in entity_attribute_dict.items():
attributes_filtered = [att for att in attributes if len(re.sub("\s+", "", att)) >= 3] # remove attributes with less than 3 characters (Excluding spaces)
entity_attribute_dict_filtered[entity] = attributes_filtered
print(f"entity_attribute_dict_filtered ready")
generation_path = os.path.join(os.path.dirname(__file__), "results", "patchscopes", args.mode, args.model_name, f"entity_description_default_description_{args.source}_max20.csv")
results = pd.read_csv(generation_path, lineterminator='\n')
main_queries = list(set(results["main_query"]))
tar_src_layer_dict = { (i, j): [] for i in range(num_layers) for j in range(num_layers) }
invalid_noe3 = 0
invalid_noatt = 0 # no valid attributes after filtering
for main_query in tqdm(main_queries):
rows_per_query = results[results["main_query"] == main_query]
answers = rows_per_query.iloc[0]["subanswer_2"]
answers = re.sub('\' ', '\', ', answers)
answers = re.sub('\'\n ', '\', ', answers)
answers = re.sub('" \'', '", \'', answers)
answers = ast.literal_eval(answers)
e2_list = rows_per_query.iloc[0]["subanswer_1"]
e2_list = re.sub('\' ', '\', ', e2_list)
e2_list = re.sub('\'\n ', '\', ', e2_list)
e2_list = re.sub('" \'', '", \'', e2_list)
e2 = ast.literal_eval(e2_list)[0]
analogy_pattern = re.compile(rf"(.+?) is to {re.escape(e2)} as (.+?) is to")
try:
e3 = analogy_pattern.search(main_query).group(2)
except:
invalid_noe3 += 1
e2_attributes = entity_attribute_dict_filtered.get(e2, [])
e3_attributes = entity_attribute_dict_filtered.get(e3, [])
if len(e2_attributes) == 0 or len(e3_attributes) == 0:
invalid_noatt += 1
answer_patterns = [re.compile(rf'\b{re.escape(answer.lower())}\b') for answer in answers]
e2_patterns = [re.compile(rf'\b{re.escape(attr.lower())}\b') for attr in e2_attributes]
e3_patterns = [re.compile(rf'\b{re.escape(attr.lower())}\b') for attr in e3_attributes]
for source, target, generation in zip(rows_per_query["source_layer"], rows_per_query["target_layer"], rows_per_query["generation"]):
gen_lower = str(generation).lower()
if args.attribute_entity == "e2":
condition = any(pattern.search(gen_lower) for pattern in e2_patterns)
elif args.attribute_entity == "e3":
condition = any(pattern.search(gen_lower) for pattern in e3_patterns)
elif args.attribute_entity == "e2e3":
condition = any(pattern.search(gen_lower) for pattern in e2_patterns) or any(pattern.search(gen_lower) for pattern in e3_patterns)
else:
raise ValueError(f"attribute entity {args.attribute_entity} not supported")
if args.source == "last":
condition_answer = any(pattern.search(gen_lower) for pattern in answer_patterns) # exclude answers from matched attributes for clarity
tar_src_layer_dict[(target, source)].append(1 if condition == True and condition_answer == False else 0)
else:
tar_src_layer_dict[(target, source)].append(1 if condition == True else 0)
print(f"invalid_noatt : {invalid_noatt}")
draw_heatmap(num_layers, tar_src_layer_dict, args.vmax, save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, choices=[
"meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-13b-hf", "meta-llama/Llama-3.1-8B", "google/gemma-7b", "Qwen/Qwen2.5-7B", "Qwen/Qwen2.5-14B"])
parser.add_argument("--mode", choices=["correct", "incorrect"])
parser.add_argument("--source", choices=["e2", "link", "e3", "last"])
parser.add_argument("--attribute_entity", choices=["e2", "e3", "e2e3"])
parser.add_argument("--vmax", type=str) # "free" or float (e.g. 0.6)
args = parser.parse_args()
main(args)