-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat-chainlit.py
116 lines (88 loc) · 3.52 KB
/
chat-chainlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import chainlit as cl
from chainlit.data.base import BaseDataLayer
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.types import ThreadDict
from dotenv import load_dotenv
from langchain_community.callbacks import OpenAICallbackHandler
from agent.graph import AgentGraph
from agent.profiles import ProfileName, get_chat_profiles
from agent.profiles.base import OutputState
from util.chainlit_helpers import (is_feature_enabled, message_rate_limited,
save_openai_metrics, static_messages,
update_search_results)
from util.config_yml import Config, TriggerEvent
from util.logging import logging
load_dotenv()
config: Config | None = Config.from_yaml()
profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me]
llm_graph = AgentGraph(profiles)
if os.getenv("POSTGRES_CHAINLIT_DB"):
CHAINLIT_DB_URI = f"postgresql+psycopg://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_CHAINLIT_DB')}?sslmode=disable"
@cl.data_layer
def get_data_layer() -> BaseDataLayer:
return SQLAlchemyDataLayer(conninfo=CHAINLIT_DB_URI)
else:
logging.warning("POSTGRES_CHAINLIT_DB undefined; Chainlit persistence disabled.")
if os.getenv("CHAINLIT_AUTH_SECRET"):
@cl.oauth_callback
def oauth_callback(
provider_id: str,
token: str,
raw_user_data: dict[str, str],
default_user: cl.User,
) -> cl.User | None:
return default_user
@cl.set_chat_profiles
async def chat_profiles() -> list[cl.ChatProfile]:
return [
cl.ChatProfile(
name=profile.name,
markdown_description=profile.description,
)
for profile in get_chat_profiles(profiles)
]
@cl.on_chat_start
async def start() -> None:
thread_id: str = cl.user_session.get("id")
cl.user_session.set("thread_id", thread_id)
await static_messages(config, TriggerEvent.on_chat_start)
@cl.on_chat_resume
async def resume(thread: ThreadDict) -> None:
await static_messages(config, TriggerEvent.on_chat_resume)
@cl.on_chat_end
async def end() -> None:
await static_messages(config, TriggerEvent.on_chat_end)
@cl.on_message
async def main(message: cl.Message) -> None:
if await message_rate_limited(config):
return
await static_messages(config, TriggerEvent.on_message)
message_count: int = cl.user_session.get("message_count", 0) + 1
cl.user_session.set("message_count", message_count)
chat_profile: str = cl.user_session.get("chat_profile")
thread_id: str = cl.user_session.get("thread_id")
chainlit_cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True,
force_stream_final_answer=True, # we're not using prefix tokens
)
openai_cb = OpenAICallbackHandler()
enable_postprocess: bool = is_feature_enabled(config, "postprocessing")
result: OutputState = await llm_graph.ainvoke(
message.content,
chat_profile.lower(),
callbacks=[chainlit_cb, openai_cb],
thread_id=thread_id,
enable_postprocess=enable_postprocess,
)
if (
enable_postprocess
and chainlit_cb.final_stream
and len(result["additional_content"]["search_results"]) > 0
):
await update_search_results(
result["additional_content"]["search_results"],
chainlit_cb.final_stream,
)
await static_messages(config, after_messages=message_count)
save_openai_metrics(message.id, openai_cb)