Skip to content

User Feedback on Guest + S3StorageClient #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .chainlit/translations/en-US.json
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"negative": "Not helpful",
"edit": "Edit feedback",
"dialog": {
"title": "Add a comment",
"title": "Add a comment (optional)",
"submit": "Submit feedback"
},
"status": {
Expand Down Expand Up @@ -188,4 +188,4 @@
"saved": "Saved successfully"
}
}
}
}
35 changes: 25 additions & 10 deletions bin/chat-chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from agent.graph import AgentGraph
from agent.profiles import ProfileName, get_chat_profiles
from agent.profiles.base import OutputState
from util.chainlit_helpers import (is_feature_enabled, message_rate_limited,
save_openai_metrics, static_messages,
update_search_results)
from util.chainlit_helpers import (PrefixedS3StorageClient, is_feature_enabled,
message_rate_limited, save_openai_metrics,
static_messages, update_search_results)
from util.config_yml import Config, TriggerEvent
from util.logging import logging

Expand All @@ -22,12 +22,26 @@
profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me]
llm_graph = AgentGraph(profiles)

if os.getenv("POSTGRES_CHAINLIT_DB"):
CHAINLIT_DB_URI = f"postgresql+psycopg://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_CHAINLIT_DB')}?sslmode=disable"
POSTGRES_CHAINLIT_DB = os.getenv("POSTGRES_CHAINLIT_DB")
POSTGRES_USER = os.getenv("POSTGRES_USER")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
S3_BUCKET = os.getenv("S3_BUCKET")
S3_CHAINLIT_PREFIX = os.getenv("S3_CHAINLIT_PREFIX")

if POSTGRES_CHAINLIT_DB and POSTGRES_USER and POSTGRES_PASSWORD:
CHAINLIT_DB_URI = f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@postgres:5432/{POSTGRES_CHAINLIT_DB}?sslmode=disable"

if S3_BUCKET and S3_CHAINLIT_PREFIX:
storage_client = PrefixedS3StorageClient(S3_BUCKET, S3_CHAINLIT_PREFIX)
else:
storage_client = None

@cl.data_layer
def get_data_layer() -> BaseDataLayer:
return SQLAlchemyDataLayer(conninfo=CHAINLIT_DB_URI)
return SQLAlchemyDataLayer(
conninfo=CHAINLIT_DB_URI,
storage_provider=storage_client,
)

else:
logging.warning("POSTGRES_CHAINLIT_DB undefined; Chainlit persistence disabled.")
Expand Down Expand Up @@ -57,8 +71,8 @@ async def chat_profiles() -> list[cl.ChatProfile]:

@cl.on_chat_start
async def start() -> None:
thread_id: str = cl.user_session.get("id")
cl.user_session.set("thread_id", thread_id)
if cl.user_session.get("thread_id") is None:
cl.user_session.set("thread_id", cl.user_session.get("id"))
await static_messages(config, TriggerEvent.on_chat_start)


Expand Down Expand Up @@ -100,15 +114,16 @@ async def main(message: cl.Message) -> None:
thread_id=thread_id,
enable_postprocess=enable_postprocess,
)
assistant_message: cl.Message | None = chainlit_cb.final_stream

if (
enable_postprocess
and chainlit_cb.final_stream
and assistant_message
and len(result["additional_content"]["search_results"]) > 0
):
await update_search_results(
result["additional_content"]["search_results"],
chainlit_cb.final_stream,
assistant_message,
)

await static_messages(config, after_messages=message_count)
Expand Down
3 changes: 3 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ services:
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
- POSTGRES_CHAINLIT_DB=${POSTGRES_CHAINLIT_DB}
- POSTGRES_LANGGRAPH_DB=${POSTGRES_LANGGRAPH_DB}
- S3_BUCKET=${S3_BUCKET}
- S3_CHAINLIT_PREFIX=${S3_CHAINLIT_PREFIX}
- LOG_LEVEL=${LOG_LEVEL}
- UVICORN_LOG_LEVEL=${LOG_LEVEL}
- CHAT_ENV=${CHAT_ENV}
Expand Down Expand Up @@ -42,6 +44,7 @@ services:
- UVICORN_LOG_LEVEL=${LOG_LEVEL}
- POSTGRES_USER=${POSTGRES_USER}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
- POSTGRES_CHAINLIT_DB=${POSTGRES_CHAINLIT_DB}
- POSTGRES_LANGGRAPH_DB=${POSTGRES_LANGGRAPH_DB}_no_login
- CHAT_ENV=${CHAT_ENV}
- CLOUDFLARE_SECRET_KEY=${CLOUDFLARE_SECRET_KEY}
Expand Down
19 changes: 18 additions & 1 deletion src/agent/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from langgraph.graph.message import add_messages

from agent.tasks.rephrase import create_rephrase_chain
from tools.external_search.state import WebSearchResult
from tools.external_search.state import SearchState, WebSearchResult
from tools.external_search.workflow import create_search_workflow


class AdditionalContent(TypedDict, total=False):
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
embedding: Embeddings,
) -> None:
self.rephrase_chain: Runnable = create_rephrase_chain(llm)
self.search_workflow: Runnable = create_search_workflow(llm)

