Skip to content

Commit ae8f11a

Browse files
Merge pull request #136 from monarch-initiative/make_custom_app_targets
Add support for paperqa (and update app_alz.py to use it)
2 parents 525b7cb + 515d22b commit ae8f11a

File tree

11 files changed

+2494
-1645
lines changed

11 files changed

+2494
-1645
lines changed

.codespellrc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[codespell]
2+
ignore-words-list = aadd
3+
skip = tests/db

.github/workflows/qc.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: [ "3.9", "3.10", "3.11" ]
14+
python-version: [ "3.11" ]
1515

1616
steps:
1717
- uses: actions/checkout@v3.0.2
@@ -24,7 +24,7 @@ jobs:
2424
- name: Install Poetry
2525
uses: snok/install-poetry@v1.3.1
2626
- name: Install dependencies
27-
run: poetry install --no-interaction
27+
run: poetry install --no-interaction --extras "paperqa"
2828

2929
- name: Check common spelling errors
3030
run: poetry run tox -e codespell

poetry.lock

Lines changed: 2078 additions & 1591 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ license = "BSD-3"
77
readme = "README.md"
88

99
[tool.poetry.dependencies]
10-
python = "^3.9, !=3.9.7"
10+
python = "^3.11"
1111
click = "^8.1.7"
1212
importlib-metadata = ">=6"
1313
oaklib = "^0.6.9"
@@ -50,10 +50,12 @@ click-default-group = "^1.2.4"
5050
venomx = "^0.1.1"
5151
duckdb = "~1.0.0"
5252
python-dotenv = "^1.0.1"
53+
langchain-community = {version = "*", optional = true}
5354
onnxruntime = [
5455
{version = "<=1.19.2", python = "<3.10"},
5556
{version = "^1.20.0", python = ">=3.10"}
5657
]
58+
paper-qa = {version = "^5.20.0", optional = true, python = ">=3.11"}
5759

5860
[tool.poetry.group.dev.dependencies]
5961
pytest = ">=7.1.2"
@@ -88,14 +90,18 @@ docs = [
8890
"sphinx-autodoc-typehints",
8991
"sphinx-click",
9092
"myst-parser"
91-
]
93+
]
9294
bioc = [
9395
"bioc"
9496
]
9597
gpt4all = [
9698
"gpt4all",
9799
"llm-gpt4all"
98100
]
101+
paperqa = [
102+
"paper-qa",
103+
"langchain-community"
104+
]
99105

100106
[tool.poetry-dynamic-versioning]
101107
enable = false
@@ -148,4 +154,3 @@ quiet-level = 3
148154
[build-system]
149155
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning"]
150156
build-backend = "poetry_dynamic_versioning.backend"
151-

