-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_vec.py
More file actions
92 lines (69 loc) · 5.15 KB
/
generate_vec.py
File metadata and controls
92 lines (69 loc) · 5.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import torch
import os
import argparse
def load_jsonl(file_path):
with open(file_path, 'r') as f:
return [json.loads(line) for line in f]
def get_hidden_p_and_r(model, tokenizer, prompts, responses, layer_list=None):
max_layer = model.config.num_hidden_layers
if layer_list is None:
layer_list = list(range(max_layer+1))
prompt_avg = [[] for _ in range(max_layer+1)]
response_avg = [[] for _ in range(max_layer+1)]
prompt_last = [[] for _ in range(max_layer+1)]
texts = [p+a for p, a in zip(prompts, responses)]
for text, prompt in tqdm(zip(texts, prompts), total=len(texts)):
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(model.device)
prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False))
outputs = model(**inputs, output_hidden_states=True)
for layer in layer_list:
prompt_avg[layer].append(outputs.hidden_states[layer][:, :prompt_len, :].mean(dim=1).detach().cpu())
response_avg[layer].append(outputs.hidden_states[layer][:, prompt_len:, :].mean(dim=1).detach().cpu())
prompt_last[layer].append(outputs.hidden_states[layer][:, prompt_len-1, :].detach().cpu())
del outputs
for layer in layer_list:
prompt_avg[layer] = torch.cat(prompt_avg[layer], dim=0)
prompt_last[layer] = torch.cat(prompt_last[layer], dim=0)
response_avg[layer] = torch.cat(response_avg[layer], dim=0)
return prompt_avg, prompt_last, response_avg
import pandas as pd
import os
def get_persona_effective(pos_path, neg_path, trait, threshold=50):
persona_pos = pd.read_csv(pos_path)
persona_neg = pd.read_csv(neg_path)
mask = (persona_pos[trait.split("_")[0]] >=threshold) & (persona_neg[trait.split("_")[0]] < 100-threshold) & (persona_pos["coherence"] >= 50) & (persona_neg["coherence"] >= 50)
persona_pos_effective = persona_pos[mask]
persona_neg_effective = persona_neg[mask]
persona_pos_effective_prompts = persona_pos_effective["prompt"].tolist()
persona_neg_effective_prompts = persona_neg_effective["prompt"].tolist()
persona_pos_effective_responses = persona_pos_effective["answer"].tolist()
persona_neg_effective_responses = persona_neg_effective["answer"].tolist()
return persona_pos_effective, persona_neg_effective, persona_pos_effective_prompts, persona_neg_effective_prompts, persona_pos_effective_responses, persona_neg_effective_responses
def save_persona_vector(model_name, pos_path, neg_path, trait, save_dir, threshold=50):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
persona_pos_effective, persona_neg_effective, persona_pos_effective_prompts, persona_neg_effective_prompts, persona_pos_effective_responses, persona_neg_effective_responses = get_persona_effective(pos_path, neg_path, trait, threshold)
persona_effective_prompt_avg, persona_effective_prompt_last, persona_effective_response_avg = {}, {}, {}
persona_effective_prompt_avg["pos"], persona_effective_prompt_last["pos"], persona_effective_response_avg["pos"] = get_hidden_p_and_r(model, tokenizer, persona_pos_effective_prompts, persona_pos_effective_responses)
persona_effective_prompt_avg["neg"], persona_effective_prompt_last["neg"], persona_effective_response_avg["neg"] = get_hidden_p_and_r(model, tokenizer, persona_neg_effective_prompts, persona_neg_effective_responses)
persona_effective_prompt_avg_diff = torch.stack([persona_effective_prompt_avg["pos"][l].mean(0).float() - persona_effective_prompt_avg["neg"][l].mean(0).float() for l in range(len(persona_effective_prompt_avg["pos"]))], dim=0)
persona_effective_response_avg_diff = torch.stack([persona_effective_response_avg["pos"][l].mean(0).float() - persona_effective_response_avg["neg"][l].mean(0).float() for l in range(len(persona_effective_response_avg["pos"]))], dim=0)
persona_effective_prompt_last_diff = torch.stack([persona_effective_prompt_last["pos"][l].mean(0).float() - persona_effective_prompt_last["neg"][l].mean(0).float() for l in range(len(persona_effective_prompt_last["pos"]))], dim=0)
os.makedirs(save_dir, exist_ok=True)
torch.save(persona_effective_prompt_avg_diff, f"{save_dir}/{trait}_prompt_avg_diff.pt")
torch.save(persona_effective_response_avg_diff, f"{save_dir}/{trait}_response_avg_diff.pt")
torch.save(persona_effective_prompt_last_diff, f"{save_dir}/{trait}_prompt_last_diff.pt")
print(f"Persona vectors saved to {save_dir}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--pos_path", type=str, required=True)
parser.add_argument("--neg_path", type=str, required=True)
parser.add_argument("--trait", type=str, required=True)
parser.add_argument("--save_dir", type=str, required=True)
parser.add_argument("--threshold", type=int, default=50)
args = parser.parse_args()
save_persona_vector(args.model_name, args.pos_path, args.neg_path, args.trait, args.save_dir, args.threshold)