This repository was archived by the owner on Jan 23, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathgeneration.py
More file actions
151 lines (124 loc) · 4.71 KB
/
Copy pathgeneration.py
File metadata and controls
151 lines (124 loc) · 4.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/python
import argparse
import datetime
import os
import platform
import time
from typing import List
import torch
import torch_xla.core.xla_model as xm
from transformers import AutoTokenizer, StaticCache
from optimum.tpu.modeling import AutoModelForCausalLM
os.environ["PJRT_DEVICE"] = "TPU"
def sample_greedy(logits):
next_logits = logits[:, -1]
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int()
return next_token_id
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token,
position_ids=input_pos,
cache_position=cache_position,
return_dict=False,
use_cache=True,
past_key_values=past_key_values,
)[0]
new_token = sample_greedy(logits)
return new_token
def conditional_compile(func):
if "DBG_COMPILE" in os.environ:
compiled = torch.compile(func, backend="openxla")
return compiled
return func
def summary(values: List[float]):
values.sort()
n = len(values)
if n % 2 == 0:
median = (values[n // 2 - 1] + values[n // 2]) / 2
else:
median = values[n // 2]
total = sum(values)
mean = sum(values) / n
print(f"Decode time: {total}, average: {mean}, median: {median}")
def main():
parser = argparse.ArgumentParser(description="Text generation example")
parser.add_argument("--model_id", type=str,
default="meta-llama/Llama-3.2-1B-Instruct",
help="Model ID (e.g.: google/gemma-2b, mistralai/Mistral-7B-v0.3)")
parser.add_argument("--max_new_tokens", type=int, default=20, help="Number of tokens to generate")
parser.add_argument("--max_cache_length", type=int, default=256, help="Maximum cache length for the model")
args = parser.parse_args()
prg_start = time.time()
print(f"⏳ Loading model {args.model_id}...")
model_id = args.model_id
torch_dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
device = model.device
model = model.eval()
print(f"✅ Model loaded in {time.time() - prg_start} seconds on {device=}.")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Set pad token for cases where it is None, e.g. for Mistral
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
prompts = ["Here's a funny thing:", "Once upon a time,"]
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
batch_size, sequence_length = inputs["input_ids"].shape
max_cache_length = 1024
max_new_tokens = args.max_new_tokens
# setup static cache
past_key_values = StaticCache(
config=model.config,
max_batch_size=batch_size,
max_cache_len=max_cache_length,
device=model.device,
dtype=model.dtype,
)
start = time.time()
cache_position = torch.arange(sequence_length, device=device)
generated_ids = torch.zeros(
(batch_size, sequence_length + max_new_tokens + 1),
dtype=torch.int,
device=device,
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch.int)
# prefill here
attention_mask = inputs["attention_mask"]
pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
logits = model(
**inputs,
cache_position=cache_position,
return_dict=False,
use_cache=True,
position_ids=pos_ids,
past_key_values=past_key_values,
)[0]
next_token = sample_greedy(logits)
xm.mark_step()
generated_ids[:, sequence_length] = next_token[:, 0]
end = time.time()
print(f"Prefill took {end - start} seconds.")
pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1
model = conditional_compile(model)
cache_position = torch.tensor([sequence_length], device=device)
decode_times = []
for i in range(max_new_tokens):
step_start = time.time()
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, past_key_values)
cache_position += 1
generated_ids[:, cache_position] = next_token
pos_ids += 1
xm.mark_step()
step_end = time.time()
step_time = step_end - step_start
decode_times.append(step_time)
print(f"Step {i} took {step_time} seconds.")
summary(decode_times)
print(f"Decoding start at {datetime.datetime.now()}")
decoded_texts = tokenizer.batch_decode(generated_ids)
for i, text in enumerate(decoded_texts):
print(i, text)
end = time.time()
print(f"Program run in {end - prg_start} seconds. Device: {device} System: {platform.system()}")
if __name__ == "__main__":
with torch.no_grad():
main()