src/curategpt/agents/chat_agent.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,107 @@ def chat(
158158
uncited_references=uncited_references_dict,
159159
conversation_id=conversation_id,
160160
)
161+
162+
163+
@dataclass
164+
class ChatAgentAlz(BaseAgent):
165+
"""
166+
An agent that allows chat to a knowledge source.
167+
168+
This implements a standard knowledgebase retrieval augmented generation pattern.
169+
The knowledge_source is queried for relevant objects (the source can be a local
170+
database or a remote source such as pubmed).
171+
The objects are provided as context to a LLM query
172+
"""
173+
174+
relevance_factor: float = 0.5
175+
"""Relevance factor for diversifying search results using MMR."""
176+
177+
conversation_id: Optional[str] = None
178+
179+
def chat(
180+
self,
181+
query: str,
182+
conversation: Optional[Any] = None,
183+
limit: int = 10,
184+
collection: str = None,
185+
expand: bool = True,
186+
**kwargs,
187+
) -> ChatResponse:
188+
if self.extractor is None:
189+
if isinstance(self.knowledge_source, BaseWrapper):
190+
self.extractor = self.knowledge_source.extractor
191+
else:
192+
raise ValueError("Extractor must be set.")
193+
194+
logger.info(f"Chat: {query} on {self.knowledge_source} with limit: {limit}")
195+
if collection is None:
196+
collection = self.knowledge_source_collection
197+
kwargs["collection"] = collection
198+
199+
# The search now returns dictionary results directly.
200+
kb_results = list(self.knowledge_source.search(
201+
query, relevance_factor=self.relevance_factor, limit=limit, expand=expand, **kwargs
202+
))
203+
204+
while True:
205+
references = {}
206+
texts = []
207+
for i, result_tuple in enumerate(kb_results, start=1):
208+
# Extract the object from the standard tuple format (obj, distance, metadata)
209+
obj, _, _ = result_tuple
210+
211+
obj_text = yaml.dump({k: v for k, v in obj.items() if v}, sort_keys=False)
212+
references[str(i)] = obj_text
213+
texts.append(f"## Reference {i}\n{obj_text}")
214+
215+
model = self.extractor.model
216+
prompt = (
217+
"You are a specialized AI assistant for biomedical researchers and clinicians focused on "
218+
"Alzheimer's disease and related topics. I will provide relevant background information, then ask "
219+
"a question. Use this context to provide evidence-based answers with proper scientific citations.\n"
220+
)
221+
prompt += "---\nBackground facts:\n" + "\n".join(texts) + "\n\n"
222+
prompt += (
223+
"I will ask a question and you will answer as best as possible, citing the references above.\n"
224+
"Write references in square brackets, e.g. [1]. For any additional facts without a citation, write [?].\n"
225+
)
226+
prompt += f"---\nHere is the Question: {query}.\n"
227+
logger.debug(f"Candidate Prompt: {prompt}")
228+
estimated_length = estimate_num_tokens([prompt])
229+
logger.debug(f"Max tokens {model.model_id}: {max_tokens_by_model(model.model_id)}")
230+
231+
if estimated_length + 300 < max_tokens_by_model(model.model_id):
232+
break
233+
else:
234+
logger.debug("Prompt too long, removing least relevant result.")
235+
if not kb_results:
236+
raise ValueError(f"Prompt too long: {prompt}.")
237+
kb_results.pop()
238+
239+
logger.info("Final prompt constructed for chat.")
240+
if conversation:
241+
conversation.model = model
242+
agent = conversation
243+
conversation_id = conversation.id
244+
logger.info(f"Using conversation context with ID: {conversation_id}")
245+
else:
246+
agent = model
247+
conversation_id = None
248+
249+
response = agent.prompt(prompt, system="You are a scientist assistant.")
250+
response_text = response.text()
251+
pattern = r"\[(\d+|\?)\]"
252+
used_references = re.findall(pattern, response_text)
253+
used_references_dict = {ref: references.get(ref, "NO REFERENCE") for ref in used_references}
254+
uncited_references_dict = {ref: ref_obj for ref, ref_obj in references.items() if ref not in used_references}
255+
formatted_text = replace_references_with_links(response_text)
256+
257+
return ChatResponse(
258+
body=response_text,
259+
formatted_body=formatted_text,
260+
prompt=prompt,
261+
references=used_references_dict,
262+
uncited_references=uncited_references_dict,
263+
conversation_id=conversation_id,
264+
)

src/curategpt/app/app_alz.py

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,24 @@
22

33
import json
44
import logging
5+
import os
56
from typing import List, Union
67

78
import streamlit as st
89
import yaml
910

1011
from curategpt import BasicExtractor
11-
from curategpt.agents.chat_agent import ChatAgent, ChatResponse
12+
from curategpt.agents.chat_agent import ChatAgentAlz, ChatResponse
1213
from curategpt.agents.evidence_agent import EvidenceAgent
13-
from curategpt.app.helper import get_applicable_examples
1414
from curategpt.app.state import get_state
1515
from curategpt.wrappers import BaseWrapper
1616
from curategpt.wrappers.literature import WikipediaWrapper
1717
from curategpt.wrappers.literature.pubmed_wrapper import PubmedWrapper
18+
from curategpt.wrappers.paperqa.paperqawrapper import PaperQAWrapper
1819

