Skip to content

Commit 8bd10c9

Browse files
Merge pull request #79 from reactome/enhance-guest
User Feedback on Guest + S3StorageClient
2 parents bf34543 + 3bcf88a commit 8bd10c9

File tree

7 files changed

+79
-58
lines changed

7 files changed

+79
-58
lines changed

.chainlit/translations/en-US.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
"negative": "Not helpful",
9595
"edit": "Edit feedback",
9696
"dialog": {
97-
"title": "Add a comment",
97+
"title": "Add a comment (optional)",
9898
"submit": "Submit feedback"
9999
},
100100
"status": {
@@ -188,4 +188,4 @@
188188
"saved": "Saved successfully"
189189
}
190190
}
191-
}
191+
}

bin/chat-chainlit.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from agent.graph import AgentGraph
1111
from agent.profiles import ProfileName, get_chat_profiles
1212
from agent.profiles.base import OutputState
13-
from util.chainlit_helpers import (is_feature_enabled, message_rate_limited,
14-
save_openai_metrics, static_messages,
15-
update_search_results)
13+
from util.chainlit_helpers import (PrefixedS3StorageClient, is_feature_enabled,
14+
message_rate_limited, save_openai_metrics,
15+
static_messages, update_search_results)
1616
from util.config_yml import Config, TriggerEvent
1717
from util.logging import logging
1818

@@ -22,12 +22,26 @@
2222
profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me]
2323
llm_graph = AgentGraph(profiles)
2424

25-
if os.getenv("POSTGRES_CHAINLIT_DB"):
26-
CHAINLIT_DB_URI = f"postgresql+psycopg://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_CHAINLIT_DB')}?sslmode=disable"
25+
POSTGRES_CHAINLIT_DB = os.getenv("POSTGRES_CHAINLIT_DB")
26+
POSTGRES_USER = os.getenv("POSTGRES_USER")
27+
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
28+
S3_BUCKET = os.getenv("S3_BUCKET")
29+
S3_CHAINLIT_PREFIX = os.getenv("S3_CHAINLIT_PREFIX")
30+
31+
if POSTGRES_CHAINLIT_DB and POSTGRES_USER and POSTGRES_PASSWORD:
32+
CHAINLIT_DB_URI = f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@postgres:5432/{POSTGRES_CHAINLIT_DB}?sslmode=disable"
33+
34+
if S3_BUCKET and S3_CHAINLIT_PREFIX:
35+
storage_client = PrefixedS3StorageClient(S3_BUCKET, S3_CHAINLIT_PREFIX)
36+
else:
37+
storage_client = None
2738

2839
@cl.data_layer
2940
def get_data_layer() -> BaseDataLayer:
30-
return SQLAlchemyDataLayer(conninfo=CHAINLIT_DB_URI)
41+
return SQLAlchemyDataLayer(
42+
conninfo=CHAINLIT_DB_URI,
43+
storage_provider=storage_client,
44+
)
3145

3246
else:
3347
logging.warning("POSTGRES_CHAINLIT_DB undefined; Chainlit persistence disabled.")
@@ -57,8 +71,8 @@ async def chat_profiles() -> list[cl.ChatProfile]:
5771

5872
@cl.on_chat_start
5973
async def start() -> None:
60-
thread_id: str = cl.user_session.get("id")
61-
cl.user_session.set("thread_id", thread_id)
74+
if cl.user_session.get("thread_id") is None:
75+
cl.user_session.set("thread_id", cl.user_session.get("id"))
6276
await static_messages(config, TriggerEvent.on_chat_start)
6377

6478

@@ -100,15 +114,16 @@ async def main(message: cl.Message) -> None:
100114
thread_id=thread_id,
101115
enable_postprocess=enable_postprocess,
102116
)
117+
assistant_message: cl.Message | None = chainlit_cb.final_stream
103118

104119
if (
105120
enable_postprocess
106-
and chainlit_cb.final_stream
121+
and assistant_message
107122
and len(result["additional_content"]["search_results"]) > 0
108123
):
109124
await update_search_results(
110125
result["additional_content"]["search_results"],
111-
chainlit_cb.final_stream,
126+
assistant_message,
112127
)
113128

114129
await static_messages(config, after_messages=message_count)

