-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
115 lines (92 loc) · 3.69 KB
/
utils.py
File metadata and controls
115 lines (92 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
import os
import boto3
from botocore.exceptions import ClientError
from nltk import word_tokenize
from openai import OpenAI
from prompts import *
def calculate_avg_std(df, file_path):
f = open(file_path, "w")
concated_text = df['Question'].str.cat(sep=' ')
token_count = len(set([token.lower() for token in word_tokenize(concated_text)]))
print(f"All Tokens count, {token_count}\n")
f.write(f"All Tokens count, {token_count}\n")
for column in df.columns:
if df[column].dtype in ['int64', 'float64', 'float32']:
avg = df[column].mean()
std = df[column].std()
formatted_avg = f"{avg:.2f}"
formatted_std = f"{std:.2f}"
print(f"{column}, {formatted_avg},{formatted_std}")
f.write(f"{column}, {formatted_avg},{formatted_std}\n")
f.flush()
f.close()
def chat(prompt: str, model: str, system: str = None) -> str:
############## GPT ##############
if "gpt" in model:
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
msg = [{"role": "user", "content": prompt}]
if system:
msg.append({"role": "system", "content": system})
chat_completion = client.chat.completions.create(messages=msg, model=model)
return chat_completion.choices[0].message.content
############## Claude ##############
elif model == "claude":
bedrock = boto3.client(service_name="bedrock-runtime")
body = json.dumps({
"max_tokens": 256000,
"messages": [{"role": "user", "content": prompt}],
"anthropic_version": "bedrock-2023-05-31"
})
response = bedrock.invoke_model(body=body, modelId="anthropic.claude-3-5-sonnet-20240620-v1:0")
response_body = json.loads(response.get("body").read())
return response_body.get("content")[0]['text']
############## llama3 ##############
elif model == "llama3":
client = boto3.client("bedrock-runtime", region_name="us-east-1")
model_id = "meta.llama3-70b-instruct-v1:0"
formatted_prompt = f"""
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{prompt}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
native_request = {
"prompt": formatted_prompt,
"max_gen_len": 2048,
"temperature": 0.5,
}
request = json.dumps(native_request)
try:
response = client.invoke_model(modelId=model_id, body=request)
except (ClientError, Exception) as e:
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
exit(1)
model_response = json.loads(response["body"].read())
response_text = model_response["generation"]
return response_text
else:
raise Exception("Invalid model version")
def generate_vignettes(disease, context, model, has_context=True):
if has_context:
vignette_prompt = brief_vignette_template.format(disease=disease, context=context)
else:
vignette_prompt = brief_vignette_template_wo_context
vignettes = chat(vignette_prompt, model)
return vignettes
# def truncate_text(input_text: str, limit=9830) -> str:
# tokenizer = tiktoken.encoding_for_model(model)
# tokens = tokenizer.encode(input_text)
# print("No of tokens:" + str(len(tokens)))
#
# if len(tokens) <= limit:
# return input_text
#
# truncated_tokens = tokens[:limit]
# truncated_text = tokenizer.decode(truncated_tokens)
# return truncated_text
def generate_pubmed_query(disease):
query = ""
for word in disease.split():
query += f'AND ({word}[Title]) \n\n'
return query[3:] + pubmed_query