19-
PUBMED = "PubMed (via API)"
20-
WIKIPEDIA = "Wikipedia (via API)"
21-
# Removed JGI and ESS-Dive
22-
# JGI = "JGI (via API)"
23-
# ESSDIVE = "ESS-DeepDive (via API)"
20+
PUBMED = "PubMed"
21+
WIKIPEDIA = "Wikipedia"
22+
PAPERQA = "Alzheimers_Papers"
2423

2524
CHAT = "Chat"
2625
SEARCH = "Search"
@@ -57,15 +56,22 @@
5756
cart = state.cart
5857

5958

60-
st.title("Alzheimers AI assistant")
59+
st.title("Alzheimer's AI Assistant")
60+
61+
# Check if PQA_HOME environment variable is set for PaperQA
62+
if PAPERQA in [PUBMED, PAPERQA, WIKIPEDIA] and os.environ.get("PQA_HOME") is None:
63+
st.warning(
64+
"PQA_HOME environment variable is not set. To use the Alzheimer's Papers collection, "
65+
"you need to set PQA_HOME to the directory containing your indexed papers. "
66+
"Use 'curategpt paperqa index /path/to/papers' to create an index."
67+
)
6168
if not db.list_collection_names():
6269
st.warning("No collections found. Please use command line to load one.")
6370

6471
# Include Chat, Search, and CiteSeek in PAGES
6572
PAGES = [
6673
CHAT,
67-
CITESEEK,
68-
SEARCH
74+
CITESEEK
6975
]
7076

7177

@@ -92,12 +98,13 @@ def filtered_collection_names() -> List[str]:
9298

9399
collection = st.sidebar.selectbox(
94100
"Choose collection",
95-
[PUBMED, WIKIPEDIA] + filtered_collection_names(), # Removed JGI and ESSDIVE, put PubMed first
96-
index=0, # Set PubMed as default
101+
[PUBMED, PAPERQA, WIKIPEDIA] + filtered_collection_names() + ["No collection"],
102+
index=0, # Set PUBMED as default (index 0 since it's first in the list)
97103
help="""
98104
A collection is a knowledge base. It could be anything, but
99105
it's likely your instance has some bio-ontologies pre-loaded.
100-
Select 'About' to see details of each collection
106+
Select 'Alzheimer's Papers (via PaperQA)' for direct access to a trusted corpus of Alzheimer's research papers.
107+
Select 'No collection' to interact with the model directly without a knowledge base.
101108
""",
102109
)
103110

@@ -118,11 +125,12 @@ def filtered_collection_names() -> List[str]:
118125
# Add background_collection for CiteSeek functionality
119126
background_collection = st.sidebar.selectbox(
120127
"Background knowledge for CiteSeek",
121-
[NO_BACKGROUND_SELECTED, PUBMED, WIKIPEDIA],
128+
[NO_BACKGROUND_SELECTED, PUBMED, PAPERQA, WIKIPEDIA],
122129
index=1, # Set PubMed as default
123130
help="""
124131
Background databases provide evidence sources for CiteSeek.
125132
PubMed is recommended for verifying medical claims.
133+
Alzheimer's Papers provides specialized knowledge from trusted Alzheimer's research papers.
126134
""",
127135
)
128136

@@ -131,25 +139,43 @@ def filtered_collection_names() -> List[str]:
131139
st.sidebar.markdown("Developed by the Monarch Initiative")
132140

133141

