diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..77cd4ea --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +.ruff_cache \ No newline at end of file diff --git a/README.md b/README.md index 5b1804f..d75cf13 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,8 @@ To execute a demonstration of SelfExtend on the Passkey Retrivale task, you can python llama_example.py # llama python mistral_example.py # mistra + +python gemma_example.py # gemma ``` @@ -170,6 +172,39 @@ SelfExtend: [What is the pass key? The pass key is 58328.] ----------------------------------- ``` +For Gemma +```bash +----------------------------------- +#Tokens of Prompt: 5142 Passkey target: 89427 +Gemma: [What is the pass key? The pass key is 89427.] +SelfExtend: [What is the pass key? The pass key is 89427.] +----------------------------------- + +----------------------------------- +#Tokens of Prompt: 5142 Passkey target: 51906 +Gemma: [What is the pass key? The pass key is 519. Here.] +SelfExtend: [What is the pass key? The pass key is 51906.] +----------------------------------- + +----------------------------------- +#Tokens of Prompt: 5142 Passkey target: 38117 +Gemma: [What is the pass key? The pass key is 38117.] +SelfExtend: [What is the pass key? The pass key is 38117.] +----------------------------------- + +----------------------------------- +#Tokens of Prompt: 5142 Passkey target: 60151 +Gemma: [What is the pass key? The pass key is 60151.] +SelfExtend: [What is the pass key? The pass key is 60151.] +----------------------------------- + +----------------------------------- +#Tokens of Prompt: 5142 Passkey target: 23789 +Gemma: [What is the pass key? The pass key is 2378. The] +SelfExtend: [What is the pass key? The pass key is 23789.] +----------------------------------- +``` + ## 4.How to choose the group_size and neighbor_window diff --git a/gemma_example.py b/gemma_example.py new file mode 100644 index 0000000..21860bc --- /dev/null +++ b/gemma_example.py @@ -0,0 +1,61 @@ +# transfromers version 4.32.0 +import warnings + +warnings.filterwarnings("ignore") + +import gemma_self_extend_patch as GemmaSE +from modify_utils import modify_method_of_instance +from functools import partial +import json +from transformers.models.gemma.modeling_gemma import GemmaAttention +from transformers import AutoTokenizer, AutoModelForCausalLM + +original_gemma_forward = GemmaAttention.forward +self_extend_forward = partial( + GemmaSE.self_extend_forward, group_size_1=8, group_size_2=1024 +) + +device = "cpu" +model_path = "google/gemma-2b-it" +model = AutoModelForCausalLM.from_pretrained(model_path).to(device) +tokenizer = AutoTokenizer.from_pretrained(model_path) +model.eval() + + +for line in open("passkey_examples_5k.jsonl", "r"): + example = json.loads(line) + prompt_postfix = "What is the pass key? The pass key is " + prompt = example["input"] + prompt_postfix + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + print("-----------------------------------") + print(f"#Tokens of Prompt:", input_ids.shape[1], end=" ") + print("Passkey target:", example["target"]) + + modify_method_of_instance( + model, "GemmaAttention", "forward", original_gemma_forward + ) + tokens = model.generate(input_ids, max_new_tokens=6) + answer = ( + "Gemma: [" + + prompt_postfix + + tokenizer.decode( + tokens[0].tolist()[input_ids.shape[1] :], skip_special_tokens=True + ) + + "]" + ) + answer = answer.replace("\n", "\\n") + print(answer) + + modify_method_of_instance(model, "GemmaAttention", "forward", self_extend_forward) + tokens = model.generate(input_ids, max_new_tokens=6) + answer = ( + "SelfExtend: [" + + prompt_postfix + + tokenizer.decode( + tokens[0].tolist()[input_ids.shape[1] :], skip_special_tokens=True + ) + + "]" + ) + answer = answer.replace("\n", "\\n") + print(answer) + print("-----------------------------------\n") diff --git a/gemma_self_extend_patch.py b/gemma_self_extend_patch.py index 1c1a8f1..e530ba3 100644 --- a/gemma_self_extend_patch.py +++ b/gemma_self_extend_patch.py @@ -1,18 +1,16 @@ # transformers version: 4.38.1 -import torch -from transformers.models.llama.modeling_llama import * -from transformers.models.gpt_neox.modeling_gpt_neox import * -import numpy as np -import torch.nn as nn -import math +import warnings from typing import Optional, Tuple +import math +import torch import transformers -if transformers.__version__ >= '4.36': +if transformers.__version__ >= "4.36": from transformers.cache_utils import Cache + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -21,9 +19,12 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -83,52 +84,88 @@ def self_extend_forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) if log_scale_base > 0: - scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(log_scale_base)).clip(1).to(query_states.dtype) + scaled_query = query_states * ( + (position_ids + 1)[:, None, :, None].log() / np.log(log_scale_base) + ).clip(1).to(query_states.dtype) else: scaled_query = query_states - + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) kv_seq_len = key_states.shape[-2] query_position = position_ids - key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(bsz, kv_seq_len) - - - neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position, seq_len=None) - neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position, seq_len=None) - - - _re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position - group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 / group_size_1 + key_position = ( + position_ids + if q_len != 1 + else torch.arange(kv_seq_len, dtype=position_ids.dtype) + .to(query_position.device) + .view(bsz, kv_seq_len) + ) + + neighbor_q_cos, neighbor_q_sin = self.rotary_emb( + value_states, query_position, seq_len=None + ) + neighbor_k_cos, neighbor_k_sin = self.rotary_emb( + value_states, key_position, seq_len=None + ) + + _re_group_size_2 = ( + 0 if query_position.max() < group_size_2 else group_size_2 + ) # in case that, the smallest q position, g2-g2//g1 exceed the max position + group_query_position = ( + query_position // group_size_1 + + _re_group_size_2 + - _re_group_size_2 / group_size_1 + ) group_key_position = key_position // group_size_1 - group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position, seq_len=None) - group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position, seq_len=None) - - - - neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None) - _, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None) - group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None) - _, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None) - + group_q_cos, group_q_sin = self.rotary_emb( + value_states, group_query_position, seq_len=None + ) + group_k_cos, group_k_sin = self.rotary_emb( + value_states, group_key_position, seq_len=None + ) + + neighbor_query_states, _ = apply_rotary_pos_emb( + scaled_query, None, neighbor_q_cos, neighbor_q_sin, None + ) + _, neighbor_key_states = apply_rotary_pos_emb( + None, key_states, neighbor_k_cos, neighbor_k_sin, None + ) + group_query_states, _ = apply_rotary_pos_emb( + scaled_query, None, group_q_cos, group_q_sin, None + ) + _, group_key_states = apply_rotary_pos_emb( + None, key_states, group_k_cos, group_k_sin, None + ) neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups) group_key_states = repeat_kv(group_key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - + neighbor_attn_weights = torch.matmul( + neighbor_query_states, neighbor_key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + group_attn_weights = torch.matmul( + group_query_states, group_key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it if cache_position is not None: @@ -139,23 +176,40 @@ def self_extend_forward( neighbor_attn_weights = neighbor_attn_weights + causal_mask if q_len == 1: - neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device) + neighbor_attention_mask = torch.zeros( + (q_len, kv_seq_len), device=neighbor_attn_weights.device + ) neighbor_attention_mask[:, -group_size_2:] = 1 elif q_len == kv_seq_len: - neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device) + neighbor_attention_mask = torch.ones( + (q_len, kv_seq_len), device=neighbor_attn_weights.device + ) neighbor_attention_mask = torch.tril(neighbor_attention_mask) - if q_len-group_size_2 > 0: - group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device)) - neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask + if q_len - group_size_2 > 0: + group_attention_mask = torch.tril( + torch.ones( + (q_len - group_size_2, kv_seq_len - group_size_2), + device=group_attn_weights.device, + ) + ) + neighbor_attention_mask[group_size_2:, :-group_size_2] -= ( + group_attention_mask + ) else: raise ValueError("q_len should be 1 or seq_len.") - + neighbor_attention_mask = neighbor_attention_mask.bool() - attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights) - + attn_weights = torch.where( + neighbor_attention_mask, neighbor_attn_weights, group_attn_weights + ) + # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = torch.nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = torch.nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -169,7 +223,6 @@ def self_extend_forward( attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - if not output_attentions: attn_weights = None