-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer.py
More file actions
91 lines (73 loc) · 3.8 KB
/
infer.py
File metadata and controls
91 lines (73 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ======================================================
# Fixed configuration
# ======================================================
MODEL_PATH = os.getenv("MODEL_PATH")
DTYPE = "bfloat16"
ATTN_IMPL = "flash_attention_2"
DEVICE_MAP = "auto"
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.6
TOP_P = 0.95
DO_SAMPLE = True
# ======================================================
# Prompts — USER ONLY EDITS THESE TWO
# ======================================================
SYSTEM_PROMPT = """You are an bio-expert scientific assistant. Your task is to provide a single, continuous string containing a complete, structured solution. Your response must begin with <think> and have four components, each enclosed in its respective tag in the following order: <think>, <key>, <orc>, and <note>. The <think> tag must contain your scientific reasoning and strategy. The <key> tag is the structured plan, and it is crucial that each step is a single JSON object with its content distilled into atomic keywords for the action, objects, and parameters. The <orc> tag will be a human-readable summary of the plan, and <note> will provide critical safety information. The final output structure must be: <think>...</think>\n\n<key>...</key>\n\n<orc>...</orc>\n\n<note>...</note>"""
USER_PROMPT = """You need to prepare gel embedding solution according to the protocol, which specifies mixing 5 mL gel embedding premix with 25 µL of 10% ammonium persulfate and 2.5 µL TEMED. However, you only need to embed a single thin brain slice and want to scale this recipe down to 1 mL of gel embedding premix while keeping the same ratios of components. What exact volumes of ammonium persulfate and TEMED should you add, and how should you prepare this scaled-down solution?"""
# ======================================================
# Load model
# ======================================================
def load_model_and_tokenizer(model_path: str):
dtype = torch.bfloat16 if DTYPE == "bfloat16" else torch.float16
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
attn_implementation=ATTN_IMPL,
device_map=DEVICE_MAP,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
return model, tokenizer
# ======================================================
# Single inference
# ======================================================
def generate(model, tokenizer, system_prompt: str, user_prompt: str) -> str:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=DO_SAMPLE,
pad_token_id=tokenizer.pad_token_id,
)
gen_ids = output_ids[0][inputs.input_ids.shape[1]:]
return tokenizer.decode(gen_ids, skip_special_tokens=True)
# ======================================================
# Main
# ======================================================
def main():
print("Loading model...")
model, tokenizer = load_model_and_tokenizer(MODEL_PATH)
print("Model loaded.\n")
result = generate(model, tokenizer, SYSTEM_PROMPT, USER_PROMPT)
print("=== Model Output ===")
print(result)
print("====================\n")
if __name__ == "__main__":
main()