-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpatched_generation.py
More file actions
411 lines (332 loc) · 19.8 KB
/
patched_generation.py
File metadata and controls
411 lines (332 loc) · 19.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
# Modified from: https://github.com/edenbiran/HoppingTooLate
# Original Authors: Eden Biran, Daniela Gottesman, Sohee Yang, Mor Geva, Amir Globerson
import functools
import numpy as np
import torch
import torch.nn.functional as F
from baukit import Trace, TraceDict
from utils import get_prepend_space, get_layer_names, get_attention_modules, get_attention_layers_names, \
get_norm_module, get_mlp_layers_names, get_all_attention_modules
def find_tokens(tokenizer, string_tokens, substring, prepend_space, last=True):
total_searched = []
if prepend_space:
substring = " " + substring
substring_tokens = tokenizer(substring, add_special_tokens=False, return_tensors="pt").input_ids[0]
substring_tokens = substring_tokens.to(string_tokens.device)
for start in range(len(string_tokens) - len(substring_tokens) + 1):
end = start + len(substring_tokens)
if torch.all(string_tokens[start:end] == substring_tokens):
if last:
total_searched.append(end - 1)
else:
total_searched.append(start, end - 1)
return total_searched[-1]
def get_hidden_states(model, tokenizer, entity, prompt):
prompt_inputs = tokenizer(prompt, return_tensors="pt", padding=True)
prepend_space = get_prepend_space(model)
## decoding
# print(f"source_str : {entity}")
# print(f"modified_main_query : {prompt}")
# ##
entity_idx = [find_tokens(tokenizer, q, e, prepend_space) for q, e in zip(prompt_inputs.input_ids, entity)]
entity_hidden_states = [None for _ in range(len(entity_idx))]
layers = get_layer_names(model)
with torch.no_grad(), TraceDict(model, layers) as trace:
prompt_inputs = prompt_inputs.to(model.device)
model(**prompt_inputs)
for i, idx in enumerate(entity_idx):
if idx is not None:
entity_hidden_states[i] = torch.stack([trace[layer].output[0][i][idx].cpu() for layer in layers])
return torch.stack(entity_hidden_states)
def get_hidden_states_with_patching(model, tokenizer, entity, prompt, hidden_state, target_layer):
prompt_inputs = tokenizer(prompt, return_tensors="pt", padding=True)
prepend_space = get_prepend_space(model)
entity_idx = [find_tokens(tokenizer, q, e, prepend_space) for q, e in zip(prompt_inputs.input_ids, entity)]
entity_hidden_states = [None for _ in range(len(entity_idx))]
def replace_hidden_state_hook(output):
hs = output[0]
if hs.shape[1] == 1: # After first replacement the hidden state is cached
return output
hs[entity_idx] = hidden_state.to(hs.device)
return (hs,) + output[1:]
layers = get_layer_names(model)
with torch.no_grad(), Trace(model, layer=target_layer, edit_output=replace_hidden_state_hook), TraceDict(model,
layers) as trace:
prompt_inputs = prompt_inputs.to(model.device)
model(**prompt_inputs)
for i, idx in enumerate(entity_idx):
if idx is not None:
entity_hidden_states[i] = torch.stack([trace[layer].output[0][i][idx].cpu() for layer in layers])
return torch.stack(entity_hidden_states)
def generate_patching_inputs(model, tokenizer, target_prompt, target_token_str):
if get_prepend_space(model):
if type(target_token_str) is str:
target_token_str = " " + target_token_str
else:
target_token_str = [" " + t for t in target_token_str]
target_token = tokenizer(target_token_str, return_tensors="pt", add_special_tokens=False,
padding=True).input_ids[..., -1] # e.g. tensor([311, 311]) for target_token_str [" to", " to"]
inputs = tokenizer(target_prompt, return_tensors="pt", padding=True)
target_position = (inputs.input_ids == target_token.unsqueeze(1)).cumsum(dim=1).argmax(dim=1) # last match equals biggest cumulated sum
return inputs.to(model.device), (torch.LongTensor(range(target_position.shape[0])), target_position) # e.g. (tensor([0, 1]), tensor([18, 12]))
def generate_patching_inputs_long_target_prompt(model, tokenizer, target_prompt, target_token_str):
'''
target_prompt : ['The Adventure as the Norwood Builder is to Arthur Conan Doyle as Convivio as is to ', 'Vichy as France is to French as Swedish Empire is to ']
target_token_str : ['Arthur Conan Doyle as', 'French as']
'''
if get_prepend_space(model):
# Add space if required by the model's tokenizer
if isinstance(target_token_str, str):
target_token_str = " " + target_token_str
else:
target_token_str = [" " + t for t in target_token_str]
# Tokenize target token sequences
target_token_ids = tokenizer(target_token_str, return_tensors="pt", add_special_tokens=False, padding=True).input_ids
# Remove special tokens from target_token_ids
special_token_ids = set(tokenizer.all_special_ids)
target_token_ids_clean = [
seq[~torch.isin(seq, torch.tensor(list(special_token_ids), device=seq.device))]
for seq in target_token_ids
]
target_token_ids_clean = torch.nn.utils.rnn.pad_sequence(
target_token_ids_clean, batch_first=True, padding_value=tokenizer.pad_token_id
)
# Tokenize the target prompts
inputs = tokenizer(target_prompt, return_tensors="pt", padding=True)
# Identify matching positions for each sequence in the batch
batch_size = inputs.input_ids.shape[0]
target_positions = []
for batch_idx in range(batch_size):
prompt_ids = inputs.input_ids[batch_idx]
target_ids = target_token_ids_clean[batch_idx]
# Debugging info
# print(f"Batch {batch_idx} Prompt IDs:", prompt_ids.tolist())
print(f"prompt : {tokenizer.decode(prompt_ids, skip_special_tokens=True)}")
# print(f"Batch {batch_idx} Target IDs (Clean):", target_ids.tolist())
print(f"target : {tokenizer.decode(target_ids, skip_special_tokens=True)}")
if tokenizer.decode(prompt_ids, skip_special_tokens=True).startswith(tokenizer.decode(target_ids, skip_special_tokens=True).strip()):
target_ids = tokenizer.encode(tokenizer.decode(target_ids, skip_special_tokens=True).strip(), return_tensors="pt")[0]
print(f"prompt ids : {prompt_ids}")
print(f"target ids : {target_ids}")
match_length = target_ids[target_ids != tokenizer.pad_token_id].shape[-1]
found_match = False
for i in range(prompt_ids.shape[-1] - match_length + 1):
if torch.equal(prompt_ids[i:i + match_length], target_ids[:match_length]):
target_positions.append(i + match_length - 1)
found_match = True
break
if not found_match:
target_positions.append(-1)
target_positions = torch.tensor(target_positions, device=model.device)
batch_indices = torch.arange(batch_size, device=model.device)
print(f"matched target positions : {target_positions}")
return inputs.to(model.device), (batch_indices, target_positions)
def decode_generated(tokenizer, generated, target_prompt, sampled=False):
text = tokenizer.batch_decode(generated, skip_special_tokens=True)
batch_size = len(target_prompt)
sample_size = len(text) // len(target_prompt)
if sampled:
target_prompt = np.repeat(target_prompt, sample_size)
pred = [t[len(p):].strip().replace("\n", " ") for t, p in zip(text, target_prompt)]
if sampled:
pred = np.split(np.array(pred), batch_size)
return pred
def generate_with_patching_layer(model, tokenizer, hidden_state, target_layer, target_prompt, target_token_str,
do_sample):
inputs, target_position = generate_patching_inputs(model, tokenizer, target_prompt, target_token_str)
def replace_hidden_state_hook(output):
hs = output[0]
if hs.shape[1] == 1: # After first replacement the hidden state is cached
return output
hs[target_position] = hidden_state.to(hs.device)
return (hs,) + output[1:]
with torch.no_grad(), Trace(model, layer=target_layer, retain_output=False, edit_output=replace_hidden_state_hook):
if do_sample:
generated = model.generate(**inputs, pad_token_id=tokenizer.eos_token_id, max_new_tokens=20, do_sample=True)
else:
generated = model.generate(**inputs, do_sample=False, temperature=1, top_p=1, num_beams=1,
pad_token_id=tokenizer.eos_token_id, max_new_tokens=20)
return decode_generated(tokenizer, generated, target_prompt)
def generate_with_patching_same_layers(model, tokenizer, hidden_states, target_prompt, target_token_str, do_sample):
generations_by_layer = []
layer_names = get_layer_names(model)
for source_layer_idx, target_layer in enumerate(layer_names):
hidden_state = hidden_states[:, source_layer_idx, :]
generations = generate_with_patching_layer(model, tokenizer, hidden_state, target_layer, target_prompt,
target_token_str, do_sample)
generations_by_layer.append(generations)
return generations_by_layer
def generate_with_patching_all_layers(model, tokenizer, hidden_states, target_prompt, target_token_str, do_sample):
layer_names = get_layer_names(model)
layer_count = len(get_layer_names(model))
generations = np.ndarray((layer_count, layer_count), dtype=object)
for source_layer in range(layer_count):
for target_layer in range(layer_count):
hidden_state = hidden_states[:, source_layer, :]
generations[source_layer, target_layer] = generate_with_patching_layer(model,
tokenizer,
hidden_state,
layer_names[target_layer],
target_prompt,
target_token_str,
do_sample)
return generations
def get_top_k_tokens(tokenizer, scores, k=10):
with torch.no_grad():
probabilities = F.softmax(scores, dim=-1)
top_k_probabilities, top_k_indices = torch.topk(probabilities, k, dim=-1)
top_k_words = np.empty_like(top_k_indices.cpu(), dtype=object)
for b, batch in enumerate(top_k_indices):
for l, layer in enumerate(batch):
top_k_words[b][l] = tokenizer.batch_decode(layer.unsqueeze(-1))
return top_k_words, top_k_probabilities
def get_token_ranks(projections, tokens):
projections = projections.float().cpu()
vals = projections[:, tokens]
ranks = np.empty_like(vals, dtype=int)
for layer in range(projections.shape[0]):
for i, val in enumerate(vals[layer]):
ranks[layer][i] = (projections[layer] > val).sum()
return ranks
def generate_with_attention_knockout(model, tokenizer, layer_idx, prompt, source, target_token_str, default_decoding,
k=0, return_probabilities=False):
'''
layer_idx : -1 ~ (number of hidden layers)
prompt : main_query (list)
source : knockout_source ("final") (string)
target_token_str : string for knockout target position ["e1", "link", "e2", "e3", "last"], e.g. "to" ("last") (list)
default_decoding : False (=greedy decoding)
'''
# inputs, target_position = generate_patching_inputs(model, tokenizer, prompt, target_token_str)
inputs, target_position = generate_patching_inputs_long_target_prompt(model, tokenizer, prompt, target_token_str)
# inputs : batch (with padding) of source prompts
# target position : (tensor([0, 1]), tensor([18, 12])). tensor([0,1]) refers to batch indices, tensor([18, 12]) refers to the last match of " to" when target="last"
if type(source) is not str:
prepend_space = get_prepend_space(model)
source_position = torch.LongTensor([find_tokens(tokenizer, q, e, prepend_space) for q, e in
zip(inputs.input_ids, source)])
def wrap_attention_forward(original_forward_func):
@functools.wraps(original_forward_func)
def knockout_attention_forward(*args, **kwargs):
new_args = []
new_kwargs = {}
for arg in args:
new_args.append(arg)
for key, v in kwargs.items():
new_kwargs[key] = v
if "hidden_states" in kwargs:
hidden_states = kwargs["hidden_states"]
else:
hidden_states = args[0]
batch_size = hidden_states.shape[0]
num_tokens = hidden_states.shape[1]
attention_weight_size = (batch_size, model.config.num_attention_heads, num_tokens, num_tokens)
prev_attention_mask = kwargs["attention_mask"]
print("attention_weight_size:", attention_weight_size)
print("prev_attention_mask.shape:", prev_attention_mask.shape)
new_attention_mask = torch.zeros(attention_weight_size, dtype=prev_attention_mask.dtype).to(
prev_attention_mask.device) + prev_attention_mask
# e.g. (32,32,27,27) + (32,1,27,27) => broadcasted to (32,32,27,27)
# target_position[0] : batch index ([0,1,...])
# first ":" : apply across all attention heads
# second "-1" : source token is the last token in sequence
# target_position[1] : target index ([18,12,...])
# torch.finfo(model.dtype).min : smallest floating-point type (negative infinity)
### tldr: mask attention flowing from source position to target position
if source == "all":
new_attention_mask[target_position[0], :, :, target_position[1]] = torch.finfo(model.dtype).min
# elif source == "last":
elif source == "final":
if num_tokens != 1:
new_attention_mask[target_position[0], :, -1, target_position[1]] = torch.finfo(model.dtype).min
else:
if num_tokens != 1:
new_attention_mask[target_position[0], :, source_position, target_position[1]] = torch.finfo(
model.dtype).min
new_kwargs["attention_mask"] = new_attention_mask
return original_forward_func(*new_args, **new_kwargs)
return knockout_attention_forward
if layer_idx != -1:
if k == "all":
attention_modules = get_all_attention_modules(model)
else:
attention_modules = get_attention_modules(model, layer_idx, k)
original_forward_funcs = [attention_modules.forward for attention_modules in attention_modules]
for attention_module in attention_modules:
attention_module.forward = wrap_attention_forward(attention_module.forward)
with torch.no_grad():
if default_decoding:
generated = model.generate(**inputs, pad_token_id=tokenizer.eos_token_id, max_new_tokens=10,
return_dict_in_generate=return_probabilities, output_scores=return_probabilities)
else:
generated = model.generate(**inputs, do_sample=False, temperature=1, top_p=1, num_beams=1,
pad_token_id=tokenizer.eos_token_id, max_new_tokens=10,
return_dict_in_generate=return_probabilities, output_scores=return_probabilities)
if layer_idx != -1:
for i, attention_module in enumerate(attention_modules):
attention_module.forward = original_forward_funcs[i]
if return_probabilities:
generations = decode_generated(tokenizer, generated["sequences"], prompt)
top_k_words, top_k_probabilities = get_top_k_tokens(tokenizer, generated["scores"][0])
return generations, top_k_words, top_k_probabilities
else:
return decode_generated(tokenizer, generated, prompt)
def generate_with_attention_knockout_all_layers(model, tokenizer, prompt, source, target_token_str, default_decoding,
k=0, return_probabilities=False): # prompt : list, source : string, target_token_str : list
generations_by_layer = {}
for layer in range(-1, model.config.num_hidden_layers):
generations_by_layer[layer] = generate_with_attention_knockout(model, tokenizer, layer, prompt, source,
target_token_str, default_decoding, k,
return_probabilities)
return generations_by_layer
def get_entity_ranks(model, tokenizer, projections, entity):
entity = [ent + [e.lower() for e in ent] for ent in entity]
if get_prepend_space(model):
entity = [ent + [" " + e for e in ent] for ent in entity]
tokenized_entity = [tokenizer(ent, add_special_tokens=False).input_ids for ent in entity]
ranks = []
for proj, tokens in zip(projections, tokenized_entity):
first_tokens = [t[0] for t in tokens]
ranks.append(get_token_ranks(proj, first_tokens))
return ranks
# def get_sublayer_projection(model, tokenizer, prompt, bridge_entities, answers, layers):
def get_sublayer_projection(model, tokenizer, prompt, answers, layers):
prompt_inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
sublayer_values = []
with torch.no_grad(), TraceDict(model, layers) as trace:
logits = model(**prompt_inputs).logits
for i in range(len(prompt)):
values = []
for layer in layers:
if type(trace[layer].output) is tuple:
output = trace[layer].output[0]
else:
output = trace[layer].output
values.append(output[i][-1].cpu())
sublayer_values.append(torch.stack(values))
sublayer_values = torch.stack(sublayer_values)
out_embeddings = model.get_output_embeddings()
norm = get_norm_module(model)
norm_device = next(norm.parameters()).device
with torch.no_grad():
projections = out_embeddings(norm(sublayer_values.to(norm_device)))
top_k_words, top_k_probabilities = get_top_k_tokens(tokenizer, projections)
answer_ranks = get_entity_ranks(model, tokenizer, projections, answers)
# bridge_entity_ranks = get_entity_ranks(model, tokenizer, projections, bridge_entities)
prediction_ranks = []
predicted_tokens = logits[:, -1, :].max(dim=-1).indices
for proj, predicted_token in zip(projections, predicted_tokens):
prediction_ranks.append(get_token_ranks(proj, [predicted_token]))
prediction_ranks = [r.flatten().tolist() for r in prediction_ranks]
# return top_k_words, top_k_probabilities, bridge_entity_ranks, answer_ranks, prediction_ranks
return top_k_words, top_k_probabilities, answer_ranks, prediction_ranks
# def get_attention_projection(model, tokenizer, prompt, bridge_entities, answers):
# return get_sublayer_projection(model, tokenizer, prompt, bridge_entities, answers,
# get_attention_layers_names(model))
# def get_mlp_projection(model, tokenizer, prompt, bridge_entities, answers):
# return get_sublayer_projection(model, tokenizer, prompt, bridge_entities, answers, get_mlp_layers_names(model))
def get_attention_projection(model, tokenizer, prompt, answers):
return get_sublayer_projection(model, tokenizer, prompt, answers,
get_attention_layers_names(model))
def get_mlp_projection(model, tokenizer, prompt, answers):
return get_sublayer_projection(model, tokenizer, prompt, answers, get_mlp_layers_names(model))