|
10 | 10 | from agent.graph import AgentGraph
|
11 | 11 | from agent.profiles import ProfileName, get_chat_profiles
|
12 | 12 | from agent.profiles.base import OutputState
|
13 |
| -from util.chainlit_helpers import (is_feature_enabled, message_rate_limited, |
14 |
| - save_openai_metrics, static_messages, |
15 |
| - update_search_results) |
| 13 | +from util.chainlit_helpers import (PrefixedS3StorageClient, is_feature_enabled, |
| 14 | + message_rate_limited, save_openai_metrics, |
| 15 | + static_messages, update_search_results) |
16 | 16 | from util.config_yml import Config, TriggerEvent
|
17 | 17 | from util.logging import logging
|
18 | 18 |
|
|
22 | 22 | profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me]
|
23 | 23 | llm_graph = AgentGraph(profiles)
|
24 | 24 |
|
25 |
| -if os.getenv("POSTGRES_CHAINLIT_DB"): |
26 |
| - CHAINLIT_DB_URI = f"postgresql+psycopg://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_CHAINLIT_DB')}?sslmode=disable" |
| 25 | +POSTGRES_CHAINLIT_DB = os.getenv("POSTGRES_CHAINLIT_DB") |
| 26 | +POSTGRES_USER = os.getenv("POSTGRES_USER") |
| 27 | +POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") |
| 28 | +S3_BUCKET = os.getenv("S3_BUCKET") |
| 29 | +S3_CHAINLIT_PREFIX = os.getenv("S3_CHAINLIT_PREFIX") |
| 30 | + |
| 31 | +if POSTGRES_CHAINLIT_DB and POSTGRES_USER and POSTGRES_PASSWORD: |
| 32 | + CHAINLIT_DB_URI = f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@postgres:5432/{POSTGRES_CHAINLIT_DB}?sslmode=disable" |
| 33 | + |
| 34 | + if S3_BUCKET and S3_CHAINLIT_PREFIX: |
| 35 | + storage_client = PrefixedS3StorageClient(S3_BUCKET, S3_CHAINLIT_PREFIX) |
| 36 | + else: |
| 37 | + storage_client = None |
27 | 38 |
|
28 | 39 | @cl.data_layer
|
29 | 40 | def get_data_layer() -> BaseDataLayer:
|
30 |
| - return SQLAlchemyDataLayer(conninfo=CHAINLIT_DB_URI) |
| 41 | + return SQLAlchemyDataLayer( |
| 42 | + conninfo=CHAINLIT_DB_URI, |
| 43 | + storage_provider=storage_client, |
| 44 | + ) |
31 | 45 |
|
32 | 46 | else:
|
33 | 47 | logging.warning("POSTGRES_CHAINLIT_DB undefined; Chainlit persistence disabled.")
|
@@ -57,8 +71,8 @@ async def chat_profiles() -> list[cl.ChatProfile]:
|
57 | 71 |
|
58 | 72 | @cl.on_chat_start
|
59 | 73 | async def start() -> None:
|
60 |
| - thread_id: str = cl.user_session.get("id") |
61 |
| - cl.user_session.set("thread_id", thread_id) |
| 74 | + if cl.user_session.get("thread_id") is None: |
| 75 | + cl.user_session.set("thread_id", cl.user_session.get("id")) |
62 | 76 | await static_messages(config, TriggerEvent.on_chat_start)
|
63 | 77 |
|
64 | 78 |
|
@@ -100,15 +114,16 @@ async def main(message: cl.Message) -> None:
|
100 | 114 | thread_id=thread_id,
|
101 | 115 | enable_postprocess=enable_postprocess,
|
102 | 116 | )
|
| 117 | + assistant_message: cl.Message | None = chainlit_cb.final_stream |
103 | 118 |
|
104 | 119 | if (
|
105 | 120 | enable_postprocess
|
106 |
| - and chainlit_cb.final_stream |
| 121 | + and assistant_message |
107 | 122 | and len(result["additional_content"]["search_results"]) > 0
|
108 | 123 | ):
|
109 | 124 | await update_search_results(
|
110 | 125 | result["additional_content"]["search_results"],
|
111 |
| - chainlit_cb.final_stream, |
| 126 | + assistant_message, |
112 | 127 | )
|
113 | 128 |
|
114 | 129 | await static_messages(config, after_messages=message_count)
|
|
0 commit comments