-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathrun.py
More file actions
executable file
·116 lines (98 loc) · 4.82 KB
/
run.py
File metadata and controls
executable file
·116 lines (98 loc) · 4.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import argparse
from tqdm import tqdm
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from transformers import AutoTokenizer, AutoModelForCausalLM
from lacache.kv_cache import LaCache
from transformers.cache_utils import DynamicCache
def load(model_name_or_path):
print(f"Loading model from {model_name_or_path} ...")
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
# attn_implementation="flash_attention_2", # uncomment for enabling flash-attention2
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
)
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.pad_token_id = 0
model.eval()
return model, tokenizer
def main(args):
data = load_dataset(args.dataset_name, args.task, split=args.split)
model, tokenizer = load(args.model_name_or_path)
device = "cuda"
if not "llama" in model.config.model_type:
raise ValueError(f"got {model.config.model_type}")
nlls = []
loss_fn = CrossEntropyLoss(reduction="none")
if args.enable_lacache:
kv_cache = LaCache(cache_size=args.cache_size, span=args.span, overlap=args.overlap)
from lacache.llama_patch import enable_llama_pos_shift_attention
enable_llama_pos_shift_attention(model)
else:
kv_cache = None
if args.dataset_name == "wikitext":
encodings = tokenizer("".join(data["text"]), return_tensors="pt")
# torch.save(encodings, 'encodings_wikitext_llama2.pt') # one may first save encodings as pt for a faster loading,
# torch.save(encodings, 'encodings_wikitext_llama3.pt') # one may first save encodings as pt for a faster loading,
# encodings = torch.load('encodings_wikitext_llama2.pt') # and then loading encodings direcntly from pt
# encodings = torch.load('encodings_wikitext_llama3.pt') # and then loading encodings direcntly from pt
elif args.dataset_name == "emozilla/pg19-test":
encodings = tokenizer("".join(data["text"][:2]), return_tensors="pt") # [:2]: using only first 2 books for a faster tokenization, adjust it for a longer length
else:
raise ValueError(f"got {args.dataset_name}")
num_eval_tokens = 0
pbar = tqdm(range(0, args.num_eval_tokens + 1))
nlls = []
past_key_values, cache_position = DynamicCache(), None
for idx in pbar:
input_ids = encodings.input_ids[:, idx : idx + 1].to(device)
if cache_position is None:
cache_position = torch.arange(input_ids.shape[1], device=input_ids.device)
else:
if (not args.enable_lacache) or cache_position[-1:] + 1 <= legacy_cache[0][0].shape[2]:
cache_position = cache_position[-1:] + 1
with torch.no_grad():
inputs = model.prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
cache_position=cache_position,
use_cache=True
)
outputs = model(**inputs)
logits = outputs.logits.view(-1, model.config.vocab_size)
past_key_values = outputs.past_key_values
label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1)
neg_log_likelihood = loss_fn(logits, label)
if args.enable_lacache:
legacy_cache = kv_cache(past_key_values.to_legacy_cache())
past_key_values = DynamicCache.from_legacy_cache(legacy_cache)
nlls.append(neg_log_likelihood)
pbar.set_description(f"nll: {neg_log_likelihood.item():.2f}, ppl: {torch.exp(neg_log_likelihood).item():.2f}")
num_eval_tokens += 1
if args.num_eval_tokens is not None and num_eval_tokens > args.num_eval_tokens:
break
ppl = torch.exp(torch.stack(nlls).mean())
print(f"ppl: {ppl.item()}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="models/llama/llama-7b")
parser.add_argument("--dataset_name", type=str, default="wikitext")
parser.add_argument("--task", type=str, default="wikitext-2-raw-v1")
parser.add_argument("--split", type=str, default="test", choices=["validation", "test"])
parser.add_argument("--enable_lacache", action="store_true")
parser.add_argument("--cache_size", type=int, default=256)
parser.add_argument("--span", type=int, default=16)
parser.add_argument("--overlap", type=int, default=0)
parser.add_argument("--num_eval_tokens", type=int, default=None)
args = parser.parse_args()
main(args)