Skip to content

Enables return of activation cache variables during generation #838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 104 additions & 12 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from transformer_lens.utilities import devices
from transformer_lens.utils import (
USE_DEFAULT_VALUE,
Slice,
init_kaiming_normal_,
init_kaiming_uniform_,
init_xavier_normal_,
Expand Down Expand Up @@ -2084,7 +2085,7 @@ def generate(
use_past_kv_cache: bool = True,
prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
return_type: Optional[str] = "input",
return_type: Optional[Union[str, List[str]]] = "input",
verbose: bool = True,
) -> Union[
str,
Expand Down Expand Up @@ -2130,7 +2131,7 @@ def generate(
padding_side (Union[Literal["left", "right"], None], optional): Overrides
self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
strings of different lengths.
return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'),
return_type (Optional[Union[str, List[str]]]): The type of the output to return - a string or a list of strings ('str'),
a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the
input was ('input').
verbose (bool): If True, show tqdm progress bars for generation.
Expand All @@ -2151,12 +2152,15 @@ def generate(
or not isinstance(input, list)
), "Input must be either string, torch.Tensor, or List[str]"

assert return_type in [

return_types = return_type if isinstance(return_type, list) else [return_type]

assert all([rt in [
"input",
"str",
"tokens",
"embeds",
], "return_type must be one of ['input', 'str', 'tokens', 'embeds']"
] for rt in return_types]), "return_type must be one of ['input', 'str', 'tokens', 'embeds']"

if return_type == "input":
if isinstance(input, (str, list)):
Expand Down Expand Up @@ -2225,6 +2229,49 @@ def generate(
# that changes in the future.
self.eval()
sampled_tokens_list = []
cache_dict_tape = None
logits_tape = None
token_tape = None

if return_type:
# defaults from hook_points.py#L510
names_filter = None
device = None
remove_batch_dim: bool = False
incl_bwd: bool = False
reset_hooks_end: bool = True
clear_contexts: bool = True#False
pos_slice = None

pos_slice = Slice.unwrap(pos_slice)

cache_dict, fwd, bwd = self.get_caching_hooks(
names_filter,
incl_bwd,
device,
remove_batch_dim=remove_batch_dim,
pos_slice=pos_slice,
)

def forward_(*model_args, **model_kwargs):
if return_type:
# cache_dict is changed in-place?
with self.hooks(
fwd_hooks=fwd,
bwd_hooks=bwd,
reset_hooks_end=reset_hooks_end,
clear_contexts=clear_contexts,
):
model_out = self(*model_args, **model_kwargs)
if incl_bwd:
model_out.backward()
return model_out
else:
model_out = self.forward(*model_args,
**model_kwargs,
)
return model_out

for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
pos_offset = self.get_pos_offset(past_kv_cache, batch_size)

Expand All @@ -2247,7 +2294,7 @@ def generate(
if use_past_kv_cache:
# We just take the final tokens, as a [batch, 1] tensor
if index > 0:
logits = self.forward(
logits = forward_(
residual[:, -1:],
return_type="logits",
prepend_bos=prepend_bos,
Expand All @@ -2257,7 +2304,7 @@ def generate(
shortformer_pos_embed=shortformer_pos_embed,
)
else:
logits = self.forward(
logits = forward_(
residual,
return_type="logits",
prepend_bos=prepend_bos,
Expand All @@ -2269,7 +2316,7 @@ def generate(
else:
# We input the entire sequence, as a [batch, pos] tensor, since we aren't using
# the cache.
logits = self.forward(
logits = forward_(
residual,
return_type="logits",
prepend_bos=prepend_bos,
Expand All @@ -2279,6 +2326,7 @@ def generate(
)
final_logits = logits[:, -1, :]

# sampling
if do_sample:
if input_type in [
"str",
Expand Down Expand Up @@ -2317,7 +2365,42 @@ def generate(
)
)

# concatenate the new tokens
embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))])

# concatenate the new tokens
tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1)

# concatenate the cache
# we need to clone on the first pass to prevent overwrite
token_tape = torch.cat([token_tape, sampled_tokens.unsqueeze(-1)], dim=-1) if token_tape is not None else torch.clone(tokens[:, -ctx_length:]) # appends to the prompt tokens
if 'cache' in return_type:
def cat_cache_var(key, var_tape, var):
if not any(key_ in key for key_ in ['attn_scores', 'hook_pattern']): # only for the vector-valued vars
cat_var = torch.cat([var_tape, var[:, -1:]], dim=1)
return cat_var
else:
var_tape = torch.nn.functional.pad(var_tape, (0,1,0,1), value=0) # right-pads the last two dimensions
slice1 = var[:, :, -1:, :]
T = slice1.shape[-1]
var_tape[..., -1:, -T:] = slice1

# Update for x[:, :, :, -1]
slice2 = var[:, :, :, -1:]
var_tape[..., -T:, -1:] = slice2
return var_tape

cache_dict_tape = (
{k: cat_cache_var(k, cache_dict_tape[k], cache_dict[k]) for k in cache_dict}
if cache_dict_tape is not None
else {k: torch.clone(v) for k, v in cache_dict.items()} # initializes the dict with the initial cache
)

logits_tape = (
torch.cat([logits_tape, logits[:, -1:]], dim=1)
if logits_tape is not None
else torch.clone(logits[:, -ctx_length:])
)

if stop_at_eos and finished_sequences.all():
break
Expand All @@ -2328,16 +2411,25 @@ def generate(
else:
output_tokens = sampled_tokens

if return_type == "str":
# compute return objects as requested
result = dict()
if "str" in return_types:
decoded_texts = [
self.tokenizer.decode(tokens, skip_special_tokens=True)
for tokens in output_tokens
]
return decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts
elif return_type == "tokens":
return output_tokens
result['str'] = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts
if "tokens" in return_types:
result['tokens'] = output_tokens
if "cache" in return_types:
result['cache'] = cache_dict_tape
if "embeds" in return_types:
result['embeds'] = embeds

if not isinstance(return_type, list):
return result[return_type]
else:
return embeds
return result

# Give access to all weights as properties.
@property
Expand Down
Loading