Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 141 additions & 26 deletions functions/chain.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,153 @@
import concurrent.futures
from os import environ

from get_google_docs import get_inital_prompt
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_anthropic import ChatAnthropic

# separated files
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import OpenAI
from meta import document_content_description, metadata_field_info
from store import get_vector_store

# def create_health_ai_chain(llm, vector_store):
# retriever = SelfQueryRetriever.from_llm(
# llm=llm,
# vectorstore=vector_store,
# document_content_description=document_content_description,
# metadata_field_info=metadata_field_info,
# document_contents='',
# )
# health_ai_template = """
# You are a health AI agent equipped with access to diverse sources of health data,
# including research articles, nutritional information, medical archives, and more.
# Your task is to provide informed answers to user queries based on the available data.
# If you cannot find relevant information, simply state that you do not have enough data
# to answer accurately. write your response in markdown form and also add reference url
# so user can know from which source you are answering the questions.

# CONTEXT:
# {context}

# QUESTION: {question}

# YOUR ANSWER:
# """
# health_ai_prompt = ChatPromptTemplate.from_template(health_ai_template)
# chain = (
# {'context': retriever, 'question': RunnablePassthrough()}
# | health_ai_prompt
# | llm
# | StrOutputParser()
# )
# return chain


def create_health_ai_chain(llm, vector_store):
def custom_history(entire_history: list, llm_name: str):
chat_history = []
for msg in entire_history:
if 'user' in msg:
chat_history.extend([HumanMessage(content=msg['user'])])
if llm_name in msg:
chat_history.extend([AIMessage(content=msg[llm_name])])
return chat_history


def process_llm(
llm_name,
input_string,
message_history,
vector_store,
contextualize_q_system_prompt,
health_ai_template,
):
chat_history = custom_history(message_history, llm_name)

if llm_name == 'gpt-4':
environ.get('OPENAI_API_KEY')
llm = OpenAI(temperature=0.2)
elif llm_name == 'gemini':
environ.get('GOOGLE_API_KEY')
llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro-latest')
elif llm_name == 'claude':
environ.get('ANTHROPIC_API_KEY')
llm = ChatAnthropic(model='claude-3-5-sonnet-20240620')

retriever = SelfQueryRetriever.from_llm(
llm=llm,
vectorstore=vector_store,
document_content_description=document_content_description,
metadata_field_info=metadata_field_info,
document_contents='',
)
health_ai_template = """
You are a health AI agent equipped with access to diverse sources of health data,
including research articles, nutritional information, medical archives, and more.
Your task is to provide informed answers to user queries based on the available data.
If you cannot find relevant information, simply state that you do not have enough data
to answer accurately. write your response in markdown form and also add reference url
so user can know from which source you are answering the questions.

CONTEXT:
{context}

QUESTION: {question}

YOUR ANSWER:
"""
health_ai_prompt = ChatPromptTemplate.from_template(health_ai_template)
chain = (
{'context': retriever, 'question': RunnablePassthrough()}
| health_ai_prompt
| llm
| StrOutputParser()

contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
('system', contextualize_q_system_prompt),
MessagesPlaceholder('chat_history'),
('human', '{input}'),
]
)
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)

qa_prompt = ChatPromptTemplate.from_messages(
[('system', health_ai_template), MessagesPlaceholder('chat_history'), ('human', '{input}')]
)
return chain

question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
msg = rag_chain.invoke({'input': input_string, 'chat_history': chat_history})

return llm_name, msg['answer']


def hist_aware_answers(llm_list, input_string, message_history):
answers = {}
vector_store = get_vector_store()

get_init_answer = get_inital_prompt()
init_prompt = '' if get_init_answer is None else get_init_answer

contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""

context_str = """ You are a health AI agent equipped with
access to diverse sources of health data,
including research articles, nutritional information, medical archives, and more.
Your task is to provide informed answers to user queries based on the available data.
If you cannot find relevant information, simply state that you do not have enough data
to answer accurately. write your response in markdown form and also add reference url
so user can know from which source you are answering the questions.

CONTEXT:
{context}
"""

health_ai_template = f'{init_prompt}{context_str}'