docker-compose.yml

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ services:
88
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
99
- POSTGRES_CHAINLIT_DB=${POSTGRES_CHAINLIT_DB}
1010
- POSTGRES_LANGGRAPH_DB=${POSTGRES_LANGGRAPH_DB}
11+
- S3_BUCKET=${S3_BUCKET}
12+
- S3_CHAINLIT_PREFIX=${S3_CHAINLIT_PREFIX}
1113
- LOG_LEVEL=${LOG_LEVEL}
1214
- UVICORN_LOG_LEVEL=${LOG_LEVEL}
1315
- CHAT_ENV=${CHAT_ENV}
@@ -42,6 +44,7 @@ services:
4244
- UVICORN_LOG_LEVEL=${LOG_LEVEL}
4345
- POSTGRES_USER=${POSTGRES_USER}
4446
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
47+
- POSTGRES_CHAINLIT_DB=${POSTGRES_CHAINLIT_DB}
4548
- POSTGRES_LANGGRAPH_DB=${POSTGRES_LANGGRAPH_DB}_no_login
4649
- CHAT_ENV=${CHAT_ENV}
4750
- CLOUDFLARE_SECRET_KEY=${CLOUDFLARE_SECRET_KEY}

src/agent/profiles/base.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from langgraph.graph.message import add_messages
88

99
from agent.tasks.rephrase import create_rephrase_chain
10-
from tools.external_search.state import WebSearchResult
10+
from tools.external_search.state import SearchState, WebSearchResult
11+
from tools.external_search.workflow import create_search_workflow
1112

1213

1314
class AdditionalContent(TypedDict, total=False):
@@ -37,6 +38,7 @@ def __init__(
3738
embedding: Embeddings,
3839
) -> None:
3940
self.rephrase_chain: Runnable = create_rephrase_chain(llm)
41+
self.search_workflow: Runnable = create_search_workflow(llm)
4042

4143
async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
4244
rephrased_input: str = await self.rephrase_chain.ainvoke(
@@ -47,3 +49,18 @@ async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseStat
4749
config,
4850
)
4951
return BaseState(rephrased_input=rephrased_input)
52+
53+
async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
54+
search_results: list[WebSearchResult] = []
55+
if config["configurable"]["enable_postprocess"]:
56+
result: SearchState = await self.search_workflow.ainvoke(
57+
SearchState(
58+
input=state["rephrased_input"],
59+
generation=state["answer"],
60+
),
61+
config=RunnableConfig(callbacks=config["callbacks"]),
62+
)
63+
search_results = result["search_results"]
64+
return BaseState(
65+
additional_content=AdditionalContent(search_results=search_results)
66+
)

src/agent/profiles/cross_database.py

+2-22
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from langchain_core.language_models.chat_models import BaseChatModel
55
from langchain_core.messages import AIMessage, HumanMessage
66
from langchain_core.runnables import Runnable, RunnableConfig
7-
from langgraph.graph.state import CompiledStateGraph, StateGraph
7+
from langgraph.graph.state import StateGraph
88

9-
from agent.profiles.base import AdditionalContent, BaseGraphBuilder, BaseState
9+
from agent.profiles.base import BaseGraphBuilder, BaseState
1010
from agent.tasks.completeness_grader import (CompletenessGrade,
1111
create_completeness_grader)
1212
from agent.tasks.cross_database.rewrite_reactome_with_uniprot import \
@@ -19,8 +19,6 @@
1919
from agent.tasks.safety_checker import SafetyCheck, create_safety_checker
2020
from retrievers.reactome.rag import create_reactome_rag
2121
from retrievers.uniprot.rag import create_uniprot_rag
22-
from tools.external_search.state import SearchState, WebSearchResult
23-
from tools.external_search.workflow import create_search_workflow
2422

2523

