Skip to content

Commit 0eb7222

Browse files
Merge pull request #142 from monarch-initiative/tweaks_to_appalz
Fix issue with duplicate references in Alzheimers app
2 parents ae8f11a + 1ace976 commit 0eb7222

File tree

3 files changed

+245
-268
lines changed

3 files changed

+245
-268
lines changed

src/curategpt/agents/chat_agent.py

Lines changed: 173 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class ChatResponse(BaseModel):
4444
def replace_references_with_links(text):
4545
"""Replace references with links."""
4646
pattern = r"\[(\d+)\]"
47-
replacement = lambda m: f"[{m.group(1)}](#ref-{m.group(1)})"
47+
def replacement(m): return f"[{m.group(1)}](#ref-{m.group(1)})"
4848
return re.sub(pattern, replacement, text)
4949

5050

@@ -85,7 +85,8 @@ def chat(
8585
self.extractor = self.knowledge_source.extractor
8686
else:
8787
raise ValueError("Extractor must be set.")
88-
logger.info(f"Chat: {query} on {self.knowledge_source} kwargs: {kwargs}, limit: {limit}")
88+
logger.info(
89+
f"Chat: {query} on {self.knowledge_source} kwargs: {kwargs}, limit: {limit}")
8990
if collection is None:
9091
collection = self.knowledge_source_collection
9192
kwargs["collection"] = collection
@@ -101,7 +102,8 @@ def chat(
101102
current_length = 0
102103
for obj, _, _obj_meta in kb_results:
103104
i += 1
104-
obj_text = yaml.dump({k: v for k, v in obj.items() if v}, sort_keys=False)
105+
obj_text = yaml.dump(
106+
{k: v for k, v in obj.items() if v}, sort_keys=False)
105107
references[str(i)] = obj_text
106108
texts.append(f"## Reference {i}\n{obj_text}")
107109
current_length += len(obj_text)
@@ -126,7 +128,8 @@ def chat(
126128
break
127129
else:
128130
# remove least relevant
129-
logger.debug(f"Removing least relevant of {len(kb_results)}: {kb_results[-1]}")
131+
logger.debug(
132+
f"Removing least relevant of {len(kb_results)}: {kb_results[-1]}")
130133
if not kb_results:
131134
raise ValueError(f"Prompt too long: {prompt}.")
132135
kb_results.pop()
@@ -141,11 +144,13 @@ def chat(
141144
else:
142145
agent = model
143146
conversation_id = None
144-
response = agent.prompt(prompt, system="You are a scientist assistant.")
147+
response = agent.prompt(
148+
prompt, system="You are a scientist assistant.")
145149
response_text = response.text()
146150
pattern = r"\[(\d+|\?)\]"
147151
used_references = re.findall(pattern, response_text)
148-
used_references_dict = {ref: references.get(ref, "NO REFERENCE") for ref in used_references}
152+
used_references_dict = {ref: references.get(
153+
ref, "NO REFERENCE") for ref in used_references}
149154
uncited_references_dict = {
150155
ref: ref_obj for ref, ref_obj in references.items() if ref not in used_references
151156
}
@@ -191,74 +196,181 @@ def chat(
191196
else:
192197
raise ValueError("Extractor must be set.")
193198

194-
logger.info(f"Chat: {query} on {self.knowledge_source} with limit: {limit}")
199+
logger.info(
200+
f"Chat: {query} on {self.knowledge_source} with limit: {limit}")
195201
if collection is None:
196202
collection = self.knowledge_source_collection
197203
kwargs["collection"] = collection
198204

199-
# The search now returns dictionary results directly.
200-
kb_results = list(self.knowledge_source.search(
201-
query, relevance_factor=self.relevance_factor, limit=limit, expand=expand, **kwargs
202-
))
203-
204-
while True:
205-
references = {}
206-
texts = []
207-
for i, result_tuple in enumerate(kb_results, start=1):
208-
# Extract the object from the standard tuple format (obj, distance, metadata)
209-
obj, _, _ = result_tuple
205+
# Set Alzheimer's system prompt if we are using paperqa
206+
if hasattr(self.knowledge_source, 'name') and self.knowledge_source.name == 'paperqa':
207+
self.knowledge_source.settings.agent.agent_system_prompt = (
208+
"""You are a specialized AI assistant for biomedical researchers and clinicians focused on
209+
Alzheimer's disease and related topics. I will ask a question and you will answer
210+
as best as possible, citing references. For any additional facts that you are
211+
sure of, but without a citation, write [?].
212+
""")
210213

211-
obj_text = yaml.dump({k: v for k, v in obj.items() if v}, sort_keys=False)
212-
references[str(i)] = obj_text
213-
texts.append(f"## Reference {i}\n{obj_text}")
214+
# The search now returns dictionary results directly.
215+
kb_results = self.knowledge_source.search(
216+
query,
217+
relevance_factor=self.relevance_factor,
218+
limit=limit,
219+
expand=expand,
220+
**kwargs
221+
)
214222

215-
model = self.extractor.model
216-
prompt = (
217-
"You are a specialized AI assistant for biomedical researchers and clinicians focused on "
218-
"Alzheimer's disease and related topics. I will provide relevant background information, then ask "
219-
"a question. Use this context to provide evidence-based answers with proper scientific citations.\n"
220-
)
221-
prompt += "---\nBackground facts:\n" + "\n".join(texts) + "\n\n"
222-
prompt += (
223-
"I will ask a question and you will answer as best as possible, citing the references above.\n"
224-
"Write references in square brackets, e.g. [1]. For any additional facts without a citation, write [?].\n"
225-
)
226-
prompt += f"---\nHere is the Question: {query}.\n"
227-
logger.debug(f"Candidate Prompt: {prompt}")
228-
estimated_length = estimate_num_tokens([prompt])
229-
logger.debug(f"Max tokens {model.model_id}: {max_tokens_by_model(model.model_id)}")
223+
# Check if we're using PaperQA
224+
is_paperqa = hasattr(self.knowledge_source,
225+
'name') and self.knowledge_source.name == 'paperqa'
230226

231-
if estimated_length + 300 < max_tokens_by_model(model.model_id):
232-
break
233-
else:
234-
logger.debug("Prompt too long, removing least relevant result.")
235-
if not kb_results:
236-
raise ValueError(f"Prompt too long: {prompt}.")
237-
kb_results.pop()
238-
239-
logger.info("Final prompt constructed for chat.")
227+
model = self.extractor.model
240228
if conversation:
241229
conversation.model = model
242230
agent = conversation
243231
conversation_id = conversation.id
244-
logger.info(f"Using conversation context with ID: {conversation_id}")
232+
logger.info(f"Conversation ID: {conversation_id}")
245233
else:
246234
agent = model
247235
conversation_id = None
248236

249-
response = agent.prompt(prompt, system="You are a scientist assistant.")
250-
response_text = response.text()
251-
pattern = r"\[(\d+|\?)\]"
252-
used_references = re.findall(pattern, response_text)
253-
used_references_dict = {ref: references.get(ref, "NO REFERENCE") for ref in used_references}
254-
uncited_references_dict = {ref: ref_obj for ref, ref_obj in references.items() if ref not in used_references}
255-
formatted_text = replace_references_with_links(response_text)
237+
if is_paperqa:
238+
session = kb_results.session
239+
response_text = session.answer.strip()
240+
prompt = f"[PaperQA] Question: {session.question}"
241+
formatted_body, references = (
242+
self._format_paperqa_references(response_text,
243+
session.contexts))
244+
245+
# formatted_refs = {k: yaml.dump(v, sort_keys=False) for k, v in references.items() if v}
246+
247+
def drop_empty_fields(d: dict) -> dict:
248+
return {k: v for k, v in d.items() if isinstance(v, str) and v.strip()}
249+
250+
formatted_refs = {
251+
k: yaml.safe_dump(
252+
drop_empty_fields(v),
253+
sort_keys=False,
254+
allow_unicode=True,
255+
default_flow_style=False,
256+
width=80 # prevents line wrapping madness
257+
)
258+
for k, v in references.items()
259+
}
260+
261+
return ChatResponse(
262+
body=response_text,
263+
formatted_body=formatted_body,
264+
prompt=prompt,
265+
references=formatted_refs,
266+
uncited_references={},
267+
conversation_id=conversation_id,
268+
)
256269

257-
return ChatResponse(
258-
body=response_text,
259-
formatted_body=formatted_text,
260-
prompt=prompt,
261-
references=used_references_dict,
262-
uncited_references=uncited_references_dict,
263-
conversation_id=conversation_id,
264-
)
270+
else:
271+
kb_results = list(kb_results)
272+
# Regular processing for non-PaperQA sources
273+
274+
# For other sources, we need to format the results and create a prompt
275+
while True:
276+
i = 0
277+
references = {}
278+
texts = []
279+
current_length = 0
280+
for obj, _, _obj_meta in kb_results:
281+
i += 1
282+
obj_text = yaml.dump(
283+
{k: v for k, v in obj.items() if v}, sort_keys=False)
284+
references[str(i)] = obj_text
285+
texts.append(f"## Reference {i}\n{obj_text}")
286+
current_length += len(obj_text)
287+
288+
prompt = "I will first give background facts, then ask a question. Use the background fact to answer\n"
289+
prompt += "---\nBackground facts:\n"
290+
prompt += "\n".join(texts)
291+
prompt += "\n\n"
292+
prompt += "I will ask a question and you will answer as best as possible, citing the references above.\n"
293+
prompt += "Write references in square brackets, e.g. [1].\n"
294+
prompt += (
295+
"For additional facts you are sure of but a reference is not found, write [?].\n"
296+
)
297+
prompt += f"---\nHere is the Question: {query}.\n"
298+
logger.debug(f"Candidate Prompt: {prompt}")
299+
estimated_length = estimate_num_tokens([prompt])
300+
logger.debug(
301+
f"Max tokens {self.extractor.model.model_id}: {max_tokens_by_model(self.extractor.model.model_id)}"
302+
)
303+
# TODO: use a more precise estimate of the length
304+
if estimated_length + 300 < max_tokens_by_model(self.extractor.model.model_id):
305+
break
306+
else:
307+
# remove least relevant
308+
logger.debug(
309+
f"Removing least relevant of {len(kb_results)}: {kb_results[-1]}")
310+
if not kb_results:
311+
raise ValueError(f"Prompt too long: {prompt}.")
312+
kb_results.pop()
313+
314+
logger.info(f"Prompt: {prompt}")
315+
316+
response = agent.prompt(
317+
prompt, system="You are a scientist assistant.")
318+
response_text = response.text()
319+
pattern = r"\[(\d+)\]"
320+
used_references = re.findall(pattern, response_text)
321+
used_references_dict = {ref: references.get(
322+
ref, "NO REFERENCE") for ref in used_references}
323+
uncited_references_dict = {
324+
ref: ref_obj for ref, ref_obj in references.items() if ref not in used_references
325+
}
326+
formatted_text = replace_references_with_links(response_text)
327+
return ChatResponse(
328+
body=response_text,
329+
formatted_body=formatted_text,
330+
prompt=prompt,
331+
references=used_references_dict,
332+
uncited_references=uncited_references_dict,
333+
conversation_id=conversation_id,
334+
)
335+
336+
@staticmethod
337+
def _format_paperqa_references(answer: str, contexts: list) -> tuple[str, dict]:
338+
import re
339+
from collections import OrderedDict
340+
341+
formatted_body = answer
342+
doc_key_to_num = OrderedDict()
343+
references = {}
344+
345+
# Assign numbers to unique doc.key
346+
for ctx in contexts:
347+
doc = ctx.text.doc
348+
if doc.key not in doc_key_to_num:
349+
doc_key_to_num[doc.key] = len(doc_key_to_num) + 1
350+
references[str(doc_key_to_num[doc.key])] = {
351+
"id": doc.key if hasattr(doc, 'key') else "",
352+
"title": doc.title if hasattr(doc, 'title') else "",
353+
"abstract": doc.text if hasattr(doc, 'text') else "",
354+
"citation": doc.citation if hasattr(doc, 'citation') else "",
355+
"url": doc.doi_url if hasattr(doc, 'doi_url') else "",
356+
"doi": doc.doi if hasattr(doc, 'doi') else ""
357+
}
358+
359+
used_pairs = set()
360+
for ctx in contexts:
361+
text_name = ctx.text.name.strip() # e.g. melendez2024 pages 6–7
362+
doc_key = ctx.text.doc.key
363+
ref_num = doc_key_to_num[doc_key]
364+
pages = text_name.split("pages")[
365+
-1].strip() if "pages" in text_name else None
366+
367+
if (text_name, ref_num, pages) in used_pairs:
368+
continue
369+
used_pairs.add((text_name, ref_num, pages))
370+
371+
markdown_link = f"[{ref_num}](#ref-{ref_num})"
372+
replacement = f"{markdown_link} (pages {pages})" if pages else markdown_link
373+
formatted_body = re.sub(re.escape(text_name), replacement,
374+
formatted_body)
375+
376+
return formatted_body, references

0 commit comments

Comments
 (0)