@@ -44,7 +44,7 @@ class ChatResponse(BaseModel):
4444def replace_references_with_links (text ):
4545 """Replace references with links."""
4646 pattern = r"\[(\d+)\]"
47- replacement = lambda m : f"[{ m .group (1 )} ](#ref-{ m .group (1 )} )"
47+ def replacement ( m ): return f"[{ m .group (1 )} ](#ref-{ m .group (1 )} )"
4848 return re .sub (pattern , replacement , text )
4949
5050
@@ -85,7 +85,8 @@ def chat(
8585 self .extractor = self .knowledge_source .extractor
8686 else :
8787 raise ValueError ("Extractor must be set." )
88- logger .info (f"Chat: { query } on { self .knowledge_source } kwargs: { kwargs } , limit: { limit } " )
88+ logger .info (
89+ f"Chat: { query } on { self .knowledge_source } kwargs: { kwargs } , limit: { limit } " )
8990 if collection is None :
9091 collection = self .knowledge_source_collection
9192 kwargs ["collection" ] = collection
@@ -101,7 +102,8 @@ def chat(
101102 current_length = 0
102103 for obj , _ , _obj_meta in kb_results :
103104 i += 1
104- obj_text = yaml .dump ({k : v for k , v in obj .items () if v }, sort_keys = False )
105+ obj_text = yaml .dump (
106+ {k : v for k , v in obj .items () if v }, sort_keys = False )
105107 references [str (i )] = obj_text
106108 texts .append (f"## Reference { i } \n { obj_text } " )
107109 current_length += len (obj_text )
@@ -126,7 +128,8 @@ def chat(
126128 break
127129 else :
128130 # remove least relevant
129- logger .debug (f"Removing least relevant of { len (kb_results )} : { kb_results [- 1 ]} " )
131+ logger .debug (
132+ f"Removing least relevant of { len (kb_results )} : { kb_results [- 1 ]} " )
130133 if not kb_results :
131134 raise ValueError (f"Prompt too long: { prompt } ." )
132135 kb_results .pop ()
@@ -141,11 +144,13 @@ def chat(
141144 else :
142145 agent = model
143146 conversation_id = None
144- response = agent .prompt (prompt , system = "You are a scientist assistant." )
147+ response = agent .prompt (
148+ prompt , system = "You are a scientist assistant." )
145149 response_text = response .text ()
146150 pattern = r"\[(\d+|\?)\]"
147151 used_references = re .findall (pattern , response_text )
148- used_references_dict = {ref : references .get (ref , "NO REFERENCE" ) for ref in used_references }
152+ used_references_dict = {ref : references .get (
153+ ref , "NO REFERENCE" ) for ref in used_references }
149154 uncited_references_dict = {
150155 ref : ref_obj for ref , ref_obj in references .items () if ref not in used_references
151156 }
@@ -191,74 +196,181 @@ def chat(
191196 else :
192197 raise ValueError ("Extractor must be set." )
193198
194- logger .info (f"Chat: { query } on { self .knowledge_source } with limit: { limit } " )
199+ logger .info (
200+ f"Chat: { query } on { self .knowledge_source } with limit: { limit } " )
195201 if collection is None :
196202 collection = self .knowledge_source_collection
197203 kwargs ["collection" ] = collection
198204
199- # The search now returns dictionary results directly.
200- kb_results = list (self .knowledge_source .search (
201- query , relevance_factor = self .relevance_factor , limit = limit , expand = expand , ** kwargs
202- ))
203-
204- while True :
205- references = {}
206- texts = []
207- for i , result_tuple in enumerate (kb_results , start = 1 ):
208- # Extract the object from the standard tuple format (obj, distance, metadata)
209- obj , _ , _ = result_tuple
205+ # Set Alzheimer's system prompt if we are using paperqa
206+ if hasattr (self .knowledge_source , 'name' ) and self .knowledge_source .name == 'paperqa' :
207+ self .knowledge_source .settings .agent .agent_system_prompt = (
208+ """You are a specialized AI assistant for biomedical researchers and clinicians focused on
209+ Alzheimer's disease and related topics. I will ask a question and you will answer
210+ as best as possible, citing references. For any additional facts that you are
211+ sure of, but without a citation, write [?].
212+ """ )
210213
211- obj_text = yaml .dump ({k : v for k , v in obj .items () if v }, sort_keys = False )
212- references [str (i )] = obj_text
213- texts .append (f"## Reference { i } \n { obj_text } " )
214+ # The search now returns dictionary results directly.
215+ kb_results = self .knowledge_source .search (
216+ query ,
217+ relevance_factor = self .relevance_factor ,
218+ limit = limit ,
219+ expand = expand ,
220+ ** kwargs
221+ )
214222
215- model = self .extractor .model
216- prompt = (
217- "You are a specialized AI assistant for biomedical researchers and clinicians focused on "
218- "Alzheimer's disease and related topics. I will provide relevant background information, then ask "
219- "a question. Use this context to provide evidence-based answers with proper scientific citations.\n "
220- )
221- prompt += "---\n Background facts:\n " + "\n " .join (texts ) + "\n \n "
222- prompt += (
223- "I will ask a question and you will answer as best as possible, citing the references above.\n "
224- "Write references in square brackets, e.g. [1]. For any additional facts without a citation, write [?].\n "
225- )
226- prompt += f"---\n Here is the Question: { query } .\n "
227- logger .debug (f"Candidate Prompt: { prompt } " )
228- estimated_length = estimate_num_tokens ([prompt ])
229- logger .debug (f"Max tokens { model .model_id } : { max_tokens_by_model (model .model_id )} " )
223+ # Check if we're using PaperQA
224+ is_paperqa = hasattr (self .knowledge_source ,
225+ 'name' ) and self .knowledge_source .name == 'paperqa'
230226
231- if estimated_length + 300 < max_tokens_by_model (model .model_id ):
232- break
233- else :
234- logger .debug ("Prompt too long, removing least relevant result." )
235- if not kb_results :
236- raise ValueError (f"Prompt too long: { prompt } ." )
237- kb_results .pop ()
238-
239- logger .info ("Final prompt constructed for chat." )
227+ model = self .extractor .model
240228 if conversation :
241229 conversation .model = model
242230 agent = conversation
243231 conversation_id = conversation .id
244- logger .info (f"Using conversation context with ID: { conversation_id } " )
232+ logger .info (f"Conversation ID: { conversation_id } " )
245233 else :
246234 agent = model
247235 conversation_id = None
248236
249- response = agent .prompt (prompt , system = "You are a scientist assistant." )
250- response_text = response .text ()
251- pattern = r"\[(\d+|\?)\]"
252- used_references = re .findall (pattern , response_text )
253- used_references_dict = {ref : references .get (ref , "NO REFERENCE" ) for ref in used_references }
254- uncited_references_dict = {ref : ref_obj for ref , ref_obj in references .items () if ref not in used_references }
255- formatted_text = replace_references_with_links (response_text )
237+ if is_paperqa :
238+ session = kb_results .session
239+ response_text = session .answer .strip ()
240+ prompt = f"[PaperQA] Question: { session .question } "
241+ formatted_body , references = (
242+ self ._format_paperqa_references (response_text ,
243+ session .contexts ))
244+
245+ # formatted_refs = {k: yaml.dump(v, sort_keys=False) for k, v in references.items() if v}
246+
247+ def drop_empty_fields (d : dict ) -> dict :
248+ return {k : v for k , v in d .items () if isinstance (v , str ) and v .strip ()}
249+
250+ formatted_refs = {
251+ k : yaml .safe_dump (
252+ drop_empty_fields (v ),
253+ sort_keys = False ,
254+ allow_unicode = True ,
255+ default_flow_style = False ,
256+ width = 80 # prevents line wrapping madness
257+ )
258+ for k , v in references .items ()
259+ }
260+
261+ return ChatResponse (
262+ body = response_text ,
263+ formatted_body = formatted_body ,
264+ prompt = prompt ,
265+ references = formatted_refs ,
266+ uncited_references = {},
267+ conversation_id = conversation_id ,
268+ )
256269
257- return ChatResponse (
258- body = response_text ,
259- formatted_body = formatted_text ,
260- prompt = prompt ,
261- references = used_references_dict ,
262- uncited_references = uncited_references_dict ,
263- conversation_id = conversation_id ,
264- )
270+ else :
271+ kb_results = list (kb_results )
272+ # Regular processing for non-PaperQA sources
273+
274+ # For other sources, we need to format the results and create a prompt
275+ while True :
276+ i = 0
277+ references = {}
278+ texts = []
279+ current_length = 0
280+ for obj , _ , _obj_meta in kb_results :
281+ i += 1
282+ obj_text = yaml .dump (
283+ {k : v for k , v in obj .items () if v }, sort_keys = False )
284+ references [str (i )] = obj_text
285+ texts .append (f"## Reference { i } \n { obj_text } " )
286+ current_length += len (obj_text )
287+
288+ prompt = "I will first give background facts, then ask a question. Use the background fact to answer\n "
289+ prompt += "---\n Background facts:\n "
290+ prompt += "\n " .join (texts )
291+ prompt += "\n \n "
292+ prompt += "I will ask a question and you will answer as best as possible, citing the references above.\n "
293+ prompt += "Write references in square brackets, e.g. [1].\n "
294+ prompt += (
295+ "For additional facts you are sure of but a reference is not found, write [?].\n "
296+ )
297+ prompt += f"---\n Here is the Question: { query } .\n "
298+ logger .debug (f"Candidate Prompt: { prompt } " )
299+ estimated_length = estimate_num_tokens ([prompt ])
300+ logger .debug (
301+ f"Max tokens { self .extractor .model .model_id } : { max_tokens_by_model (self .extractor .model .model_id )} "
302+ )
303+ # TODO: use a more precise estimate of the length
304+ if estimated_length + 300 < max_tokens_by_model (self .extractor .model .model_id ):
305+ break
306+ else :
307+ # remove least relevant
308+ logger .debug (
309+ f"Removing least relevant of { len (kb_results )} : { kb_results [- 1 ]} " )
310+ if not kb_results :
311+ raise ValueError (f"Prompt too long: { prompt } ." )
312+ kb_results .pop ()
313+
314+ logger .info (f"Prompt: { prompt } " )
315+
316+ response = agent .prompt (
317+ prompt , system = "You are a scientist assistant." )
318+ response_text = response .text ()
319+ pattern = r"\[(\d+)\]"
320+ used_references = re .findall (pattern , response_text )
321+ used_references_dict = {ref : references .get (
322+ ref , "NO REFERENCE" ) for ref in used_references }
323+ uncited_references_dict = {
324+ ref : ref_obj for ref , ref_obj in references .items () if ref not in used_references
325+ }
326+ formatted_text = replace_references_with_links (response_text )
327+ return ChatResponse (
328+ body = response_text ,
329+ formatted_body = formatted_text ,
330+ prompt = prompt ,
331+ references = used_references_dict ,
332+ uncited_references = uncited_references_dict ,
333+ conversation_id = conversation_id ,
334+ )
335+
336+ @staticmethod
337+ def _format_paperqa_references (answer : str , contexts : list ) -> tuple [str , dict ]:
338+ import re
339+ from collections import OrderedDict
340+
341+ formatted_body = answer
342+ doc_key_to_num = OrderedDict ()
343+ references = {}
344+
345+ # Assign numbers to unique doc.key
346+ for ctx in contexts :
347+ doc = ctx .text .doc
348+ if doc .key not in doc_key_to_num :
349+ doc_key_to_num [doc .key ] = len (doc_key_to_num ) + 1
350+ references [str (doc_key_to_num [doc .key ])] = {
351+ "id" : doc .key if hasattr (doc , 'key' ) else "" ,
352+ "title" : doc .title if hasattr (doc , 'title' ) else "" ,
353+ "abstract" : doc .text if hasattr (doc , 'text' ) else "" ,
354+ "citation" : doc .citation if hasattr (doc , 'citation' ) else "" ,
355+ "url" : doc .doi_url if hasattr (doc , 'doi_url' ) else "" ,
356+ "doi" : doc .doi if hasattr (doc , 'doi' ) else ""
357+ }
358+
359+ used_pairs = set ()
360+ for ctx in contexts :
361+ text_name = ctx .text .name .strip () # e.g. melendez2024 pages 6–7
362+ doc_key = ctx .text .doc .key
363+ ref_num = doc_key_to_num [doc_key ]
364+ pages = text_name .split ("pages" )[
365+ - 1 ].strip () if "pages" in text_name else None
366+
367+ if (text_name , ref_num , pages ) in used_pairs :
368+ continue
369+ used_pairs .add ((text_name , ref_num , pages ))
370+
371+ markdown_link = f"[{ ref_num } ](#ref-{ ref_num } )"
372+ replacement = f"{ markdown_link } (pages { pages } )" if pages else markdown_link
373+ formatted_body = re .sub (re .escape (text_name ), replacement ,
374+ formatted_body )
375+
376+ return formatted_body , references
0 commit comments