diff --git a/.chainlit/translations/en-US.json b/.chainlit/translations/en-US.json index 21cb885..3202e2c 100644 --- a/.chainlit/translations/en-US.json +++ b/.chainlit/translations/en-US.json @@ -94,7 +94,7 @@ "negative": "Not helpful", "edit": "Edit feedback", "dialog": { - "title": "Add a comment", + "title": "Add a comment (optional)", "submit": "Submit feedback" }, "status": { @@ -188,4 +188,4 @@ "saved": "Saved successfully" } } -} \ No newline at end of file +} diff --git a/bin/chat-chainlit.py b/bin/chat-chainlit.py index c46e9d7..fa4faf6 100644 --- a/bin/chat-chainlit.py +++ b/bin/chat-chainlit.py @@ -10,9 +10,9 @@ from agent.graph import AgentGraph from agent.profiles import ProfileName, get_chat_profiles from agent.profiles.base import OutputState -from util.chainlit_helpers import (is_feature_enabled, message_rate_limited, - save_openai_metrics, static_messages, - update_search_results) +from util.chainlit_helpers import (PrefixedS3StorageClient, is_feature_enabled, + message_rate_limited, save_openai_metrics, + static_messages, update_search_results) from util.config_yml import Config, TriggerEvent from util.logging import logging @@ -22,12 +22,26 @@ profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me] llm_graph = AgentGraph(profiles) -if os.getenv("POSTGRES_CHAINLIT_DB"): - CHAINLIT_DB_URI = f"postgresql+psycopg://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_CHAINLIT_DB')}?sslmode=disable" +POSTGRES_CHAINLIT_DB = os.getenv("POSTGRES_CHAINLIT_DB") +POSTGRES_USER = os.getenv("POSTGRES_USER") +POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") +S3_BUCKET = os.getenv("S3_BUCKET") +S3_CHAINLIT_PREFIX = os.getenv("S3_CHAINLIT_PREFIX") + +if POSTGRES_CHAINLIT_DB and POSTGRES_USER and POSTGRES_PASSWORD: + CHAINLIT_DB_URI = f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@postgres:5432/{POSTGRES_CHAINLIT_DB}?sslmode=disable" + + if S3_BUCKET and S3_CHAINLIT_PREFIX: + storage_client = PrefixedS3StorageClient(S3_BUCKET, S3_CHAINLIT_PREFIX) + else: + storage_client = None @cl.data_layer def get_data_layer() -> BaseDataLayer: - return SQLAlchemyDataLayer(conninfo=CHAINLIT_DB_URI) + return SQLAlchemyDataLayer( + conninfo=CHAINLIT_DB_URI, + storage_provider=storage_client, + ) else: logging.warning("POSTGRES_CHAINLIT_DB undefined; Chainlit persistence disabled.") @@ -57,8 +71,8 @@ async def chat_profiles() -> list[cl.ChatProfile]: @cl.on_chat_start async def start() -> None: - thread_id: str = cl.user_session.get("id") - cl.user_session.set("thread_id", thread_id) + if cl.user_session.get("thread_id") is None: + cl.user_session.set("thread_id", cl.user_session.get("id")) await static_messages(config, TriggerEvent.on_chat_start) @@ -100,15 +114,16 @@ async def main(message: cl.Message) -> None: thread_id=thread_id, enable_postprocess=enable_postprocess, ) + assistant_message: cl.Message | None = chainlit_cb.final_stream if ( enable_postprocess - and chainlit_cb.final_stream + and assistant_message and len(result["additional_content"]["search_results"]) > 0 ): await update_search_results( result["additional_content"]["search_results"], - chainlit_cb.final_stream, + assistant_message, ) await static_messages(config, after_messages=message_count) diff --git a/docker-compose.yml b/docker-compose.yml index 37d8a57..bf353ad 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,8 @@ services: - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} - POSTGRES_CHAINLIT_DB=${POSTGRES_CHAINLIT_DB} - POSTGRES_LANGGRAPH_DB=${POSTGRES_LANGGRAPH_DB} + - S3_BUCKET=${S3_BUCKET} + - S3_CHAINLIT_PREFIX=${S3_CHAINLIT_PREFIX} - LOG_LEVEL=${LOG_LEVEL} - UVICORN_LOG_LEVEL=${LOG_LEVEL} - CHAT_ENV=${CHAT_ENV} @@ -42,6 +44,7 @@ services: - UVICORN_LOG_LEVEL=${LOG_LEVEL} - POSTGRES_USER=${POSTGRES_USER} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} + - POSTGRES_CHAINLIT_DB=${POSTGRES_CHAINLIT_DB} - POSTGRES_LANGGRAPH_DB=${POSTGRES_LANGGRAPH_DB}_no_login - CHAT_ENV=${CHAT_ENV} - CLOUDFLARE_SECRET_KEY=${CLOUDFLARE_SECRET_KEY} diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index f5e6580..9a6e26c 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -7,7 +7,8 @@ from langgraph.graph.message import add_messages from agent.tasks.rephrase import create_rephrase_chain -from tools.external_search.state import WebSearchResult +from tools.external_search.state import SearchState, WebSearchResult +from tools.external_search.workflow import create_search_workflow class AdditionalContent(TypedDict, total=False): @@ -37,6 +38,7 @@ def __init__( embedding: Embeddings, ) -> None: self.rephrase_chain: Runnable = create_rephrase_chain(llm) + self.search_workflow: Runnable = create_search_workflow(llm) async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: rephrased_input: str = await self.rephrase_chain.ainvoke( @@ -47,3 +49,18 @@ async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseStat config, ) return BaseState(rephrased_input=rephrased_input) + + async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: + search_results: list[WebSearchResult] = [] + if config["configurable"]["enable_postprocess"]: + result: SearchState = await self.search_workflow.ainvoke( + SearchState( + input=state["rephrased_input"], + generation=state["answer"], + ), + config=RunnableConfig(callbacks=config["callbacks"]), + ) + search_results = result["search_results"] + return BaseState( + additional_content=AdditionalContent(search_results=search_results) + ) diff --git a/src/agent/profiles/cross_database.py b/src/agent/profiles/cross_database.py index 50e95e4..74ef26c 100644 --- a/src/agent/profiles/cross_database.py +++ b/src/agent/profiles/cross_database.py @@ -4,9 +4,9 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables import Runnable, RunnableConfig -from langgraph.graph.state import CompiledStateGraph, StateGraph +from langgraph.graph.state import StateGraph -from agent.profiles.base import AdditionalContent, BaseGraphBuilder, BaseState +from agent.profiles.base import BaseGraphBuilder, BaseState from agent.tasks.completeness_grader import (CompletenessGrade, create_completeness_grader) from agent.tasks.cross_database.rewrite_reactome_with_uniprot import \ @@ -19,8 +19,6 @@ from agent.tasks.safety_checker import SafetyCheck, create_safety_checker from retrievers.reactome.rag import create_reactome_rag from retrievers.uniprot.rag import create_uniprot_rag -from tools.external_search.state import SearchState, WebSearchResult -from tools.external_search.workflow import create_search_workflow class CrossDatabaseState(BaseState): @@ -45,7 +43,6 @@ def __init__( super().__init__(llm, embedding) # Create runnables (tasks & tools) - self.search_workflow: CompiledStateGraph = create_search_workflow(llm) self.reactome_rag: Runnable = create_reactome_rag(llm, embedding) self.uniprot_rag: Runnable = create_uniprot_rag(llm, embedding) @@ -273,23 +270,6 @@ async def generate_final_response( answer=final_response, ) - async def postprocess( - self, state: CrossDatabaseState, config: RunnableConfig - ) -> CrossDatabaseState: - search_results: list[WebSearchResult] = [] - if config["configurable"]["enable_postprocess"]: - result: SearchState = await self.search_workflow.ainvoke( - SearchState( - input=state["rephrased_input"], - generation=state["answer"], - ), - config=RunnableConfig(callbacks=config["callbacks"]), - ) - search_results = result["search_results"] - return CrossDatabaseState( - additional_content=AdditionalContent(search_results=search_results) - ) - def create_cross_database_graph( llm: BaseChatModel, diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index 35002b5..5878979 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -4,12 +4,10 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables import Runnable, RunnableConfig -from langgraph.graph.state import CompiledStateGraph, StateGraph +from langgraph.graph.state import StateGraph -from agent.profiles.base import AdditionalContent, BaseGraphBuilder, BaseState +from agent.profiles.base import BaseGraphBuilder, BaseState from retrievers.reactome.rag import create_reactome_rag -from tools.external_search.state import SearchState, WebSearchResult -from tools.external_search.workflow import create_search_workflow class ReactToMeState(BaseState): @@ -28,7 +26,6 @@ def __init__( self.reactome_rag: Runnable = create_reactome_rag( llm, embedding, streaming=True ) - self.search_workflow: CompiledStateGraph = create_search_workflow(llm) # Create graph state_graph = StateGraph(ReactToMeState) @@ -62,23 +59,6 @@ async def call_model( answer=result["answer"], ) - async def postprocess( - self, state: ReactToMeState, config: RunnableConfig - ) -> ReactToMeState: - search_results: list[WebSearchResult] = [] - if config["configurable"]["enable_postprocess"]: - result: SearchState = await self.search_workflow.ainvoke( - SearchState( - input=state["rephrased_input"], - generation=state["answer"], - ), - config=RunnableConfig(callbacks=config["callbacks"]), - ) - search_results = result["search_results"] - return ReactToMeState( - additional_content=AdditionalContent(search_results=search_results) - ) - def create_reactome_graph( llm: BaseChatModel, diff --git a/src/util/chainlit_helpers.py b/src/util/chainlit_helpers.py index c4b1bd5..4c1b4fd 100644 --- a/src/util/chainlit_helpers.py +++ b/src/util/chainlit_helpers.py @@ -1,9 +1,11 @@ import os from datetime import datetime +from pathlib import PurePosixPath from typing import Any, Iterable import chainlit as cl from chainlit.data import get_data_layer +from chainlit.data.storage_clients.s3 import S3StorageClient from langchain_community.callbacks import OpenAICallbackHandler from util.config_yml import Config, TriggerEvent @@ -12,6 +14,30 @@ guest_user_metadata: dict[str, Any] = {} +class PrefixedS3StorageClient(S3StorageClient): + def __init__(self, bucket: str, prefix: str, **kwargs: Any) -> None: + super().__init__(bucket, **kwargs) + self._prefix = PurePosixPath(prefix) + + async def upload_file( + self, + object_key: str, + data: bytes | str, + mime: str = "application/octet-stream", + overwrite: bool = True, + ) -> dict[str, Any]: + object_key = str(self._prefix / object_key) + return await super().upload_file(object_key, data, mime, overwrite) + + async def delete_file(self, object_key: str) -> bool: + object_key = str(self._prefix / object_key) + return await super().delete_file(object_key) + + async def get_read_url(self, object_key: str) -> str: + object_key = str(self._prefix / object_key) + return await super().get_read_url(object_key) + + def get_user_id() -> str | None: user: cl.User | None = cl.user_session.get("user") return user.identifier if user else None @@ -136,7 +162,7 @@ async def update_search_results( name="SearchResults", props={"results": search_results}, ) - message.elements = [search_results_element] # type: ignore + message.elements.append(search_results_element) # type: ignore[arg-type] await message.update()