-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
34 lines (31 loc) · 1.28 KB
/
utils.py
File metadata and controls
34 lines (31 loc) · 1.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import numpy as np
def extract_first_tensor_from_output(output):
if torch.is_tensor(output):
return output
if isinstance(output, (tuple, list)):
for item in output:
if torch.is_tensor(item):
return item
#recursively search nested tuples/lists for first tensor
if isinstance(item, (tuple, list)):
t = extract_first_tensor_from_output(item)
if t is not None:
return t
return None
def capture_activations(model, tokenizer, prompts, layer_num, device):
activations = []
def hook(module, input, output):
out_tensor = extract_first_tensor_from_output(output)
if out_tensor is None:
raise RuntimeError("no tensor in hook output")
#append activations of last token in sequence for each batch item
activations.append(out_tensor[:, -1, :].detach().cpu().numpy())
handle = model.transformer.h[layer_num].register_forward_hook(hook)
#forward pass for each prompt, collect and stack activations
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
_ = model(**inputs)
handle.remove() #clean up hook
return np.vstack(activations)