134-
def get_chat_agent() -> Union[ChatAgent, BaseWrapper]:
135-
knowledge_source_collection = None
136-
if collection == PUBMED:
142+
def get_chat_agent() -> Union[ChatAgentAlz, BaseWrapper]:
143+
if collection == "No collection":
144+
return ChatAgentAlz(extractor=extractor)
145+
elif collection == PUBMED:
137146
source = PubmedWrapper(local_store=db, extractor=extractor)
138147
elif collection == WIKIPEDIA:
139148
source = WikipediaWrapper(local_store=db, extractor=extractor)
140-
# Removed JGI and ESSDIVE cases
149+
elif collection == PAPERQA:
150+
source = PaperQAWrapper(extractor=extractor)
141151
else:
142152
source = db
143-
knowledge_source_collection = collection
144-
return ChatAgent(
153+
154+
agent = ChatAgentAlz(
145155
knowledge_source=source,
146-
knowledge_source_collection=knowledge_source_collection,
156+
knowledge_source_collection=collection,
147157
extractor=extractor,
148158
)
149159

160+
if agent.knowledge_source is None:
161+
raise ValueError(f"Knowledge source is None for collection {collection}")
162+
163+
return agent
164+
150165

151166
def ask_chatbot(query, expand=False) -> ChatResponse:
152-
return get_chat_agent().chat(query, expand=expand)
167+
agent = get_chat_agent()
168+
if collection == "No collection":
169+
response = agent.extractor.model.prompt(query, system="You are a helpful Alzheimer's disease expert.")
170+
return ChatResponse(
171+
body=response.text(),
172+
formatted_body=response.text(),
173+
prompt=query,
174+
references={},
175+
uncited_references={}
176+
)
177+
else:
178+
return agent.chat(query, expand=expand)
153179

154180

155181
def html_table(rows: List[dict]) -> str:
@@ -238,34 +264,44 @@ def _flat(obj: dict, limit=40) -> dict:
238264

239265
elif option == CHAT:
240266
page_state = state.get_page_state(CHAT)
241-
st.subheader("Chat with a knowledge base")
242-
query = st.text_area(
243-
f"Ask me anything (within the scope of {collection})!",
244-
help="You can query the current knowledge base using natural language.",
245-
)
267+
if collection == "No collection":
268+
st.subheader("Chat with the Alzheimer's AI assistant")
269+
query = st.text_area(
270+
"Ask me anything about Alzheimer's disease",
271+
help="Ask questions directly to the AI without using a knowledge base.",
272+
)
273+
else:
274+
query = st.text_area(
275+
f"Ask me anything about Alzheimer's disease (within the scope of {collection})",
276+
help="You can query the current knowledge base using natural language.",
277+
)
278+
279+
# Only show these controls if using a knowledge base
280+
if collection != "No collection":
281+
limit = st.slider(
282+
"Detail",
283+
min_value=0,
284+
max_value=30,
285+
value=10,
286+
step=1,
287+
help="""
288+
Behind the scenes, N entries are fetched from the knowledge base,
289+
and these are fed to the LLM. Selecting more examples may give more
290+
complete results, but may also exceed context windows for the model.
291+
""",
292+
)
293+
expand = st.checkbox(
294+
"Expand query",
295+
help="""
296+
If checked, perform query expansion (pubmed only).
297+
""",
298+
)
299+
else:
300+
# Set default values when not using a knowledge base
301+
limit = 0
302+
expand = False
246303

247-
limit = st.slider(
248-
"Detail",
249-
min_value=0,
250-
max_value=30,
251-
value=10,
252-
step=1,
253-
help="""
254-
Behind the scenes, N entries are fetched from the knowledge base,
255-
and these are fed to the LLM. Selecting more examples may give more
256-
complete results, but may also exceed context windows for the model.
257-
""",
258-
)
259-
expand = st.checkbox(
260-
"Expand query",
261-
help="""
262-
If checked, perform query expansion (pubmed only).
263-
""",
264-
)
265304
extractor.model_name = model_name
266-
examples = get_applicable_examples(collection, CHAT)
267-
st.write("Examples:")
268-
st.write(f"<details>{html_table(examples)}</details>", unsafe_allow_html=True)
269305

270306
if st.button(CHAT):
271307
response = ask_chatbot(query, expand=expand)

0 commit comments

Comments
 (0)