-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathtest_inference.py
More file actions
71 lines (59 loc) · 2.36 KB
/
test_inference.py
File metadata and controls
71 lines (59 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import os
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
def find_latest_finetuned_model() -> str | None:
"""
Look under:
1. <script dir>/results/auto_antislop_runs
2. /results/auto_antislop_runs
Return the most recent `…/finetuned_model*/merged_16bit` directory or None.
"""
candidate_bases = [
Path(__file__).resolve().parent / "results" / "auto_antislop_runs",
Path("/results/auto_antislop_runs"),
]
latest: tuple[float, Path] | None = None
for base in candidate_bases:
if not base.is_dir():
continue
# run_*/finetuned_model*/merged_16bit
for merged_dir in base.glob("run_*/finetuned_model*/merged_16bit"):
if not merged_dir.is_dir():
continue
mtime = merged_dir.parent.stat().st_mtime # use finetuned_model* dir mtime
if latest is None or mtime > latest[0]:
latest = (mtime, merged_dir.resolve())
return str(latest[1]) if latest else None
model_path = find_latest_finetuned_model() or "."
print(f"Loading model from: {os.path.abspath(model_path)}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto" if device == "cuda" else None,
)
messages = [
{"role": "system", "content": "You are a creative storyteller."},
{"role": "user", "content": "Write a short, engaging story about a princess."}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
print("\nApplied chat template:\n", prompt)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generated_ids = model.generate(
input_ids,
max_new_tokens=500,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
response = generated_text[len(tokenizer.decode(input_ids[0], skip_special_tokens=True)):]
print("\n--- Generated Story ---\n", response)
print("\nToken count (approximate):", len(generated_ids[0]) - len(input_ids[0]))
except Exception as e:
print(f"Error: {e}")