async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
rephrased_input: str = await self.rephrase_chain.ainvoke(
Expand All @@ -47,3 +49,18 @@ async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseStat
config,
)
return BaseState(rephrased_input=rephrased_input)

async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:
result: SearchState = await self.search_workflow.ainvoke(
SearchState(
input=state["rephrased_input"],
generation=state["answer"],
),
config=RunnableConfig(callbacks=config["callbacks"]),
)
search_results = result["search_results"]
return BaseState(
additional_content=AdditionalContent(search_results=search_results)
)
24 changes: 2 additions & 22 deletions src/agent/profiles/cross_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.graph.state import StateGraph

from agent.profiles.base import AdditionalContent, BaseGraphBuilder, BaseState
from agent.profiles.base import BaseGraphBuilder, BaseState
from agent.tasks.completeness_grader import (CompletenessGrade,
create_completeness_grader)
from agent.tasks.cross_database.rewrite_reactome_with_uniprot import \
Expand All @@ -19,8 +19,6 @@
from agent.tasks.safety_checker import SafetyCheck, create_safety_checker
from retrievers.reactome.rag import create_reactome_rag
from retrievers.uniprot.rag import create_uniprot_rag
from tools.external_search.state import SearchState, WebSearchResult
from tools.external_search.workflow import create_search_workflow


class CrossDatabaseState(BaseState):
Expand All @@ -45,7 +43,6 @@ def __init__(
super().__init__(llm, embedding)

# Create runnables (tasks & tools)
self.search_workflow: CompiledStateGraph = create_search_workflow(llm)
self.reactome_rag: Runnable = create_reactome_rag(llm, embedding)
self.uniprot_rag: Runnable = create_uniprot_rag(llm, embedding)

Expand Down Expand Up @@ -273,23 +270,6 @@ async def generate_final_response(
answer=final_response,
)

async def postprocess(
self, state: CrossDatabaseState, config: RunnableConfig
) -> CrossDatabaseState:
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:
result: SearchState = await self.search_workflow.ainvoke(
SearchState(
input=state["rephrased_input"],
generation=state["answer"],
),
config=RunnableConfig(callbacks=config["callbacks"]),
)
search_results = result["search_results"]
return CrossDatabaseState(
additional_content=AdditionalContent(search_results=search_results)
)


def create_cross_database_graph(
llm: BaseChatModel,
Expand Down
24 changes: 2 additions & 22 deletions src/agent/profiles/react_to_me.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.graph.state import StateGraph

from agent.profiles.base import AdditionalContent, BaseGraphBuilder, BaseState
from agent.profiles.base import BaseGraphBuilder, BaseState
from retrievers.reactome.rag import create_reactome_rag
from tools.external_search.state import SearchState, WebSearchResult
from tools.external_search.workflow import create_search_workflow


class ReactToMeState(BaseState):
Expand All @@ -28,7 +26,6 @@ def __init__(
self.reactome_rag: Runnable = create_reactome_rag(
llm, embedding, streaming=True
)
self.search_workflow: CompiledStateGraph = create_search_workflow(llm)

# Create graph
state_graph = StateGraph(ReactToMeState)
Expand Down Expand Up @@ -62,23 +59,6 @@ async def call_model(
answer=result["answer"],
)

async def postprocess(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:
result: SearchState = await self.search_workflow.ainvoke(
SearchState(
input=state["rephrased_input"],
generation=state["answer"],
),
config=RunnableConfig(callbacks=config["callbacks"]),
)
search_results = result["search_results"]
return ReactToMeState(
additional_content=AdditionalContent(search_results=search_results)
)


def create_reactome_graph(
llm: BaseChatModel,
Expand Down
28 changes: 27 additions & 1 deletion src/util/chainlit_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from datetime import datetime
from pathlib import PurePosixPath
from typing import Any, Iterable

import chainlit as cl
from chainlit.data import get_data_layer
from chainlit.data.storage_clients.s3 import S3StorageClient
from langchain_community.callbacks import OpenAICallbackHandler

from util.config_yml import Config, TriggerEvent
Expand All @@ -12,6 +14,30 @@
guest_user_metadata: dict[str, Any] = {}


class PrefixedS3StorageClient(S3StorageClient):
def __init__(self, bucket: str, prefix: str, **kwargs: Any) -> None:
super().__init__(bucket, **kwargs)
self._prefix = PurePosixPath(prefix)

async def upload_file(
self,
object_key: str,
data: bytes | str,
mime: str = "application/octet-stream",
overwrite: bool = True,
) -> dict[str, Any]:
object_key = str(self._prefix / object_key)
return await super().upload_file(object_key, data, mime, overwrite)

async def delete_file(self, object_key: str) -> bool:
object_key = str(self._prefix / object_key)
return await super().delete_file(object_key)

async def get_read_url(self, object_key: str) -> str:
object_key = str(self._prefix / object_key)
return await super().get_read_url(object_key)


def get_user_id() -> str | None:
user: cl.User | None = cl.user_session.get("user")
return user.identifier if user else None
Expand Down Expand Up @@ -136,7 +162,7 @@ async def update_search_results(
name="SearchResults",
props={"results": search_results},
)
message.elements = [search_results_element] # type: ignore
message.elements.append(search_results_element) # type: ignore[arg-type]
await message.update()


Expand Down