diff --git a/.config.schema.yaml b/.config.schema.yaml
index 10f67f6..5da62f8 100644
--- a/.config.schema.yaml
+++ b/.config.schema.yaml
@@ -55,7 +55,7 @@ properties:
type: array
items:
type: string
- enum: ["React-to-Me"]
+ enum: ["React-to-Me", "Cross-Database Prototype"]
usage_limits:
type: object
properties:
diff --git a/bin/chat-chainlit.py b/bin/chat-chainlit.py
index 9de6e9f..c46e9d7 100644
--- a/bin/chat-chainlit.py
+++ b/bin/chat-chainlit.py
@@ -1,5 +1,4 @@
import os
-from typing import Any
import chainlit as cl
from chainlit.data.base import BaseDataLayer
@@ -10,6 +9,7 @@
from agent.graph import AgentGraph
from agent.profiles import ProfileName, get_chat_profiles
+from agent.profiles.base import OutputState
from util.chainlit_helpers import (is_feature_enabled, message_rate_limited,
save_openai_metrics, static_messages,
update_search_results)
@@ -93,7 +93,7 @@ async def main(message: cl.Message) -> None:
openai_cb = OpenAICallbackHandler()
enable_postprocess: bool = is_feature_enabled(config, "postprocessing")
- result: dict[str, Any] = await llm_graph.ainvoke(
+ result: OutputState = await llm_graph.ainvoke(
message.content,
chat_profile.lower(),
callbacks=[chainlit_cb, openai_cb],
diff --git a/bin/embeddings_manager b/bin/embeddings_manager
index 893cbb8..385e315 100755
--- a/bin/embeddings_manager
+++ b/bin/embeddings_manager
@@ -14,6 +14,7 @@ from botocore.client import Config
from data_generation.alliance import generate_alliance_embeddings
from data_generation.reactome import generate_reactome_embeddings
+from data_generation.uniprot import generate_uniprot_embeddings
from util.embedding_environment import EM_ARCHIVE, EmbeddingEnvironment
S3_BUCKET = "download.reactome.org"
@@ -86,6 +87,8 @@ def make(
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_key
if embedding.db == "reactome":
generate_reactome_embeddings(str(embedding_path), hf_model=embedding.model, **kwargs)
+ elif embedding.db == "uniprot":
+ generate_uniprot_embeddings(embedding_path, hf_model=embedding.model, **kwargs)
elif embedding.db == "alliance":
generate_alliance_embeddings(str(embedding_path), hf_model=embedding.model, **kwargs)
else:
diff --git a/mypy.ini b/mypy.ini
index 662f6f3..7ec92a3 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -3,7 +3,9 @@ ignore_missing_imports = True
allow_untyped_calls = True
allow_untyped_defs = True
allow_untyped_globals = True
+explicit_package_bases = True
exclude = data/
+files = bin/,src/
[mypy.plugins.pandas.*]
init_forbid_dynamic = False
diff --git a/poetry.lock b/poetry.lock
index c4c9ae6..71cbb2a 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -956,6 +956,17 @@ files = [
{file = "einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85"},
]
+[[package]]
+name = "et-xmlfile"
+version = "2.0.0"
+description = "An implementation of lxml.xmlfile for the standard library"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"},
+ {file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"},
+]
+
[[package]]
name = "fastapi"
version = "0.115.7"
@@ -2903,6 +2914,20 @@ typing-extensions = ">=4.11,<5"
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
realtime = ["websockets (>=13,<15)"]
+[[package]]
+name = "openpyxl"
+version = "3.1.5"
+description = "A Python library to read/write Excel 2010 xlsx/xlsm files"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"},
+ {file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"},
+]
+
+[package.dependencies]
+et-xmlfile = "*"
+
[[package]]
name = "opentelemetry-api"
version = "1.29.0"
@@ -6055,4 +6080,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.12, <4"
-content-hash = "5b1f865d119b14bd9b18d319d3e20422de1fed9eb3cacbef5e92dbb3594cdf3e"
+content-hash = "3952ce13da91d68ddd78ca4d3d36b075e9da512c72100fa7ee7a75beb1e2abe8"
diff --git a/pyproject.toml b/pyproject.toml
index 2c1f087..362008c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,6 +45,7 @@ psycopg = {extras = ["binary"], version = "^3.2.3"}
pydantic = "^2.10.5"
pyyaml = "^6.0.2"
tavily-python = "^0.5.0"
+openpyxl = "^3.1.5"
[tool.poetry.group.dev.dependencies]
ruff = "^0.7.1"
diff --git a/src/agent/graph.py b/src/agent/graph.py
index de311d0..012df27 100644
--- a/src/agent/graph.py
+++ b/src/agent/graph.py
@@ -15,6 +15,7 @@
from agent.models import get_embedding, get_llm
from agent.profiles import ProfileName, create_profile_graphs
+from agent.profiles.base import InputState, OutputState
from util.logging import logging
LANGGRAPH_DB_URI = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_LANGGRAPH_DB')}?sslmode=disable"
@@ -81,13 +82,13 @@ async def ainvoke(
callbacks: Callbacks,
thread_id: str,
enable_postprocess: bool = True,
- ) -> dict[str, Any]:
+ ) -> OutputState:
if self.graph is None:
self.graph = await self.initialize()
if profile not in self.graph:
- return {}
- result: dict[str, Any] = await self.graph[profile].ainvoke(
- {"user_input": user_input},
+ return OutputState()
+ result: OutputState = await self.graph[profile].ainvoke(
+ InputState(user_input=user_input),
config=RunnableConfig(
callbacks=callbacks,
configurable={
diff --git a/src/agent/profiles/__init__.py b/src/agent/profiles/__init__.py
index 92e3787..8061302 100644
--- a/src/agent/profiles/__init__.py
+++ b/src/agent/profiles/__init__.py
@@ -5,12 +5,14 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.graph.state import StateGraph
-from agent.profiles.react_to_me import create_reacttome_graph
+from agent.profiles.cross_database import create_cross_database_graph
+from agent.profiles.react_to_me import create_reactome_graph
class ProfileName(StrEnum):
# These should exactly match names in .config.schema.yaml
React_to_Me = "React-to-Me"
+ Cross_Database_Prototype = "Cross-Database Prototype"
class Profile(NamedTuple):
@@ -23,7 +25,12 @@ class Profile(NamedTuple):
ProfileName.React_to_Me.lower(): Profile(
name=ProfileName.React_to_Me,
description="An AI assistant specialized in exploring **Reactome** biological pathways and processes.",
- graph_builder=create_reacttome_graph,
+ graph_builder=create_reactome_graph,
+ ),
+ ProfileName.Cross_Database_Prototype.lower(): Profile(
+ name=ProfileName.Cross_Database_Prototype,
+ description="Early version of an AI assistant with knowledge from multiple bio-databases (**Reactome** + **Uniprot**).",
+ graph_builder=create_cross_database_graph,
),
}
diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py
index 32d1a3f..f5e6580 100644
--- a/src/agent/profiles/base.py
+++ b/src/agent/profiles/base.py
@@ -1,25 +1,49 @@
from typing import Annotated, TypedDict
-from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
+from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.message import add_messages
+from agent.tasks.rephrase import create_rephrase_chain
from tools.external_search.state import WebSearchResult
-class AdditionalContent(TypedDict):
+class AdditionalContent(TypedDict, total=False):
search_results: list[WebSearchResult]
-class BaseState(TypedDict):
- # (Everything the Chainlit layer uses should be included here)
-
+class InputState(TypedDict, total=False):
user_input: str # User input text
- chat_history: Annotated[list[BaseMessage], add_messages]
- context: list[Document]
+
+
+class OutputState(TypedDict, total=False):
answer: str # primary LLM response that is streamed to the user
additional_content: AdditionalContent # sends on graph completion
+class BaseState(InputState, OutputState, total=False):
+ rephrased_input: str # LLM-generated query from user input
+ chat_history: Annotated[list[BaseMessage], add_messages]
+
+
class BaseGraphBuilder:
- pass # NOTE: Anything that is common to all graph builders goes here
+ # NOTE: Anything that is common to all graph builders goes here
+
+ def __init__(
+ self,
+ llm: BaseChatModel,
+ embedding: Embeddings,
+ ) -> None:
+ self.rephrase_chain: Runnable = create_rephrase_chain(llm)
+
+ async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
+ rephrased_input: str = await self.rephrase_chain.ainvoke(
+ {
+ "user_input": state["user_input"],
+ "chat_history": state["chat_history"],
+ },
+ config,
+ )
+ return BaseState(rephrased_input=rephrased_input)
diff --git a/src/agent/profiles/cross_database.py b/src/agent/profiles/cross_database.py
new file mode 100644
index 0000000..50e95e4
--- /dev/null
+++ b/src/agent/profiles/cross_database.py
@@ -0,0 +1,298 @@
+from typing import Any, Literal
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import AIMessage, HumanMessage
+from langchain_core.runnables import Runnable, RunnableConfig
+from langgraph.graph.state import CompiledStateGraph, StateGraph
+
+from agent.profiles.base import AdditionalContent, BaseGraphBuilder, BaseState
+from agent.tasks.completeness_grader import (CompletenessGrade,
+ create_completeness_grader)
+from agent.tasks.cross_database.rewrite_reactome_with_uniprot import \
+ create_reactome_rewriter_w_uniprot
+from agent.tasks.cross_database.rewrite_uniprot_with_reactome import \
+ create_uniprot_rewriter_w_reactome
+from agent.tasks.cross_database.summarize_reactome_uniprot import \
+ create_reactome_uniprot_summarizer
+from agent.tasks.detect_language import create_language_detector
+from agent.tasks.safety_checker import SafetyCheck, create_safety_checker
+from retrievers.reactome.rag import create_reactome_rag
+from retrievers.uniprot.rag import create_uniprot_rag
+from tools.external_search.state import SearchState, WebSearchResult
+from tools.external_search.workflow import create_search_workflow
+
+
+class CrossDatabaseState(BaseState):
+ safety: str # LLM-assessed safety level of the user input
+ query_language: str # language of the user input
+
+ reactome_query: str # LLM-generated query for Reactome
+ reactome_answer: str # LLM-generated answer from Reactome
+ reactome_completeness: str # LLM-assessed completeness of the Reactome answer
+
+ uniprot_query: str # LLM-generated query for UniProt
+ uniprot_answer: str # LLM-generated answer from UniProt
+ uniprot_completeness: str # LLM-assessed completeness of the UniProt answer
+
+
+class CrossDatabaseGraphBuilder(BaseGraphBuilder):
+ def __init__(
+ self,
+ llm: BaseChatModel,
+ embedding: Embeddings,
+ ) -> None:
+ super().__init__(llm, embedding)
+
+ # Create runnables (tasks & tools)
+ self.search_workflow: CompiledStateGraph = create_search_workflow(llm)
+ self.reactome_rag: Runnable = create_reactome_rag(llm, embedding)
+ self.uniprot_rag: Runnable = create_uniprot_rag(llm, embedding)
+
+ self.safety_checker = create_safety_checker(llm)
+ self.completeness_checker = create_completeness_grader(llm)
+ self.detect_language = create_language_detector(llm)
+ self.write_reactome_query = create_reactome_rewriter_w_uniprot(llm)
+ self.write_uniprot_query = create_uniprot_rewriter_w_reactome(llm)
+ self.summarize_final_answer = create_reactome_uniprot_summarizer(
+ llm.model_copy(update={"streaming": True})
+ )
+
+ # Create graph
+ state_graph = StateGraph(CrossDatabaseState)
+ # Set up nodes
+ state_graph.add_node("check_question_safety", self.check_question_safety)
+ state_graph.add_node("preprocess_question", self.preprocess)
+ state_graph.add_node("identify_query_language", self.identify_query_language)
+ state_graph.add_node("conduct_research", self.conduct_research)
+ state_graph.add_node("generate_reactome_answer", self.generate_reactome_answer)
+ state_graph.add_node("rewrite_reactome_query", self.rewrite_reactome_query)
+ state_graph.add_node("rewrite_reactome_answer", self.rewrite_reactome_answer)
+ state_graph.add_node("generate_uniprot_answer", self.generate_uniprot_answer)
+ state_graph.add_node("rewrite_uniprot_query", self.rewrite_uniprot_query)
+ state_graph.add_node("rewrite_uniprot_answer", self.rewrite_uniprot_answer)
+ state_graph.add_node("assess_completeness", self.assess_completeness)
+ state_graph.add_node("decide_next_steps", self.decide_next_steps)
+ state_graph.add_node("generate_final_response", self.generate_final_response)
+ state_graph.add_node("postprocess", self.postprocess)
+ # Set up edges
+ state_graph.set_entry_point("preprocess_question")
+ state_graph.add_edge("preprocess_question", "identify_query_language")
+ state_graph.add_edge("preprocess_question", "check_question_safety")
+ state_graph.add_conditional_edges(
+ "check_question_safety",
+ self.proceed_with_research,
+ {"Continue": "conduct_research", "Finish": "generate_final_response"},
+ )
+ state_graph.add_edge("conduct_research", "generate_reactome_answer")
+ state_graph.add_edge("conduct_research", "generate_uniprot_answer")
+ state_graph.add_edge("generate_reactome_answer", "assess_completeness")
+ state_graph.add_edge("generate_uniprot_answer", "assess_completeness")
+ state_graph.add_conditional_edges(
+ "assess_completeness",
+ self.decide_next_steps,
+ {
+ "generate_final_response": "generate_final_response",
+ "perform_web_search": "generate_final_response",
+ "rewrite_reactome_query": "rewrite_reactome_query",
+ "rewrite_uniprot_query": "rewrite_uniprot_query",
+ },
+ )
+ state_graph.add_edge("rewrite_reactome_query", "rewrite_reactome_answer")
+ state_graph.add_edge("rewrite_uniprot_query", "rewrite_uniprot_answer")
+ state_graph.add_edge("rewrite_reactome_answer", "generate_final_response")
+ state_graph.add_edge("rewrite_uniprot_answer", "generate_final_response")
+ state_graph.add_edge("generate_final_response", "postprocess")
+ state_graph.set_finish_point("postprocess")
+
+ self.uncompiled_graph: StateGraph = state_graph
+
+ async def check_question_safety(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ result: SafetyCheck = await self.safety_checker.ainvoke(
+ {"input": state["rephrased_input"]},
+ config,
+ )
+ if result.binary_score == "No":
+ inappropriate_input = f"This is the user's question and it is NOT appropriate for you to answer: {state["user_input"]}. \n\n explain that you are unable to answer the question but you can answer questions about topics related to the Reactome Pathway Knowledgebase or UniProt Knowledgebas."
+ return CrossDatabaseState(
+ safety=result.binary_score,
+ user_input=inappropriate_input,
+ reactome_answer="",
+ uniprot_answer="",
+ )
+ else:
+ return CrossDatabaseState(safety=result.binary_score)
+
+ async def proceed_with_research(
+ self, state: CrossDatabaseState
+ ) -> Literal["Continue", "Finish"]:
+ if state["safety"] == "Yes":
+ return "Continue"
+ else:
+ return "Finish"
+
+ async def identify_query_language(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ query_language: str = await self.detect_language.ainvoke(
+ {"user_input": state["user_input"]}, config
+ )
+ return CrossDatabaseState(query_language=query_language)
+
+ async def conduct_research(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ return CrossDatabaseState()
+
+ async def generate_reactome_answer(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ reactome_answer: dict[str, Any] = await self.reactome_rag.ainvoke(
+ {
+ "input": state["rephrased_input"],
+ "chat_history": state["chat_history"],
+ },
+ config,
+ )
+ return CrossDatabaseState(reactome_answer=reactome_answer["answer"])
+
+ async def generate_uniprot_answer(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ uniprot_answer: dict[str, Any] = await self.uniprot_rag.ainvoke(
+ {
+ "input": state["rephrased_input"],
+ "chat_history": state["chat_history"],
+ },
+ config,
+ )
+ return CrossDatabaseState(uniprot_answer=uniprot_answer["answer"])
+
+ async def rewrite_reactome_query(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ reactome_query: str = await self.write_reactome_query.ainvoke(
+ {
+ "input": state["rephrased_input"],
+ "uniprot_answer": state["uniprot_answer"],
+ },
+ config,
+ )
+ return CrossDatabaseState(reactome_query=reactome_query)
+
+ async def rewrite_uniprot_query(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ uniprot_query: str = await self.write_uniprot_query.ainvoke(
+ {
+ "input": state["rephrased_input"],
+ "reactome_answer": state["reactome_answer"],
+ },
+ config,
+ )
+ return CrossDatabaseState(uniprot_query=uniprot_query)
+
+ async def rewrite_reactome_answer(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ rewritten_answer: dict[str, Any] = await self.reactome_rag.ainvoke(
+ {
+ "input": state["reactome_query"],
+ "chat_history": state["chat_history"],
+ },
+ config,
+ )
+ return CrossDatabaseState(reactome_answer=rewritten_answer["answer"])
+
+ async def rewrite_uniprot_answer(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ rewritten_answer: dict[str, Any] = await self.uniprot_rag.ainvoke(
+ {
+ "input": state["uniprot_query"],
+ "chat_history": state["chat_history"],
+ },
+ config,
+ )
+ return CrossDatabaseState(uniprot_answer=rewritten_answer["answer"])
+
+ async def assess_completeness(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ reactome_completeness_async = self.completeness_checker.ainvoke(
+ {"input": state["rephrased_input"], "generation": state["reactome_answer"]},
+ config,
+ )
+ uniprot_completeness_async = self.completeness_checker.ainvoke(
+ {"input": state["rephrased_input"], "generation": state["uniprot_answer"]},
+ config,
+ )
+ reactome_completeness: CompletenessGrade = await reactome_completeness_async
+ uniprot_completeness: CompletenessGrade = await uniprot_completeness_async
+ return CrossDatabaseState(
+ reactome_completeness=reactome_completeness.binary_score,
+ uniprot_completeness=uniprot_completeness.binary_score,
+ )
+
+ async def decide_next_steps(self, state: CrossDatabaseState) -> Literal[
+ "generate_final_response",
+ "perform_web_search",
+ "rewrite_reactome_query",
+ "rewrite_uniprot_query",
+ ]:
+ reactome_complete = state["reactome_completeness"] != "No"
+ uniprot_complete = state["uniprot_completeness"] != "No"
+ if reactome_complete and uniprot_complete:
+ return "generate_final_response"
+ elif not reactome_complete and uniprot_complete:
+ return "rewrite_reactome_query"
+ elif reactome_complete and not uniprot_complete:
+ return "rewrite_uniprot_query"
+ else:
+ return "perform_web_search"
+
+ async def generate_final_response(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ final_response: str = await self.summarize_final_answer.ainvoke(
+ {
+ "input": state["rephrased_input"],
+ "query_language": state["query_language"],
+ "reactome_answer": state["reactome_answer"],
+ "uniprot_answer": state["uniprot_answer"],
+ },
+ config,
+ )
+ return CrossDatabaseState(
+ chat_history=[
+ HumanMessage(state["user_input"]),
+ AIMessage(final_response),
+ ],
+ answer=final_response,
+ )
+
+ async def postprocess(
+ self, state: CrossDatabaseState, config: RunnableConfig
+ ) -> CrossDatabaseState:
+ search_results: list[WebSearchResult] = []
+ if config["configurable"]["enable_postprocess"]:
+ result: SearchState = await self.search_workflow.ainvoke(
+ SearchState(
+ input=state["rephrased_input"],
+ generation=state["answer"],
+ ),
+ config=RunnableConfig(callbacks=config["callbacks"]),
+ )
+ search_results = result["search_results"]
+ return CrossDatabaseState(
+ additional_content=AdditionalContent(search_results=search_results)
+ )
+
+
+def create_cross_database_graph(
+ llm: BaseChatModel,
+ embedding: Embeddings,
+) -> StateGraph:
+ return CrossDatabaseGraphBuilder(llm, embedding).uncompiled_graph
diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py
index 7583f2f..35002b5 100644
--- a/src/agent/profiles/react_to_me.py
+++ b/src/agent/profiles/react_to_me.py
@@ -6,15 +6,14 @@
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.state import CompiledStateGraph, StateGraph
-from agent.profiles.base import BaseGraphBuilder, BaseState
-from agent.tasks.rephrase import create_rephrase_chain
+from agent.profiles.base import AdditionalContent, BaseGraphBuilder, BaseState
from retrievers.reactome.rag import create_reactome_rag
-from tools.external_search.state import WebSearchResult
+from tools.external_search.state import SearchState, WebSearchResult
from tools.external_search.workflow import create_search_workflow
class ReactToMeState(BaseState):
- rephrased_input: str # LLM-generated query from user input
+ pass
class ReactToMeGraphBuilder(BaseGraphBuilder):
@@ -23,11 +22,12 @@ def __init__(
llm: BaseChatModel,
embedding: Embeddings,
) -> None:
+ super().__init__(llm, embedding)
+
# Create runnables (tasks & tools)
self.reactome_rag: Runnable = create_reactome_rag(
llm, embedding, streaming=True
)
- self.rephrase_chain: Runnable = create_rephrase_chain(llm)
self.search_workflow: CompiledStateGraph = create_search_workflow(llm)
# Create graph
@@ -44,48 +44,43 @@ def __init__(
self.uncompiled_graph: StateGraph = state_graph
- async def preprocess(
- self, state: ReactToMeState, config: RunnableConfig
- ) -> dict[str, str]:
- query: str = await self.rephrase_chain.ainvoke(state, config)
- return {"rephrased_input": query}
-
async def call_model(
self, state: ReactToMeState, config: RunnableConfig
- ) -> dict[str, Any]:
+ ) -> ReactToMeState:
result: dict[str, Any] = await self.reactome_rag.ainvoke(
{
"input": state["rephrased_input"],
- "user_input": state["user_input"],
"chat_history": state["chat_history"],
},
config,
)
- return {
- "chat_history": [
+ return ReactToMeState(
+ chat_history=[
HumanMessage(state["user_input"]),
AIMessage(result["answer"]),
],
- "context": result["context"],
- "answer": result["answer"],
- }
+ answer=result["answer"],
+ )
async def postprocess(
self, state: ReactToMeState, config: RunnableConfig
- ) -> dict[str, dict[str, list[WebSearchResult]]]:
+ ) -> ReactToMeState:
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:
- result: dict[str, Any] = await self.search_workflow.ainvoke(
- {"question": state["rephrased_input"], "generation": state["answer"]},
+ result: SearchState = await self.search_workflow.ainvoke(
+ SearchState(
+ input=state["rephrased_input"],
+ generation=state["answer"],
+ ),
config=RunnableConfig(callbacks=config["callbacks"]),
)
search_results = result["search_results"]
- return {
- "additional_content": {"search_results": search_results},
- }
+ return ReactToMeState(
+ additional_content=AdditionalContent(search_results=search_results)
+ )
-def create_reacttome_graph(
+def create_reactome_graph(
llm: BaseChatModel,
embedding: Embeddings,
) -> StateGraph:
diff --git a/src/agent/tasks/completeness_grader.py b/src/agent/tasks/completeness_grader.py
new file mode 100644
index 0000000..129866e
--- /dev/null
+++ b/src/agent/tasks/completeness_grader.py
@@ -0,0 +1,30 @@
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.runnables import Runnable
+from pydantic import BaseModel, Field
+
+completeness_grader_message = """
+You are an expert grader with extensive knowledge in molecular biology and experience as a curator for both Reactome and UniProt knowledgebases.
+Your task is to evaluate whether a response generated by an LLM is complete, meaning it addresses the user’s question with necessary details, background information, and context.
+
+Provide a binary output as either:
+ - Yes: The response answers the user question and provides enough details and background.
+ - No: The response is incomplete, missing key details, or lacking sufficient context.
+"""
+
+completeness_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", completeness_grader_message),
+ ("human", "User question: \n\n {input} \n\n LLM generation: {generation}"),
+ ]
+)
+
+
+class CompletenessGrade(BaseModel):
+ binary_score: str = Field(
+ description="Answer is complete and provides all necessary background, 'Yes' or 'No'"
+ )
+
+
+def create_completeness_grader(llm: BaseChatModel) -> Runnable:
+ return completeness_prompt | llm.with_structured_output(CompletenessGrade)
diff --git a/src/agent/tasks/cross_database/rewrite_reactome_with_uniprot.py b/src/agent/tasks/cross_database/rewrite_reactome_with_uniprot.py
new file mode 100644
index 0000000..ee91dad
--- /dev/null
+++ b/src/agent/tasks/cross_database/rewrite_reactome_with_uniprot.py
@@ -0,0 +1,47 @@
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.runnables import Runnable
+
+reactome_rewriter_message = """
+You are a query optimization expert with deep knowledge of molecular biology and extensive experience curating the Reactome Pathway Knowledgebase. You are also skilled in leveraging UniProt data to enhance search precision.
+Your task is to generate a new, optimized search question that incorporates relevant UniProt-derived context to improve search results within Reactome.
+The Reactome Knowledgebase contains detailed information about human biological pathways, including specific pathways, related complexes, genes, proteins, and their roles in health and disease.
+
+The reformulated question must:
+ - Preserve the user’s intent while enriching it with relevant biological details.
+ - Integrate relevant insights from the UniProt response, such as protein names, functions, interactions, biological pathways and disease associations.
+ - Enhance search performance by optimizing for:
+ - Vector similarity search (semantic meaning).
+ - Case-sensitive keyword search (exact term matching).
+
+Task Breakdown
+ 1. Process Inputs:
+ - User’s Question: Understand the original query’s intent.
+ - UniProt Response: Extract key insights (protein function, interactions, pathway involvement, disease relevance).
+ 2. Reformulate the Question:
+ - Enhance with relevant biological context while keeping it concise.
+ - Avoid unnecessary details that dilute clarity.
+ 3. Optimize for Search Retrieval:
+ - Vector Search: Ensure the question captures semantic meaning for broad similarity matching.
+ - Optimize the query for Case-Sensitive Keyword Search
+
+Do NOT answer the question or provide explanations.
+"""
+
+
+reactome_rewriter_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", reactome_rewriter_message),
+ (
+ "human",
+ "Here is the initial question: \n\n {input} \n Here is UniProt-derived context:\n\n {uniprot_answer} ",
+ ),
+ ]
+)
+
+
+def create_reactome_rewriter_w_uniprot(llm: BaseChatModel) -> Runnable:
+ return (reactome_rewriter_prompt | llm | StrOutputParser()).with_config(
+ run_name="rewrite_reactome_query"
+ )
diff --git a/src/agent/tasks/cross_database/rewrite_uniprot_with_reactome.py b/src/agent/tasks/cross_database/rewrite_uniprot_with_reactome.py
new file mode 100644
index 0000000..5e5d65f
--- /dev/null
+++ b/src/agent/tasks/cross_database/rewrite_uniprot_with_reactome.py
@@ -0,0 +1,48 @@
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.runnables import Runnable
+
+uniprot_rewriter_message = """
+You are a query optimization expert with deep knowledge of molecular biology and extensive experience curating the UniProt Knowledgebase. You are also skilled in leveraging Reactome Pathway Knowledgebase data to enhance search precision.
+Your task is to reformulate user questions to maximize retrieval efficiency within UniProt’s Knowledgebase, which contains comprehensive information on human genes, proteins, protein domains/motifs, and protein functions.
+
+
+The reformulated question must:
+ - Preserve the user’s intent while enriching it with relevant biological details.
+ - Incorporate Reactome-derived insights, such as:
+ - Protein names and functions
+ - Molecular interactions
+ - Disease associations
+ -Optimizes for UniProt’s search retrieval, ensuring:
+ - Vector Similarity Search: Captures semantic meaning for broad relevance.
+ - Case-Sensitive Keyword Search: Improves retrieval of exact matches for key terms.
+Task Breakdown
+ 1. Process Inputs:
+ - User’s Question: Understand the original query’s intent.
+ - Reactome Response: Extract key insights (protein names, functions and interactions, pathway involvement, disease relevance etc.).
+ 2. Reformulate the Question:
+ - Enhance with relevant biological context while keeping it concise.
+ - Avoid unnecessary details that dilute clarity.
+ 3. Optimize for Search Retrieval:
+ - Vector Search: Ensure the question captures semantic meaning for broad similarity matching.
+ - Optimize the query for Case-Sensitive Keyword Search
+
+Do NOT answer the question or provide explanations.
+"""
+
+uniprot_rewriter_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", uniprot_rewriter_message),
+ (
+ "human",
+ "Here is the initial question: \n\n {input} \n Here is Reactome-derived context: \n\n{reactome_answer}",
+ ),
+ ]
+)
+
+
+def create_uniprot_rewriter_w_reactome(llm: BaseChatModel) -> Runnable:
+ return (uniprot_rewriter_prompt | llm | StrOutputParser()).with_config(
+ run_name="rewrite_uniprot_query"
+ )
diff --git a/src/agent/tasks/cross_database/summarize_reactome_uniprot.py b/src/agent/tasks/cross_database/summarize_reactome_uniprot.py
new file mode 100644
index 0000000..345c17d
--- /dev/null
+++ b/src/agent/tasks/cross_database/summarize_reactome_uniprot.py
@@ -0,0 +1,48 @@
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.runnables import Runnable
+
+summarization_message = """
+You are an expert in molecular biology with significant experience as a curator for the UniProt Database adn the Reactome Pathway Knowledgebase.
+Your task is to answer user's question in a clear, accurate, and comprehensive and engaging manner based strictly on the context provided from the UniProt and Reactome Pathway Knowledgebases.
+
+Instructions:
+ 1. Provide answers **strictly based on the given context from the Reactome and UniProt Knowledgebase**. Do **not** use or infer information from any external sources.
+ 2. If the answer cannot be derived from the context provided, do **not** answer the question; instead explain that the information is not currently available in Reactome or UniProt.
+ 3. Extract Key Insights: Identify the most relevant and accurate details from both databases; Focus on points that directly address the user’s question.
+ 4. Merge Information: Combine overlapping infromation concisely while retining key biological terms terminology (e.g., gene names, protein names, pathway names, disease involvement, etc.)
+ 5. Ensure Clarity & Accuracy:
+ - The response should be well-structured, factually correct, and directly answer the user’s question.
+ - Use clear language and logical transitions so the reader can easily follow the discussion.
+ 4. Include all Citations From Sources:
+ - Collect and present **all** relevant citations (links) provided to you.
+ - Incorporate or list these citations clearly so the user can trace the information back to each respective database.
+ - Example:
+ - Reactome Citations:
+ - Apoptosis
+ - Cell Cycle
+ - UniProt Citations:
+ - GATA6
+ - NR5A2
+
+ 5. Answer in the Language requested.
+ 6. Write in a conversational and engaging tone suitable for a chatbot.
+ 6. Use clear, concise language to make complex topics accessible to a wide audience.
+"""
+
+summarizer_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", summarization_message),
+ (
+ "human",
+ "User question: {input} \n\n Language: {query_language} \n\n Reactome-drived information: \n {reactome_answer} \n\n UniProt-drived infromation: \n {uniprot_answer}.",
+ ),
+ ]
+)
+
+
+def create_reactome_uniprot_summarizer(llm: BaseChatModel) -> Runnable:
+ return (summarizer_prompt | llm | StrOutputParser()).with_config(
+ run_name="summarize_answer"
+ )
diff --git a/src/agent/tasks/detect_language.py b/src/agent/tasks/detect_language.py
new file mode 100644
index 0000000..0c1eba2
--- /dev/null
+++ b/src/agent/tasks/detect_language.py
@@ -0,0 +1,28 @@
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.runnables import Runnable
+
+language_detection_message = """
+You are an expert linguist capable of identifying languages from text input.
+Your task is to determine the language of the user's question and return it as a single-word response.
+
+- Return only the language name in English (e.g., "French", "Farsi", "Simplified Chinese").
+- Do not return phrases, sentences, or explanations.
+- If the language is unknown or ambiguous, return "English".
+"""
+
+language_detection_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", language_detection_message),
+ ("human", "User question: \n\n {user_input}"),
+ ]
+)
+
+
+def create_language_detector(llm: BaseChatModel) -> Runnable:
+ return (
+ language_detection_prompt
+ | llm
+ | StrOutputParser().with_config(run_name="detect_language")
+ )
diff --git a/src/agent/tasks/safety_checker.py b/src/agent/tasks/safety_checker.py
new file mode 100644
index 0000000..ac372e3
--- /dev/null
+++ b/src/agent/tasks/safety_checker.py
@@ -0,0 +1,30 @@
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.runnables import Runnable
+from pydantic import BaseModel, Field
+
+safety_check_message = """
+You are an expert grader with a strong understanding of ethical guidelines and responsible AI usage.
+
+Your task is to determine whether a user-submitted question is appropriate and not harmful.
+Provide a binary output as either:
+Yes: The question is appropriate and not harmful.
+No: The question is harmful, offensive, inappropriate, or unethical.
+"""
+
+safety_check_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", safety_check_message),
+ ("human", "User question: \n\n {input}"),
+ ]
+)
+
+
+class SafetyCheck(BaseModel):
+ binary_score: str = Field(
+ description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'Yes' or 'No'."
+ )
+
+
+def create_safety_checker(llm: BaseChatModel) -> Runnable:
+ return safety_check_prompt | llm.with_structured_output(SafetyCheck)
diff --git a/src/data_generation/reactome/__init__.py b/src/data_generation/reactome/__init__.py
index fd73ec6..ce36442 100644
--- a/src/data_generation/reactome/__init__.py
+++ b/src/data_generation/reactome/__init__.py
@@ -52,9 +52,14 @@ def upload_to_chromadb(
docs = loader.load()
embeddings_instance: Embeddings
if hf_model is None: # Use OpenAI
- embeddings_instance = OpenAIEmbeddings()
+ embeddings_instance = OpenAIEmbeddings(
+ show_progress_bar=True,
+ )
elif hf_model.startswith("openai/text-embedding-"):
- embeddings_instance = OpenAIEmbeddings(model=hf_model[len("openai/") :])
+ embeddings_instance = OpenAIEmbeddings(
+ model=hf_model[len("openai/") :],
+ show_progress_bar=True,
+ )
elif "HUGGINGFACEHUB_API_TOKEN" in os.environ:
embeddings_instance = HuggingFaceEndpointEmbeddings(
huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
diff --git a/src/data_generation/uniprot/__init__.py b/src/data_generation/uniprot/__init__.py
new file mode 100644
index 0000000..f5b24fb
--- /dev/null
+++ b/src/data_generation/uniprot/__init__.py
@@ -0,0 +1,87 @@
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+from langchain_community.vectorstores import Chroma
+from langchain_core.embeddings import Embeddings
+from langchain_huggingface import (HuggingFaceEmbeddings,
+ HuggingFaceEndpointEmbeddings)
+from langchain_openai import OpenAIEmbeddings
+
+from data_generation.metadata_csv_loader import MetaDataCSVLoader
+from data_generation.uniprot.csv_generator import generate_uniprot_csv
+
+
+def upload_to_chromadb(
+ embeddings_dir: str,
+ file: str,
+ embedding_table: str,
+ hf_model: Optional[str] = None,
+ device: Optional[str] = None,
+) -> Chroma:
+ metadata_columns: dict[str, list] = {
+ "uniprot_data": [
+ "gene_names",
+ "short_protein_name",
+ "full_protein_name",
+ "protein_family",
+ "biological_pathways",
+ ],
+ }
+
+ loader = MetaDataCSVLoader(
+ file_path=file,
+ metadata_columns=metadata_columns[embedding_table],
+ encoding="utf-8",
+ )
+
+ docs = loader.load()
+ print(f"Loaded {len(docs)} documents from {file}")
+
+ embeddings_instance: Embeddings
+ if hf_model is None: # Use OpenAI
+ print("Using OpenAI embeddings")
+ embeddings_instance = OpenAIEmbeddings(
+ model="text-embedding-3-large",
+ chunk_size=800,
+ show_progress_bar=True,
+ )
+ elif hf_model.startswith("openai/text-embedding-"):
+ embeddings_instance = OpenAIEmbeddings(
+ model=hf_model[len("openai/") :],
+ chunk_size=800,
+ show_progress_bar=True,
+ )
+ elif "HUGGINGFACEHUB_API_TOKEN" in os.environ:
+ embeddings_instance = HuggingFaceEndpointEmbeddings(
+ huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
+ model=hf_model,
+ )
+ else:
+ if device == "cuda":
+ torch.cuda.empty_cache()
+ embeddings_instance = HuggingFaceEmbeddings(
+ model_name=hf_model,
+ model_kwargs={"device": device, "trust_remote_code": True},
+ encode_kwargs={"batch_size": 12, "normalize_embeddings": False},
+ )
+
+ return Chroma.from_documents(
+ documents=docs,
+ embedding=embeddings_instance,
+ persist_directory=os.path.join(embeddings_dir, embedding_table),
+ )
+
+
+def generate_uniprot_embeddings(
+ embedding_path: Path,
+ hf_model: Optional[str] = None,
+ device: Optional[str] = None,
+ **_,
+) -> None:
+ csv_path = generate_uniprot_csv(embedding_path)
+ db = upload_to_chromadb(
+ str(embedding_path), str(csv_path), "uniprot_data", hf_model, device
+ )
+ print(db._collection.count())
diff --git a/src/data_generation/uniprot/api_connector.py b/src/data_generation/uniprot/api_connector.py
new file mode 100644
index 0000000..50d0844
--- /dev/null
+++ b/src/data_generation/uniprot/api_connector.py
@@ -0,0 +1,52 @@
+import re
+
+import requests
+from requests.adapters import HTTPAdapter, Retry
+
+
+class UniProtAPIConnector:
+ BASE_URL = "https://rest.uniprot.org/uniprotkb/stream"
+
+ @staticmethod
+ def get_download_url():
+ """
+ Returns the UniProt API URL for downloading human-reviewed protein data.
+ """
+ query_params = (
+ "?fields=accession%2Cgene_names%2Cid%2Cprotein_name%2Cprotein_families%2Cmass"
+ "%2Cft_domain%2Ccc_domain%2Cft_motif%2Ccc_subunit%2Ccc_pathway%2Ccc_induction"
+ "%2Ccc_activity_regulation%2Ccc_subcellular_location%2Ccc_tissue_specificity"
+ "%2Ccc_disease%2Ccc_function%2Ccc_miscellaneous&format=xlsx"
+ "&query=%28reviewed%3Atrue%29+AND+%28model_organism%3A9606%29+AND+%28reviewed%3Atrue%29"
+ )
+ return UniProtAPIConnector.BASE_URL + query_params
+
+ def __init__(self):
+ self.session = self._initialize_session()
+
+ def _initialize_session(self):
+ """Creates a session with retry logic for robust downloading."""
+ retries = Retry(
+ total=5, backoff_factor=0.25, status_forcelist=[500, 502, 503, 504]
+ )
+ session = requests.Session()
+ session.mount("https://", HTTPAdapter(max_retries=retries))
+ return session
+
+ def get_next_link(self, headers):
+ """Parses the 'Link' header to find the URL for the next batch of data."""
+ re_next_link = re.compile(r'<(.+)>; rel="next"')
+ if "Link" in headers:
+ match = re_next_link.match(headers["Link"])
+ if match:
+ return match.group(1)
+ return None
+
+ def get_batch(self, batch_url):
+ """Generator to download data in batches."""
+ while batch_url:
+ response = self.session.get(batch_url)
+ response.raise_for_status() # Ensure we stop on HTTP errors
+ total = response.headers.get("x-total-results", 0)
+ yield response, total
+ batch_url = self.get_next_link(response.headers)
diff --git a/src/data_generation/uniprot/csv_generator.py b/src/data_generation/uniprot/csv_generator.py
new file mode 100644
index 0000000..5159226
--- /dev/null
+++ b/src/data_generation/uniprot/csv_generator.py
@@ -0,0 +1,164 @@
+import re
+from pathlib import Path
+
+import pandas as pd
+
+from data_generation.uniprot.api_connector import UniProtAPIConnector
+
+
+class UniProtDataCleaner:
+ def __init__(self, csv_dir: Path):
+ self.download_url = UniProtAPIConnector.get_download_url()
+ self.xlsx_path = csv_dir / "uniprot_data.xlsx"
+ self.csv_path = self.xlsx_path.with_suffix(".csv")
+ self.df = None
+ self.api = UniProtAPIConnector()
+
+ def download_data(self):
+ """Downloads data batch by batch using UniProt API connector."""
+ progress = 0
+ with open(self.xlsx_path, "wb") as f:
+ for response, total in self.api.get_batch(self.download_url):
+ f.write(response.content)
+ progress += 1
+ print(f"Downloaded {progress} batches; Total: {total}")
+ print(f"✅ UniProt data downloaded successfully to {self.xlsx_path}")
+
+ def load_data(self):
+ """Loads data from Excel file into a DataFrame."""
+ print(f"Loading data from {self.xlsx_path}")
+ self.df = pd.read_excel(self.xlsx_path)
+
+ def clean_data(self):
+ """Cleans the UniProt data using predefined processing steps."""
+ self.load_data()
+ self.remove_prefixes()
+ self.format_mass()
+ self.clean_evidence_codes()
+ self.clean_columns()
+ self.add_url()
+ self.format_names()
+ self.rename_columns()
+ self.df.to_csv(self.csv_path, index=False)
+ print(f"Cleaned data saved to {self.csv_path}")
+
+ def remove_prefixes(self):
+ """Remove prefixes from specified columns."""
+ prefix_map = {
+ "Entry Name": "_HUMAN",
+ "Pathway": "PATHWAY: ",
+ "Subunit structure": "SUBUNIT: ",
+ "Subcellular location [CC]": "SUBCELLULAR LOCATION: ",
+ "Domain [CC]": "DOMAIN: ",
+ "Tissue specificity": "TISSUE SPECIFICITY: ",
+ "Involvement in disease": "DISEASE: ",
+ "Function [CC]": "FUNCTION: ",
+ "Miscellaneous [CC]": "MISCELLANEOUS: ",
+ "Induction": "INDUCTION: ",
+ "Activity regulation": "ACTIVITY REGULATION:",
+ }
+ for column, prefix in prefix_map.items():
+ if column in self.df.columns:
+ self.df[column] = (
+ self.df[column].str.replace(prefix, "", regex=False).str.strip()
+ )
+
+ def add_url(self):
+ """Replace 'Entry' column with URLs constructed from entry IDs."""
+ base_url = "https://www.uniprot.org/uniprotkb/"
+ self.df["Entry"] = base_url + self.df["Entry"].astype(str) + "/entry"
+
+ def format_names(self):
+ """Format gene synonyms and protein names with semicolons and proper punctuation."""
+ self.df["Gene Names"] = (
+ self.df["Gene Names"].str.replace(" ", "; ", regex=False).str.strip("; ")
+ )
+ self.df["Protein names"] = self.df["Protein names"].apply(
+ lambda x: "; ".join(
+ [
+ item.strip()
+ for item in re.split(
+ r"\) \(", x.replace("(", "; ").replace(")", "")
+ )
+ ]
+ )
+ )
+
+ def format_mass(self):
+ """Format the 'Mass' column by appending ' Da' to each mass value."""
+ if "Mass" in self.df.columns:
+ self.df["Mass"] = self.df["Mass"].apply(lambda x: f"{x} Da")
+
+ def clean_evidence_codes(self):
+ """Remove citations and evidence codes from textual columns."""
+ patterns = [r"\{ECO:[^\}]*\}", r"\(PubMed:[^\)]*\)", r"\[MIM:[^\]]*\]", r" +"]
+ for column in self.df.columns:
+ for pattern in patterns:
+ self.df[column] = (
+ self.df[column].str.replace(pattern, "", regex=True).str.strip()
+ )
+
+ def clean_columns(self):
+ """Reformat entries in the 'Motif' column."""
+
+ def reformat_motif(entry):
+ if pd.isna(entry):
+ return entry
+ pattern = r"MOTIF (\d+\.\.\d+); /note=\"([^\"]*)\"; /evidence=\"[^\"]*\""
+ matches = re.findall(pattern, entry)
+ return "; ".join(
+ [
+ f"Has a {note} at position {pos.replace('..', '-')}"
+ for pos, note in matches
+ ]
+ )
+
+ def reformat_domain(entry):
+ if pd.isna(entry):
+ return entry
+ pattern = r"DOMAIN (\d+\.\.\d+); /note=\"([^\"]*)\"; /evidence=\"[^\"]*\""
+ matches = re.findall(pattern, entry)
+ return "; ".join(
+ [
+ f"Has a {note} domain at position {pos.replace('..', '-')}"
+ for pos, note in matches
+ ]
+ )
+
+ self.df["Motif"] = self.df["Motif"].apply(reformat_motif)
+ self.df["Domain [FT]"] = self.df["Domain [FT]"].apply(reformat_domain)
+
+ def rename_columns(self):
+ """Rename columns as specified."""
+ new_column_names = {
+ "Entry": "url",
+ "Gene Names": "gene_names",
+ "Entry Name": "short_protein_name",
+ "Protein names": "full_protein_name",
+ "Protein families": "protein_family",
+ "Mass": "molecular_weight",
+ "Domain [FT]": "protein_domains",
+ "Domain [CC]": "domain_annotations",
+ "Motif": "protein_motif",
+ "Subunit structure": "subunit_structure",
+ "Pathway": "biological_pathways",
+ "Induction": "expression_induction",
+ "Activity regulation": "activity_regulation",
+ "Subcellular location [CC]": "subcellular_localization",
+ "Tissue specificity": "tissue_expression",
+ "Involvement in disease": "disease_associations",
+ "Function [CC]": "protein_function",
+ "Miscellaneous [CC]": "additional_notes",
+ }
+ self.df.rename(columns=new_column_names, inplace=True)
+
+
+def generate_uniprot_csv(parent_dir: Path) -> Path:
+ csv_dir = Path(parent_dir) / "csv_files"
+ csv_dir.mkdir(parents=True, exist_ok=True)
+
+ cleaner = UniProtDataCleaner(csv_dir)
+ cleaner.download_data()
+ cleaner.clean_data()
+ cleaner.xlsx_path.unlink(missing_ok=True)
+ return cleaner.csv_path
diff --git a/src/evaluation/evaluator.py b/src/evaluation/evaluator.py
index c82c6e8..9d4d6da 100644
--- a/src/evaluation/evaluator.py
+++ b/src/evaluation/evaluator.py
@@ -11,12 +11,15 @@
from langchain_community.retrievers import BM25Retriever
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from ragas import evaluate
-from ragas.metrics import (answer_relevancy, context_recall,
- context_utilization, faithfulness)
+from ragas.metrics import (ContextUtilization, answer_relevancy,
+ context_recall, faithfulness)
from retrievers.rag_chain import create_rag_chain
-from retrievers.reactome.metadata_info import descriptions_info, field_info
-from retrievers.reactome.prompt import qa_prompt
+from retrievers.reactome.metadata_info import (reactome_descriptions_info,
+ reactome_field_info)
+from retrievers.reactome.prompt import reactome_qa_prompt
+
+context_utilization = ContextUtilization()
def parse_arguments():
@@ -82,8 +85,8 @@ def initialize_rag_chain_with_memory(embeddings_directory, model_name, rag_type)
selfq_retriever = SelfQueryRetriever.from_llm(
llm=llm,
vectorstore=vectordb,
- document_contents=descriptions_info["summations"],
- metadata_field_info=field_info["summations"],
+ document_contents=reactome_descriptions_info["summations"],
+ metadata_field_info=reactome_field_info["summations"],
search_kwargs={"k": 7},
)
rrf_retriever = EnsembleRetriever(
@@ -99,7 +102,7 @@ def initialize_rag_chain_with_memory(embeddings_directory, model_name, rag_type)
qa = create_rag_chain(
retriever=reactome_retriever,
llm=llm,
- qa_prompt=qa_prompt,
+ qa_prompt=reactome_qa_prompt,
)
return qa
diff --git a/src/retrievers/csv_chroma.py b/src/retrievers/csv_chroma.py
new file mode 100644
index 0000000..dedfdcb
--- /dev/null
+++ b/src/retrievers/csv_chroma.py
@@ -0,0 +1,64 @@
+from pathlib import Path
+
+import chromadb.config
+from langchain.chains.query_constructor.schema import AttributeInfo
+from langchain.retrievers import EnsembleRetriever
+from langchain.retrievers.merger_retriever import MergerRetriever
+from langchain.retrievers.self_query.base import SelfQueryRetriever
+from langchain_chroma.vectorstores import Chroma
+from langchain_community.document_loaders.csv_loader import CSVLoader
+from langchain_community.retrievers import BM25Retriever
+from langchain_core.embeddings import Embeddings
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.retrievers import BaseRetriever
+
+chroma_settings = chromadb.config.Settings(anonymized_telemetry=False)
+
+
+def list_chroma_subdirectories(directory: Path) -> list[str]:
+ subdirectories = list(
+ chroma_file.parent.name for chroma_file in directory.glob("*/chroma.sqlite3")
+ )
+ return subdirectories
+
+
+def create_bm25_chroma_ensemble_retriever(
+ llm: BaseChatModel,
+ embedding: Embeddings,
+ embeddings_directory: Path,
+ *,
+ descriptions_info: dict[str, str],
+ field_info: dict[str, list[AttributeInfo]],
+) -> MergerRetriever:
+ retriever_list: list[BaseRetriever] = []
+ for subdirectory in list_chroma_subdirectories(embeddings_directory):
+ # set up BM25 retriever
+ csv_file_name = subdirectory + ".csv"
+ reactome_csvs_dir: Path = embeddings_directory / "csv_files"
+ loader = CSVLoader(file_path=reactome_csvs_dir / csv_file_name)
+ data = loader.load()
+ bm25_retriever = BM25Retriever.from_documents(data)
+ bm25_retriever.k = 10
+
+ # set up vectorstore SelfQuery retriever
+ vectordb = Chroma(
+ persist_directory=str(embeddings_directory / subdirectory),
+ embedding_function=embedding,
+ client_settings=chroma_settings,
+ )
+
+ selfq_retriever = SelfQueryRetriever.from_llm(
+ llm=llm,
+ vectorstore=vectordb,
+ document_contents=descriptions_info[subdirectory],
+ metadata_field_info=field_info[subdirectory],
+ search_kwargs={"k": 10},
+ )
+ rrf_retriever = EnsembleRetriever(
+ retrievers=[bm25_retriever, selfq_retriever], weights=[0.2, 0.8]
+ )
+ retriever_list.append(rrf_retriever)
+
+ reactome_retriever = MergerRetriever(retrievers=retriever_list)
+
+ return reactome_retriever
diff --git a/src/retrievers/reactome/metadata_info.py b/src/retrievers/reactome/metadata_info.py
index cb56350..b9fd251 100644
--- a/src/retrievers/reactome/metadata_info.py
+++ b/src/retrievers/reactome/metadata_info.py
@@ -7,7 +7,7 @@
The relationship between 'reaction_name' and 'pathway_name' is foundational, with each reaction serving as a step or component within the overarching pathway, contributing to its completion and functional outcome.\
This relationship is critical to understanding the biological processes and mechanisms within the Reactome Database."
-descriptions_info: dict[str, str] = {
+reactome_descriptions_info: dict[str, str] = {
"ewas": "Contains data on proteins and nucleic acids with known sequences. Includes entity names, IDs, canonical and synonymous gene names, and functions.",
"complexes": "Catalogs biological complexes, listing complex names and IDs along with the names and IDs of their components. ",
"reactions": "Documents biological pathways and their constituent reactions, detailing pathway and reaction names and IDs. It includes information on the inputs, outputs, and catalysts for each reaction, emphasizing the interconnected nature of cellular processes. Inputs and outputs, critical to the initiation and conclusion of reactions, along with catalysts that facilitate these processes, are cataloged to highlight their roles across various reactions and pathways",
@@ -15,7 +15,7 @@
}
-field_info: dict[str, list[AttributeInfo]] = {
+reactome_field_info: dict[str, list[AttributeInfo]] = {
"summations": [
AttributeInfo(
name="st_id",
diff --git a/src/retrievers/reactome/prompt.py b/src/retrievers/reactome/prompt.py
index 1a004ed..9a11526 100644
--- a/src/retrievers/reactome/prompt.py
+++ b/src/retrievers/reactome/prompt.py
@@ -1,7 +1,6 @@
-from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
+from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
-# Answer generation prompt
-qa_system_prompt = """
+reactome_system_prompt = """
You are an expert in molecular biology with access to the Reactome Knowledgebase.
Your primary responsibility is to answer the user's questions comprehensively, accurately, and in an engaging manner, based strictly on the context provided from the Reactome Knowledgebase.
Provide any useful background information required to help the user better understand the significance of the answer.
@@ -23,10 +22,10 @@
8. Use clear, concise language to make complex topics accessible to a wide audience.
"""
-qa_prompt = ChatPromptTemplate.from_messages(
+reactome_qa_prompt = ChatPromptTemplate.from_messages(
[
- ("system", qa_system_prompt),
+ ("system", reactome_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
- ("user", "Context:\n{context}\n\nQuestion: {user_input}"),
+ ("user", "Context:\n{context}\n\nQuestion: {input}"),
]
)
diff --git a/src/retrievers/reactome/rag.py b/src/retrievers/reactome/rag.py
index 701cbac..485b6e5 100644
--- a/src/retrievers/reactome/rag.py
+++ b/src/retrievers/reactome/rag.py
@@ -1,31 +1,16 @@
from pathlib import Path
-import chromadb.config
-from langchain.retrievers import EnsembleRetriever
-from langchain.retrievers.merger_retriever import MergerRetriever
-from langchain.retrievers.self_query.base import SelfQueryRetriever
-from langchain_chroma.vectorstores import Chroma
-from langchain_community.document_loaders.csv_loader import CSVLoader
-from langchain_community.retrievers import BM25Retriever
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
-from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
+from retrievers.csv_chroma import create_bm25_chroma_ensemble_retriever
from retrievers.rag_chain import create_rag_chain
-from retrievers.reactome.metadata_info import descriptions_info, field_info
-from retrievers.reactome.prompt import qa_prompt
+from retrievers.reactome.metadata_info import (reactome_descriptions_info,
+ reactome_field_info)
+from retrievers.reactome.prompt import reactome_qa_prompt
from util.embedding_environment import EmbeddingEnvironment
-chroma_settings = chromadb.config.Settings(anonymized_telemetry=False)
-
-
-def list_chroma_subdirectories(directory: Path) -> list[str]:
- subdirectories = list(
- chroma_file.parent.name for chroma_file in directory.glob("*/chroma.sqlite3")
- )
- return subdirectories
-
def create_reactome_rag(
llm: BaseChatModel,
@@ -34,38 +19,15 @@ def create_reactome_rag(
*,
streaming: bool = False,
) -> Runnable:
- retriever_list: list[BaseRetriever] = []
- for subdirectory in list_chroma_subdirectories(embeddings_directory):
- # set up BM25 retriever
- csv_file_name = subdirectory + ".csv"
- reactome_csvs_dir: Path = embeddings_directory / "csv_files"
- loader = CSVLoader(file_path=reactome_csvs_dir / csv_file_name)
- data = loader.load()
- bm25_retriever = BM25Retriever.from_documents(data)
- bm25_retriever.k = 10
-
- # set up vectorstore SelfQuery retriever
- vectordb = Chroma(
- persist_directory=str(embeddings_directory / subdirectory),
- embedding_function=embedding,
- client_settings=chroma_settings,
- )
-
- selfq_retriever = SelfQueryRetriever.from_llm(
- llm=llm,
- vectorstore=vectordb,
- document_contents=descriptions_info[subdirectory],
- metadata_field_info=field_info[subdirectory],
- search_kwargs={"k": 10},
- )
- rrf_retriever = EnsembleRetriever(
- retrievers=[bm25_retriever, selfq_retriever], weights=[0.2, 0.8]
- )
- retriever_list.append(rrf_retriever)
-
- reactome_retriever = MergerRetriever(retrievers=retriever_list)
+ reactome_retriever = create_bm25_chroma_ensemble_retriever(
+ llm,
+ embedding,
+ embeddings_directory,
+ descriptions_info=reactome_descriptions_info,
+ field_info=reactome_field_info,
+ )
if streaming:
llm = llm.model_copy(update={"streaming": True})
- return create_rag_chain(llm, reactome_retriever, qa_prompt)
+ return create_rag_chain(llm, reactome_retriever, reactome_qa_prompt)
diff --git a/src/retrievers/uniprot/metadata_info.py b/src/retrievers/uniprot/metadata_info.py
new file mode 100644
index 0000000..0b7aa75
--- /dev/null
+++ b/src/retrievers/uniprot/metadata_info.py
@@ -0,0 +1,39 @@
+from langchain.chains.query_constructor.base import AttributeInfo
+
+uniprot_descriptions_info = {
+ "uniprot_data": "Contains detailed protein information about gene names, protein names, subcellular localizations, family classifications, biological pathway associations, domains, motifs, disease associations, and functional descriptions. ",
+}
+uniprot_field_info: dict[str, list[AttributeInfo]] = {
+ "uniprot_data": [
+ AttributeInfo(
+ name="gene_names",
+ description="The official gene name(s) associated with the protein. Gene names may include primary and alternative names \
+ used in different research contexts or species-specific databases.",
+ type="string",
+ ),
+ AttributeInfo(
+ name="short_protein_name",
+ description="The short, standardized name for the protein entry, often derived from its gene name or commonly used abbreviation. \
+ This provides a concise reference to the protein.",
+ type="string",
+ ),
+ AttributeInfo(
+ name="full_protein_name",
+ description="The complete and descriptive name of the protein, detailing its function, structure, or significant features. \
+ This name is derived from biological literature and protein function annotations.",
+ type="string",
+ ),
+ AttributeInfo(
+ name="protein_family",
+ description="The family or group of related proteins to which this protein belongs, based on sequence similarity, \
+ structural features, or shared functional characteristics.",
+ type="string",
+ ),
+ AttributeInfo(
+ name="biological_pathways",
+ description="The biological pathways in which the protein is involved, as curated from databases like Reactome or KEGG. \
+ This provides insights into the protein's role in metabolic, signaling, or regulatory networks.",
+ type="string",
+ ),
+ ]
+}
diff --git a/src/retrievers/uniprot/prompt.py b/src/retrievers/uniprot/prompt.py
new file mode 100644
index 0000000..7cb0910
--- /dev/null
+++ b/src/retrievers/uniprot/prompt.py
@@ -0,0 +1,31 @@
+from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
+
+uniprot_system_prompt = """
+You are an expert in molecular biology with access to the UniProt Knowledgebase.
+Your primary responsibility is to answer the user's questions comprehensively, accurately, and in an engaging manner, based strictly on the context provided from the UniProt Knowledgebase.
+Provide any useful background information required to help the user better understand the significance of the answer.
+Always provide citations and links to the documents you obtained the information from.
+
+When providing answers, please adhere to the following guidelines:
+1. Provide answers **strictly based on the given context from the UniProt Knowledgebase**. Do **not** use or infer information from any external sources.
+2. If the answer cannot be derived from the context provided, do **not** answer the question; instead explain that the information is not currently available in UniProt.
+3. Answer the question comprehensively and accurately, providing useful background information based **only** on the context.
+4. keep track of **all** the sources that are directly used to derive the final answer, ensuring **every** piece of information in your response is **explicitly cited**.
+5. Create Citations for the sources used to generate the final asnwer according to the following:
+ - For Reactome always format citations in the following format: *short_protein_name*.
+ Examples:
+ - GATA6
+ - NR5A2
+
+6. Always provide the citations you created in the format requested, in point-form at the end of the response paragraph, ensuring **every piece of information** provided in the final answer is cited.
+7. Write in a conversational and engaging tone suitable for a chatbot.
+8. Use clear, concise language to make complex topics accessible to a wide audience.
+"""
+
+uniprot_qa_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", uniprot_system_prompt),
+ MessagesPlaceholder(variable_name="chat_history"),
+ ("user", "Context:\n{context}\n\nQuestion: {input}"),
+ ]
+)
diff --git a/src/retrievers/uniprot/rag.py b/src/retrievers/uniprot/rag.py
new file mode 100644
index 0000000..99702d7
--- /dev/null
+++ b/src/retrievers/uniprot/rag.py
@@ -0,0 +1,33 @@
+from pathlib import Path
+
+from langchain_core.embeddings import Embeddings
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.runnables import Runnable
+
+from retrievers.csv_chroma import create_bm25_chroma_ensemble_retriever
+from retrievers.rag_chain import create_rag_chain
+from retrievers.uniprot.metadata_info import (uniprot_descriptions_info,
+ uniprot_field_info)
+from retrievers.uniprot.prompt import uniprot_qa_prompt
+from util.embedding_environment import EmbeddingEnvironment
+
+
+def create_uniprot_rag(
+ llm: BaseChatModel,
+ embedding: Embeddings,
+ embeddings_directory: Path = EmbeddingEnvironment.get_dir("uniprot"),
+ *,
+ streaming: bool = False,
+) -> Runnable:
+ reactome_retriever = create_bm25_chroma_ensemble_retriever(
+ llm,
+ embedding,
+ embeddings_directory,
+ descriptions_info=uniprot_descriptions_info,
+ field_info=uniprot_field_info,
+ )
+
+ if streaming:
+ llm = llm.model_copy(update={"streaming": True})
+
+ return create_rag_chain(llm, reactome_retriever, uniprot_qa_prompt)
diff --git a/src/tools/external_search/completeness_grader.py b/src/tools/external_search/completeness_grader.py
deleted file mode 100644
index 9c33d2d..0000000
--- a/src/tools/external_search/completeness_grader.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from langchain_core.language_models.chat_models import BaseChatModel
-from langchain_core.prompts import ChatPromptTemplate
-from langchain_core.runnables import Runnable, RunnableConfig
-from pydantic import BaseModel, Field
-
-from tools.external_search.state import GraphState
-
-completeness_grader_message = """
-You are an expert grader with extensive knowledge in molecular biology and experience as a Reactome curator.
-Your task is to evaluate whether a response generated by an LLM is complete, meaning it fully addresses the user’s question with all necessary details, background information, and context.
-Additionally, assess whether the question is appropriate and directly related to molecular biology or molecular biology research.
-Based on this evaluation, determine whether an external search should be conducted.
-Provide a binary output as either:
-Yes: The response is incomplete, missing key details, or lacking sufficient context, AND the question is appropriate and directly related to molecular biology, therefore external search should be conducted.
-No: Either the response is complete (fully answers the query, provides background, and leaves no essential details missing), OR the question is inappropriate, harmful, or not related to molecular biology, therefore no external search should be conducted.
-Ensure your evaluation is based solely on the information requested in the query, the adequacy of the response, and the appropriateness of the question.
-"""
-
-completeness_prompt = ChatPromptTemplate.from_messages(
- [
- ("system", completeness_grader_message),
- ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
- ]
-)
-
-
-class GradeCompleteness(BaseModel):
- binary_score: str = Field(
- description="Answer is complete and provides all necessary background, 'Yes' or 'No'"
- )
-
-
-class CompletenessGrader:
- def __init__(self, llm: BaseChatModel):
- structured_completeness_grader: Runnable = llm.with_structured_output(
- GradeCompleteness
- )
- self.runnable: Runnable = completeness_prompt | structured_completeness_grader
-
- async def ainvoke(
- self, state: GraphState, config: RunnableConfig
- ) -> dict[str, str]:
- result: GradeCompleteness = await self.runnable.ainvoke(
- {
- "question": state["question"],
- "generation": state["generation"],
- },
- config,
- )
- return {"external_search": result.binary_score}
diff --git a/src/tools/external_search/state.py b/src/tools/external_search/state.py
index 71c0884..034e994 100644
--- a/src/tools/external_search/state.py
+++ b/src/tools/external_search/state.py
@@ -7,8 +7,8 @@ class WebSearchResult(TypedDict):
content: str
-class GraphState(TypedDict):
- question: str # User question
+class SearchState(TypedDict, total=False):
+ input: str # LLM enhanced User question
generation: str # LLM generated reponse to the user question
- external_search: str # "Yes" or "No" to search for external resources
+ complete: str # "Yes" or "No" to search for external resources
search_results: list[WebSearchResult] # Results from searching the web
diff --git a/src/tools/external_search/tavily_wrapper.py b/src/tools/external_search/tavily_wrapper.py
index d3a768a..54e373b 100644
--- a/src/tools/external_search/tavily_wrapper.py
+++ b/src/tools/external_search/tavily_wrapper.py
@@ -4,7 +4,7 @@
from tavily import AsyncTavilyClient, MissingAPIKeyError
-from tools.external_search.state import GraphState, WebSearchResult
+from tools.external_search.state import SearchState, WebSearchResult
from util.logging import logging
@@ -64,10 +64,10 @@ async def search(self, query: str) -> list[WebSearchResult]:
if all(key in result for key in ["title", "url"])
]
- async def ainvoke(self, state: GraphState) -> dict[str, list[WebSearchResult]]:
- query: str = state["question"]
+ async def ainvoke(self, state: SearchState) -> SearchState:
+ query: str = state["input"]
search_results: list[WebSearchResult] = await self.search(query)
- return {"search_results": search_results}
+ return SearchState(search_results=search_results)
@staticmethod
def format_results(web_search_results: list[WebSearchResult]) -> str:
diff --git a/src/tools/external_search/workflow.py b/src/tools/external_search/workflow.py
index af746dd..0a409c8 100644
--- a/src/tools/external_search/workflow.py
+++ b/src/tools/external_search/workflow.py
@@ -1,33 +1,56 @@
+from typing import Literal
+
from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph import StateGraph
from langgraph.graph.state import CompiledStateGraph
+from langgraph.utils.runnable import RunnableLike
-from tools.external_search.completeness_grader import CompletenessGrader
-from tools.external_search.state import GraphState
+from agent.tasks.completeness_grader import (CompletenessGrade,
+ create_completeness_grader)
+from tools.external_search.state import SearchState
from tools.external_search.tavily_wrapper import TavilyWrapper
-def decide_next_steps(state: GraphState) -> str:
- if state["external_search"] == "Yes":
+def decide_next_steps(state: SearchState) -> Literal["perform_web_search", "no_search"]:
+ if state["complete"] == "No":
return "perform_web_search"
else:
return "no_search"
-def no_search(_) -> dict[str, list]:
- return {"search_results": []}
+def no_search(_) -> SearchState:
+ return SearchState(search_results=[])
+
+
+def run_completeness_grader(grader: Runnable) -> RunnableLike:
+ async def _run_completeness_grader(
+ state: SearchState, config: RunnableConfig
+ ) -> SearchState:
+ result: CompletenessGrade = await grader.ainvoke(
+ {
+ "input": state["input"],
+ "generation": state["generation"],
+ },
+ config,
+ )
+ return SearchState(complete=result.binary_score)
+
+ return _run_completeness_grader
def create_search_workflow(
llm: BaseChatModel, max_results: int = 3
) -> CompiledStateGraph:
- completeness_grader = CompletenessGrader(llm)
+ completeness_grader: Runnable = create_completeness_grader(llm)
tavily_wrapper = TavilyWrapper(max_results=max_results)
- workflow = StateGraph(GraphState)
+ workflow = StateGraph(SearchState)
# Add nodes
- workflow.add_node("assess_completeness", completeness_grader.ainvoke)
+ workflow.add_node(
+ "assess_completeness", run_completeness_grader(completeness_grader)
+ )
workflow.add_node("perform_web_search", tavily_wrapper.ainvoke)
workflow.add_node("no_search", no_search)