# Parallel processing
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_llm = {
executor.submit(
process_llm,
llm,
input_string,
message_history,
vector_store,
contextualize_q_system_prompt,
health_ai_template,
): llm
for llm in llm_list
}
for future in concurrent.futures.as_completed(future_to_llm):
llm_name, answer = future.result()
answers[llm_name] = answer

return answers
81 changes: 81 additions & 0 deletions functions/get_google_docs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import io
import os
import pickle
import re

from google.auth.transport.requests import Request
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload


def extract_document_id_from_url(url):
pattern = r'/d/([a-zA-Z0-9-_]+)'
matches = re.findall(pattern, url)
document_id = max(matches, key=len)
return document_id


def authenticate(credentials, scopes):
"""Obtaining auth with needed apis"""
creds = None
# The file token.pickle stores the user's access
# and refresh tokens, and is created automatically
# when the authorization flow completes for the first time.
if os.path.exists('token.pickle'):
with open('token.pickle', 'rb') as token:
creds = pickle.load(token)
# If there are no (valid) credentials available, let the user log in.
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file(credentials, scopes)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open('token.pickle', 'wb') as token:
pickle.dump(creds, token)

return creds


def download_file(file_id, credentials_path):
scopes = ['https://www.googleapis.com/auth/drive.readonly']
credentials = authenticate(credentials_path, scopes)
drive_service = build('drive', 'v3', credentials=credentials)

# Export the Google Docs file as plain text
export_mime_type = 'text/plain'
request = drive_service.files().export_media(fileId=file_id, mimeType=export_mime_type)

# Use a BytesIO buffer to handle the file content in memory
fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)
done = False
while not done:
status, done = downloader.next_chunk()
print(f'Download {int(status.progress() * 100)}%.')

# Reset the buffer's position to the beginning
fh.seek(0)

# Read the content of the buffer
content = fh.read().decode('utf-8')

return content


def get_inital_prompt():
# Example usage
document_id = extract_document_id_from_url(
'https://docs.google.com/document/d/1GtLyBqhk-cu8CSo4A15WTgGDbMbL4B9LLjdvBoU3234/edit'
)
# print("Document id: ", document_id)
credentials_json = 'credentials.json'

try:
content = download_file(document_id, credentials_json)
return content
except Exception as e:
print(f'An error occurred: {e}')
return None
24 changes: 17 additions & 7 deletions functions/main.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
from json import dumps

# from handlers import get_response_from_llm
from chain import hist_aware_answers
from firebase_functions import https_fn, options
from handlers import get_response_from_llm


@https_fn.on_request(cors=options.CorsOptions(cors_origins=['*']))
def get_response_url(req: https_fn.Request) -> https_fn.Response:
query = req.get_json().get('query', '')
llms = req.get_json().get('llms', ['gpt-4'])
llms = req.get_json().get('llms', ['gpt-4', 'gemini', 'claude'])
chat = req.get_json().get('history', [])
print(chat)
responses = {}
for llm in llms:
response = get_response_from_llm(query, llm)
responses[llm] = response
responses = hist_aware_answers(llm, query) # , chat_history)
# responses[llm] = response
return https_fn.Response(dumps(responses), mimetype='application/json')


@https_fn.on_call()
def get_response(req: https_fn.CallableRequest):
query = req.data.get('query', '')
llms = req.data.get('llms', ['gpt-4'])
llms = req.get_json().get('llms', ['gpt-4', 'gemini', 'claude'])
chat = req.get_json().get('history', [])
print(chat)
responses = {}
for llm in llms:
response = get_response_from_llm(query, llm)
responses[llm] = response
responses = hist_aware_answers(llm, query) # , chat_history)
# responses[llm] = response
return responses


@https_fn.on_request(cors=options.CorsOptions(cors_origins=['*']))
def get_test(req: https_fn.Request) -> https_fn.Response:
return https_fn.Response('Hello World!', mimetype='application/json')
7 changes: 7 additions & 0 deletions functions/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ langchain-community
langchain-openai
langchain-astradb
lark
langchain_core
langchain_google_genai
langchain_anthropic
google-auth
google-auth-oauthlib
google-api-python-client
python-dotenv
Loading