Skip to content

Commit cebf824

Browse files
committed
started to add support for getting json responses from LLMs calls : Prompt_To_Json__Open_AI
1 parent 90d9c8c commit cebf824

File tree

7 files changed

+151
-22
lines changed

7 files changed

+151
-22
lines changed

osbot_llms/fast_api/routes/Routes__Chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def completion(self, request: Request, llm_chat_completion: LLMs__Chat_Com
8585
request_id = self.request_id(request)
8686
chat_save_result = self.chats_storage_s3_minio.save_user_request(llm_chat_completion, request_id)
8787

88-
routes_open_ai = Routes__OpenAI()
88+
routes_open_ai = Routes__OpenAI() # todo: fix this mess of having to use a new instance of Routes__OpenAI
8989
user_data = llm_chat_completion.user_data
9090
if user_data is None:
9191
user_data = dict(selected_platform = llm_chat_completion.llm_platform ,

osbot_llms/llms/API_Open_AI.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def api_key(self):
3535
load_dotenv()
3636
return getenv(OPEN_AI__API_KEY)
3737

38+
def client(self):
39+
return OpenAI(api_key=self.api_key())
40+
3841
def embeddings(self, input, model='text-embedding-3-small', dimensions=None):
3942
url = 'https://api.openai.com/v1/embeddings'
4043
headers = { "Content-Type" : "application/json" ,
@@ -53,6 +56,7 @@ def embeddings(self, input, model='text-embedding-3-small', dimensions=None):
5356
total_tokens = total_tokens )
5457
return result
5558

59+
5660
def open_ai_available(self):
5761
if self.api_key():
5862
#if is_url_online(URL_OPEN_AI_BASE): # todo, find a better way (or url) to do this
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import List, Dict
2+
3+
import openai
4+
from pydantic import BaseModel
5+
from pydantic._internal._model_construction import ModelMetaclass
6+
7+
from osbot_llms.llms.API_Open_AI import API_Open_AI
8+
from osbot_utils.base_classes.Type_Safe import Type_Safe
9+
from osbot_utils.utils.Json import str_to_json
10+
11+
12+
class Prompt_To_Json__Open_AI(Type_Safe):
13+
response_format : ModelMetaclass
14+
messages : List[Dict[str, str]]
15+
model : str
16+
temperature : float
17+
seed : int
18+
19+
def add_message__assistant(self, message):
20+
return self.add_message("assistant", message)
21+
22+
def add_message__user(self, message):
23+
return self.add_message("user", message)
24+
25+
def add_message__system(self, message):
26+
return self.add_message("system", message)
27+
28+
def add_message(self,role, content):
29+
self.messages.append(dict(role= role, content= content))
30+
return self
31+
32+
def invoke(self):
33+
response = self.invoke__raw()
34+
response_parsed = self.parse_response(response)
35+
return response_parsed
36+
37+
def invoke__raw(self):
38+
client = API_Open_AI().client()
39+
40+
try:
41+
completion = client.beta.chat.completions.parse(**self.invoke_kwargs())
42+
return completion
43+
except Exception as exception: # todo: figure out the exceptions to handle here
44+
raise exception
45+
# # Handle edge cases
46+
# if type(e) == openai.LengthFinishReasonError:
47+
# # Retry with a higher max tokens
48+
# print("Too many tokens: ", e)
49+
# pass
50+
# else:
51+
# # Handle other exceptions
52+
# print(e)
53+
# pass
54+
55+
def invoke_kwargs(self):
56+
return dict(model = self.model ,
57+
messages = self.messages ,
58+
response_format = self.response_format,
59+
seed = self.seed ,
60+
temperature =self.temperature )
61+
62+
def set_model(self, model):
63+
self.model = model
64+
return self
65+
66+
def set_model__gpt_4o(self):
67+
return self.set_model("gpt-4o")
68+
69+
def set_model__gpt_4o_mini(self):
70+
return self.set_model("gpt-4o-mini")
71+
72+
def set_response_format(self, response_format):
73+
self.response_format = response_format
74+
return self
75+
76+
def parse_response(self, response):
77+
choice = response.choices[0]
78+
message = choice.message
79+
usage = response.usage
80+
content = str_to_json(message.content)
81+
model = message.parsed
82+
tokens = usage.total_tokens
83+
return dict(content = content,
84+
model = model ,
85+
tokens = tokens )

osbot_llms/llms/prompt_to_json/__init__.py

Whitespace-only changes.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from unittest import TestCase
2+
from pydantic import BaseModel
3+
from osbot_llms.llms.prompt_to_json.Prompt_To_Json__Open_AI import Prompt_To_Json__Open_AI
4+
5+
6+
7+
class test_Prompt_To_Json__Open_AI(TestCase):
8+
9+
@classmethod
10+
def setUpClass(cls) -> None:
11+
cls.prompt_to_json = Prompt_To_Json__Open_AI()
12+
13+
def test_invoke(self):
14+
class CalendarEvent(BaseModel):
15+
name: str
16+
date: str
17+
participants: list[str]
18+
action: str
19+
20+
with self.prompt_to_json as _:
21+
_.set_model__gpt_4o_mini()
22+
_.set_response_format(CalendarEvent)
23+
_.add_message__system("Extract the event information." )
24+
_.add_message__user ("Alice and Bob are going to a science fair on Friday.")
25+
26+
response = _.invoke()
27+
28+
expected_content = { 'action' : 'Attend' ,
29+
'date' : 'Friday' ,
30+
'name' : 'Science Fair',
31+
'participants': ['Alice', 'Bob']}
32+
assert response == dict(content = expected_content ,
33+
model = CalendarEvent(**expected_content),
34+
tokens = 124 )

tests/integration/testing/test_TestCase__S3_Minio__Temp_Chat_Threads.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from osbot_aws.AWS_Config import aws_config
22
from osbot_aws.aws.s3.S3__DB_Base import S3_DB_BASE__BUCKET_NAME__PREFIX
3+
from osbot_utils.utils.Env import in_github_action
34
from osbot_utils.utils.Misc import list_set
45
from osbot_llms.backend.s3_minio.S3_DB__Chat_Threads import S3_DB__Chat_Threads
56
from osbot_llms.testing.TestCase__S3_Minio__Temp_Chat_Threads import TestCase__S3_Minio__Temp_Chat_Threads
@@ -16,16 +17,17 @@ def tearDownClass(cls):
1617
assert cls.s3_db_chat_threads.bucket_exists() is False
1718

1819
def test__setUpClass(self):
19-
assert list_set(self.extra_env_vars) == [ 'AWS_ACCESS_KEY_ID' ,
20-
'AWS_ACCOUNT_ID' ,
21-
'AWS_DEFAULT_REGION' ,
22-
'AWS_SECRET_ACCESS_KEY' ,
23-
'USE_MINIO_AS_S3' ]
24-
assert self.random_aws_creds.original_env_vars == { 'AWS_ACCESS_KEY_ID' : None ,
25-
'AWS_ACCOUNT_ID' : None ,
26-
'AWS_DEFAULT_REGION' : None ,
27-
'AWS_SECRET_ACCESS_KEY': None ,
28-
'USE_MINIO_AS_S3' : None }
20+
if in_github_action():
21+
assert list_set(self.extra_env_vars) == [ 'AWS_ACCESS_KEY_ID' ,
22+
'AWS_ACCOUNT_ID' ,
23+
'AWS_DEFAULT_REGION' ,
24+
'AWS_SECRET_ACCESS_KEY' ,
25+
'USE_MINIO_AS_S3' ]
26+
assert self.random_aws_creds.original_env_vars == { 'AWS_ACCESS_KEY_ID' : None ,
27+
'AWS_ACCOUNT_ID' : None ,
28+
'AWS_DEFAULT_REGION' : None ,
29+
'AWS_SECRET_ACCESS_KEY': None ,
30+
'USE_MINIO_AS_S3' : None }
2931
assert self.server_name == 'osbot-llms'
3032
assert type(self.s3_db_chat_threads) is S3_DB__Chat_Threads
3133
assert self.s3_db_chat_threads.bucket_exists() is True

tests/integration/testing/test_TestCase__S3_Minio__Temp_S3_Bucket.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import pytest
2+
13
from osbot_aws.AWS_Config import aws_config
2-
from osbot_aws.aws.s3.S3__DB_Base import S3__DB_Base, S3_DB_BASE__BUCKET_NAME__PREFIX, S3_DB_BASE__SERVER_NAME, \
4+
from osbot_aws.aws.s3.S3__DB_Base import S3__DB_Base, S3_DB_BASE__BUCKET_NAME__PREFIX, S3_DB_BASE__SERVER_NAME, \
35
S3_DB_BASE__BUCKET_NAME__SUFFIX
46
from osbot_aws.testing.TestCase__S3_Minio__Temp_S3_Bucket import TestCase__S3_Minio__Temp_S3_Bucket
7+
from osbot_utils.utils.Env import in_github_action
58
from osbot_utils.utils.Misc import list_set
69

710
from osbot_llms.OSBot_LLMs__Server_Config import DEFAULT__SERVER_CONFIG__SERVER_NAME
@@ -19,16 +22,17 @@ def tearDownClass(cls):
1922
assert cls.s3_db_base.bucket_exists() is False
2023

2124
def test__setUpClass(self):
22-
assert list_set(self.extra_env_vars) == [ 'AWS_ACCESS_KEY_ID' ,
23-
'AWS_ACCOUNT_ID' ,
24-
'AWS_DEFAULT_REGION' ,
25-
'AWS_SECRET_ACCESS_KEY' ,
26-
'USE_MINIO_AS_S3' ]
27-
assert self.random_aws_creds.original_env_vars == { 'AWS_ACCESS_KEY_ID' : None ,
28-
'AWS_ACCOUNT_ID' : None ,
29-
'AWS_DEFAULT_REGION' : None ,
30-
'AWS_SECRET_ACCESS_KEY': None ,
31-
'USE_MINIO_AS_S3' : None }
25+
if in_github_action():
26+
assert list_set(self.extra_env_vars) == [ 'AWS_ACCESS_KEY_ID' ,
27+
'AWS_ACCOUNT_ID' ,
28+
'AWS_DEFAULT_REGION' ,
29+
'AWS_SECRET_ACCESS_KEY' ,
30+
'USE_MINIO_AS_S3' ]
31+
assert self.random_aws_creds.original_env_vars == { 'AWS_ACCESS_KEY_ID' : None ,
32+
'AWS_ACCOUNT_ID' : None ,
33+
'AWS_DEFAULT_REGION' : None ,
34+
'AWS_SECRET_ACCESS_KEY': None ,
35+
'USE_MINIO_AS_S3' : None }
3236
assert type(self.s3_db_base) is S3__DB_Base
3337
assert self.s3_db_base.bucket_exists() is True
3438
assert aws_config.account_id() == self.random_aws_creds.env_vars['AWS_ACCOUNT_ID']

0 commit comments

Comments
 (0)