2624
class CrossDatabaseState(BaseState):
@@ -45,7 +43,6 @@ def __init__(
4543
super().__init__(llm, embedding)
4644

4745
# Create runnables (tasks & tools)
48-
self.search_workflow: CompiledStateGraph = create_search_workflow(llm)
4946
self.reactome_rag: Runnable = create_reactome_rag(llm, embedding)
5047
self.uniprot_rag: Runnable = create_uniprot_rag(llm, embedding)
5148

@@ -273,23 +270,6 @@ async def generate_final_response(
273270
answer=final_response,
274271
)
275272

276-
async def postprocess(
277-
self, state: CrossDatabaseState, config: RunnableConfig
278-
) -> CrossDatabaseState:
279-
search_results: list[WebSearchResult] = []
280-
if config["configurable"]["enable_postprocess"]:
281-
result: SearchState = await self.search_workflow.ainvoke(
282-
SearchState(
283-
input=state["rephrased_input"],
284-
generation=state["answer"],
285-
),
286-
config=RunnableConfig(callbacks=config["callbacks"]),
287-
)
288-
search_results = result["search_results"]
289-
return CrossDatabaseState(
290-
additional_content=AdditionalContent(search_results=search_results)
291-
)
292-
293273

294274
def create_cross_database_graph(
295275
llm: BaseChatModel,

src/agent/profiles/react_to_me.py

+2-22
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
from langchain_core.language_models.chat_models import BaseChatModel
55
from langchain_core.messages import AIMessage, HumanMessage
66
from langchain_core.runnables import Runnable, RunnableConfig
7-
from langgraph.graph.state import CompiledStateGraph, StateGraph
7+
from langgraph.graph.state import StateGraph
88

9-
from agent.profiles.base import AdditionalContent, BaseGraphBuilder, BaseState
9+
from agent.profiles.base import BaseGraphBuilder, BaseState
1010
from retrievers.reactome.rag import create_reactome_rag
11-
from tools.external_search.state import SearchState, WebSearchResult
12-
from tools.external_search.workflow import create_search_workflow
1311

1412

1513
class ReactToMeState(BaseState):
@@ -28,7 +26,6 @@ def __init__(
2826
self.reactome_rag: Runnable = create_reactome_rag(
2927
llm, embedding, streaming=True
3028
)
31-
self.search_workflow: CompiledStateGraph = create_search_workflow(llm)
3229

3330
# Create graph
3431
state_graph = StateGraph(ReactToMeState)
@@ -62,23 +59,6 @@ async def call_model(
6259
answer=result["answer"],
6360
)
6461

65-
async def postprocess(
66-
self, state: ReactToMeState, config: RunnableConfig
67-
) -> ReactToMeState:
68-
search_results: list[WebSearchResult] = []
69-
if config["configurable"]["enable_postprocess"]:
70-
result: SearchState = await self.search_workflow.ainvoke(
71-
SearchState(
72-
input=state["rephrased_input"],
73-
generation=state["answer"],
74-
),
75-
config=RunnableConfig(callbacks=config["callbacks"]),
76-
)
77-
search_results = result["search_results"]
78-
return ReactToMeState(
79-
additional_content=AdditionalContent(search_results=search_results)
80-
)
81-
8262

8363
def create_reactome_graph(
8464
llm: BaseChatModel,

src/util/chainlit_helpers.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
22
from datetime import datetime
3+
from pathlib import PurePosixPath
34
from typing import Any, Iterable
45

56
import chainlit as cl
67
from chainlit.data import get_data_layer
8+
from chainlit.data.storage_clients.s3 import S3StorageClient
79
from langchain_community.callbacks import OpenAICallbackHandler
810

911
from util.config_yml import Config, TriggerEvent
@@ -12,6 +14,30 @@
1214
guest_user_metadata: dict[str, Any] = {}
1315

1416

17+
class PrefixedS3StorageClient(S3StorageClient):
18+
def __init__(self, bucket: str, prefix: str, **kwargs: Any) -> None:
19+
super().__init__(bucket, **kwargs)
20+
self._prefix = PurePosixPath(prefix)
21+
22+
async def upload_file(
23+
self,
24+
object_key: str,
25+
data: bytes | str,
26+
mime: str = "application/octet-stream",
27+
overwrite: bool = True,
28+
) -> dict[str, Any]:
29+
object_key = str(self._prefix / object_key)
30+
return await super().upload_file(object_key, data, mime, overwrite)
31+
32+
async def delete_file(self, object_key: str) -> bool:
33+
object_key = str(self._prefix / object_key)
34+
return await super().delete_file(object_key)
35+
36+
async def get_read_url(self, object_key: str) -> str:
37+
object_key = str(self._prefix / object_key)
38+
return await super().get_read_url(object_key)
39+
40+
1541
def get_user_id() -> str | None:
1642
user: cl.User | None = cl.user_session.get("user")
1743
return user.identifier if user else None
@@ -136,7 +162,7 @@ async def update_search_results(
136162
name="SearchResults",
137163
props={"results": search_results},
138164
)
139-
message.elements = [search_results_element] # type: ignore
165+
message.elements.append(search_results_element) # type: ignore[arg-type]
140166
await message.update()
141167

142168

0 commit comments

Comments
 (0)