diff --git a/.env.template b/.env.template deleted file mode 100644 index b007d62b..00000000 --- a/.env.template +++ /dev/null @@ -1,34 +0,0 @@ -# Supabase SQL -SUPABASE_URL= -SUPABASE_API_KEY= -SUPABASE_READ_ONLY= -SUPABASE_JWT_SECRET= - -MATERIALS_SUPABASE_TABLE=uiuc_chatbot -SUPABASE_DOCUMENTS_TABLE=documents - -# QDRANT -QDRANT_COLLECTION_NAME=uiuc-chatbot -DEV_QDRANT_COLLECTION_NAME=dev -QDRANT_URL= -QDRANT_API_KEY= - -REFACTORED_MATERIALS_SUPABASE_TABLE= - -# AWS -S3_BUCKET_NAME=uiuc-chatbot -AWS_ACCESS_KEY_ID= -AWS_SECRET_ACCESS_KEY= - -OPENAI_API_KEY= - -NOMIC_API_KEY= -LINTRULE_SECRET= - -# Github Agent -GITHUB_APP_ID= -GITHUB_APP_PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY----- - ------END RSA PRIVATE KEY-----" - -NUMEXPR_MAX_THREADS=2 diff --git a/.gitignore b/.gitignore index b0391b88..ee82f76d 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ coursera-dl/ wandb *.ipynb *.pem +qdrant_data/* # don't expose env files .env diff --git a/.trunk/configs/.isort.cfg b/.trunk/configs/.isort.cfg index b9fb3f3e..5225d7a2 100644 --- a/.trunk/configs/.isort.cfg +++ b/.trunk/configs/.isort.cfg @@ -1,2 +1,2 @@ [settings] -profile=black +profile=google diff --git a/.trunk/configs/.style.yapf b/.trunk/configs/.style.yapf index 3d4d13b2..3e0faa55 100644 --- a/.trunk/configs/.style.yapf +++ b/.trunk/configs/.style.yapf @@ -1,4 +1,4 @@ [style] based_on_style = google -column_limit = 120 +column_limit = 140 indent_width = 2 diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index 4186a1e2..292c526c 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -43,7 +43,6 @@ lint: paths: - .github/**/* - .trunk/**/* - - mkdocs.yml - .DS_Store - .vscode/**/* - README.md diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..b96ac3d8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +# Use an official Python runtime as a parent image +FROM python:3.10-slim + +# Set the working directory in the container +WORKDIR /usr/src/app + + +# Copy the requirements file first to leverage Docker cache +COPY ai_ta_backend/requirements.txt . + +# Install any needed packages specified in requirements.txt +RUN pip install -r requirements.txt + +# Mkdir for sqlite db +RUN mkdir -p /usr/src/app/db + +# Copy the rest of the local directory contents into the container +COPY . . + +# Set the Python path to include the ai_ta_backend directory +ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/ai_ta_backend" + +# Make port 8000 available to the world outside this container +EXPOSE 8000 + +# Run the application using Gunicorn with specified configuration +CMD ["gunicorn", "--workers=1", "--threads=100", "--worker-class=gthread", "ai_ta_backend.main:app", "--timeout=1800", "--bind=0.0.0.0:8000"] diff --git a/README.md b/README.md index 149e65ee..c2e2b895 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,14 @@ Architecture diagram of Flask + Next.js & React hosted on Vercel. Automatic [API Reference](https://uiuc-chatbot.github.io/ai-ta-backend/reference/) +## Docker Deployment + +1. Just run Docker Compose `docker compose up --build` + +Works on version: `Docker Compose version v2.27.1-desktop.1` + +Works on Apple Silicon M1 `aarch64`, and `x86`. + ## πŸ“£ Development 1. Rename `.env.template` to `.env` and fill in the required variables @@ -36,3 +44,4 @@ The docs are auto-built and deployed to [our docs website](https://uiuc-chatbot. 'url': doc.metadata.get('url'), # wouldn't this error out? 'base_url': doc.metadata.get('base_url'), ``` + diff --git a/ai_ta_backend/beam/OpenaiEmbeddings.py b/ai_ta_backend/beam/OpenaiEmbeddings.py index 2f0f64f7..eb7532db 100644 --- a/ai_ta_backend/beam/OpenaiEmbeddings.py +++ b/ai_ta_backend/beam/OpenaiEmbeddings.py @@ -1,550 +1,539 @@ -""" -API REQUEST PARALLEL PROCESSOR - -Using the OpenAI API to process lots of text quickly takes some care. -If you trickle in a million API requests one by one, they'll take days to complete. -If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors. -To maximize throughput, parallel requests need to be throttled to stay under rate limits. - -This script parallelizes requests to the OpenAI API while throttling to stay under rate limits. - -Features: -- Streams requests from file, to avoid running out of memory for giant jobs -- Makes requests concurrently, to maximize throughput -- Throttles request and token usage, to stay under rate limits -- Retries failed requests up to {max_attempts} times, to avoid missing data -- Logs errors, to diagnose problems with requests - -Example command to call script: -``` -python examples/api_request_parallel_processor.py \ - --requests_filepath examples/data/example_requests_to_parallel_process.jsonl \ - --save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \ - --request_url https://api.openai.com/v1/embeddings \ - --max_requests_per_minute 1500 \ - --max_tokens_per_minute 6250000 \ - --token_encoding_name cl100k_base \ - --max_attempts 5 \ - --logging_level 20 -``` - -Inputs: -- requests_filepath : str - - path to the file containing the requests to be processed - - file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field - - e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}} - - as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically) - - an example file is provided at examples/data/example_requests_to_parallel_process.jsonl - - the code to generate the example file is appended to the bottom of this script -- save_filepath : str, optional - - path to the file where the results will be saved - - file will be a jsonl file, where each line is an array with the original request plus the API response - - e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}] - - if omitted, results will be saved to {requests_filename}_results.jsonl -- request_url : str, optional - - URL of the API endpoint to call - - if omitted, will default to "https://api.openai.com/v1/embeddings" -- api_key : str, optional - - API key to use - - if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")} -- max_requests_per_minute : float, optional - - target number of requests to make per minute (will make less if limited by tokens) - - leave headroom by setting this to 50% or 75% of your limit - - if requests are limiting you, try batching multiple embeddings or completions into one request - - if omitted, will default to 1,500 -- max_tokens_per_minute : float, optional - - target number of tokens to use per minute (will use less if limited by requests) - - leave headroom by setting this to 50% or 75% of your limit - - if omitted, will default to 125,000 -- token_encoding_name : str, optional - - name of the token encoding used, as defined in the `tiktoken` package - - if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`) -- max_attempts : int, optional - - number of times to retry a failed request before giving up - - if omitted, will default to 5 -- logging_level : int, optional - - level of logging to use; higher numbers will log fewer messages - - 40 = ERROR; will log only when requests fail after all retries - - 30 = WARNING; will log when requests his rate limits or other errors - - 20 = INFO; will log when requests start and the status at finish - - 10 = DEBUG; will log various things as the loop runs to see when they occur - - if omitted, will default to 20 (INFO). - -The script is structured as follows: - - Imports - - Define main() - - Initialize things - - In main loop: - - Get next request if one is not already waiting for capacity - - Update available token & request capacity - - If enough capacity available, call API - - The loop pauses if a rate limit error is hit - - The loop breaks when no tasks remain - - Define dataclasses - - StatusTracker (stores script metadata counters; only one instance is created) - - APIRequest (stores API inputs, outputs, metadata; one method to call API) - - Define functions - - api_endpoint_from_url (extracts API endpoint from request URL) - - append_to_jsonl (writes to results file) - - num_tokens_consumed_from_request (bigger function to infer token usage from request) - - task_id_generator_function (yields 1, 2, 3, ...) - - Run main() -""" - -# import argparse -# import subprocess -# import tempfile -# from langchain.llms import OpenAI -import asyncio -import json -import logging - -# import os -import re -import time - -# for storing API inputs, outputs, and metadata -from dataclasses import dataclass, field -from typing import Any, List - -import aiohttp # for making API calls concurrently -import tiktoken # for counting tokens - -# from langchain.embeddings.openai import OpenAIEmbeddings -# from langchain.vectorstores import Qdrant -# from qdrant_client import QdrantClient, models - - -class OpenAIAPIProcessor: - - def __init__(self, input_prompts_list, request_url, api_key, max_requests_per_minute, max_tokens_per_minute, - token_encoding_name, max_attempts, logging_level): - self.request_url = request_url - self.api_key = api_key - self.max_requests_per_minute = max_requests_per_minute - self.max_tokens_per_minute = max_tokens_per_minute - self.token_encoding_name = token_encoding_name - self.max_attempts = max_attempts - self.logging_level = logging_level - self.input_prompts_list: List[dict] = input_prompts_list - self.results = [] - self.cleaned_results: List[str] = [] - - async def process_api_requests_from_file(self): - """Processes API requests in parallel, throttling to stay under rate limits.""" - # constants - seconds_to_pause_after_rate_limit_error = 15 - seconds_to_sleep_each_loop = 0.001 # 1 ms limits max throughput to 1,000 requests per second - - # initialize logging - logging.basicConfig(level=self.logging_level) - logging.debug(f"Logging initialized at level {self.logging_level}") - - # infer API endpoint and construct request header - api_endpoint = api_endpoint_from_url(self.request_url) - request_header = {"Authorization": f"Bearer {self.api_key}"} - - # initialize trackers - queue_of_requests_to_retry = asyncio.Queue() - task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ... - status_tracker = StatusTracker() # single instance to track a collection of variables - next_request = None # variable to hold the next request to call - - # initialize available capacity counts - available_request_capacity = self.max_requests_per_minute - available_token_capacity = self.max_tokens_per_minute - last_update_time = time.time() - - # initialize flags - file_not_finished = True # after file is empty, we'll skip reading it - logging.debug("Initialization complete.") - - requests = self.input_prompts_list.__iter__() - - logging.debug("File opened. Entering main loop") - - task_list = [] - - while True: - # get next request (if one is not already waiting for capacity) - if next_request is None: - if not queue_of_requests_to_retry.empty(): - next_request = queue_of_requests_to_retry.get_nowait() - logging.debug(f"Retrying request {next_request.task_id}: {next_request}") - elif file_not_finished: - try: - # get new request - # request_json = json.loads(next(requests)) - request_json = next(requests) - - next_request = APIRequest(task_id=next(task_id_generator), - request_json=request_json, - token_consumption=num_tokens_consumed_from_request( - request_json, api_endpoint, self.token_encoding_name), - attempts_left=self.max_attempts, - metadata=request_json.pop("metadata", None)) - status_tracker.num_tasks_started += 1 - status_tracker.num_tasks_in_progress += 1 - logging.debug(f"Reading request {next_request.task_id}: {next_request}") - except StopIteration: - # if file runs out, set flag to stop reading it - logging.debug("Read file exhausted") - file_not_finished = False - - # update available capacity - current_time = time.time() - seconds_since_update = current_time - last_update_time - available_request_capacity = min( - available_request_capacity + self.max_requests_per_minute * seconds_since_update / 60.0, - self.max_requests_per_minute, - ) - available_token_capacity = min( - available_token_capacity + self.max_tokens_per_minute * seconds_since_update / 60.0, - self.max_tokens_per_minute, - ) - last_update_time = current_time - - # if enough capacity available, call API - if next_request: - next_request_tokens = next_request.token_consumption - if (available_request_capacity >= 1 and available_token_capacity >= next_request_tokens): - # update counters - available_request_capacity -= 1 - available_token_capacity -= next_request_tokens - next_request.attempts_left -= 1 - - # call API - # TODO: NOT SURE RESPONSE WILL WORK HERE - task = asyncio.create_task( - next_request.call_api( - request_url=self.request_url, - request_header=request_header, - retry_queue=queue_of_requests_to_retry, - status_tracker=status_tracker, - )) - task_list.append(task) - next_request = None # reset next_request to empty - - # print("status_tracker.num_tasks_in_progress", status_tracker.num_tasks_in_progress) - # one_task_result = task.result() - # print("one_task_result", one_task_result) - - # if all tasks are finished, break - if status_tracker.num_tasks_in_progress == 0: - break - - # main loop sleeps briefly so concurrent tasks can run - await asyncio.sleep(seconds_to_sleep_each_loop) - - # if a rate limit error was hit recently, pause to cool down - seconds_since_rate_limit_error = (time.time() - status_tracker.time_of_last_rate_limit_error) - if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error: - remaining_seconds_to_pause = (seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error) - await asyncio.sleep(remaining_seconds_to_pause) - # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago - logging.warn( - f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}" - ) - - # after finishing, log final status - logging.info("""Parallel processing complete. About to return.""") - if status_tracker.num_tasks_failed > 0: - logging.warning(f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed.") - if status_tracker.num_rate_limit_errors > 0: - logging.warning( - f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.") - - # asyncio wait for task_list - await asyncio.wait(task_list) - - for task in task_list: - openai_completion = task.result() - self.results.append(openai_completion) - - self.cleaned_results: List[str] = extract_context_from_results(self.results) - - -def extract_context_from_results(results: List[Any]) -> List[str]: - assistant_contents = [] - total_prompt_tokens = 0 - total_completion_tokens = 0 - - for element in results: - if element is not None: - for item in element: - if 'choices' in item: - for choice in item['choices']: - if choice['message']['role'] == 'assistant': - assistant_contents.append(choice['message']['content']) - total_prompt_tokens += item['usage']['prompt_tokens'] - total_completion_tokens += item['usage']['completion_tokens'] - # Note: I don't think the prompt_tokens or completion_tokens is working quite right... - - return assistant_contents - - -# dataclasses - - -@dataclass -class StatusTracker: - """Stores metadata about the script's progress. Only one instance is created.""" - - num_tasks_started: int = 0 - num_tasks_in_progress: int = 0 # script ends when this reaches 0 - num_tasks_succeeded: int = 0 - num_tasks_failed: int = 0 - num_rate_limit_errors: int = 0 - num_api_errors: int = 0 # excluding rate limit errors, counted above - num_other_errors: int = 0 - time_of_last_rate_limit_error: float = 0 # used to cool off after hitting rate limits - - -@dataclass -class APIRequest: - """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call.""" - - task_id: int - request_json: dict - token_consumption: int - attempts_left: int - metadata: dict - result: list = field(default_factory=list) - - async def call_api( - self, - request_url: str, - request_header: dict, - retry_queue: asyncio.Queue, - status_tracker: StatusTracker, - ): - """Calls the OpenAI API and saves results.""" - # logging.info(f"Starting request #{self.task_id}") - error = None - try: - async with aiohttp.ClientSession() as session: - async with session.post(url=request_url, headers=request_header, json=self.request_json) as response: - response = await response.json() - if "error" in response: - logging.warning(f"Request {self.task_id} failed with error {response['error']}") - status_tracker.num_api_errors += 1 - error = response - if "Rate limit" in response["error"].get("message", ""): - status_tracker.time_of_last_rate_limit_error = time.time() - status_tracker.num_rate_limit_errors += 1 - status_tracker.num_api_errors -= 1 # rate limit errors are counted separately - - except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them - logging.warning(f"Request {self.task_id} failed with Exception {e}") - status_tracker.num_other_errors += 1 - error = e - if error: - self.result.append(error) - if self.attempts_left: - retry_queue.put_nowait(self) - else: - logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}") - data = ([self.request_json, [str(e) for e in self.result], self.metadata] - if self.metadata else [self.request_json, [str(e) for e in self.result]]) - #append_to_jsonl(data, save_filepath) - status_tracker.num_tasks_in_progress -= 1 - status_tracker.num_tasks_failed += 1 - return data - else: - data = ([self.request_json, response, self.metadata] if self.metadata else [self.request_json, response] - ) # type: ignore - #append_to_jsonl(data, save_filepath) - status_tracker.num_tasks_in_progress -= 1 - status_tracker.num_tasks_succeeded += 1 - # logging.debug(f"Request {self.task_id} saved to {save_filepath}") - - return data - - -# functions - - -def api_endpoint_from_url(request_url: str): - """Extract the API endpoint from the request URL.""" - if 'text-embedding-ada-002' in request_url: - return 'embeddings' - else: - match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url) - return match[1] # type: ignore - - -def append_to_jsonl(data, filename: str) -> None: - """Append a json payload to the end of a jsonl file.""" - json_string = json.dumps(data) - with open(filename, "a") as f: - f.write(json_string + "\n") - - -def num_tokens_consumed_from_request( - request_json: dict, - api_endpoint: str, - token_encoding_name: str, -): - """Count the number of tokens in the request. Only supports completion and embedding requests.""" - encoding = tiktoken.get_encoding(token_encoding_name) - # if completions request, tokens = prompt + n * max_tokens - if api_endpoint.endswith("completions"): - max_tokens = request_json.get("max_tokens", 15) - n = request_json.get("n", 1) - completion_tokens = n * max_tokens - - # chat completions - if api_endpoint.startswith("chat/"): - num_tokens = 0 - for message in request_json["messages"]: - num_tokens += 4 # every message follows {role/name}\n{content}\n - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": # if there's a name, the role is omitted - num_tokens -= 1 # role is always required and always 1 token - num_tokens += 2 # every reply is primed with assistant - return num_tokens + completion_tokens - # normal completions - else: - prompt = request_json["prompt"] - if isinstance(prompt, str): # single prompt - prompt_tokens = len(encoding.encode(prompt)) - num_tokens = prompt_tokens + completion_tokens - return num_tokens - elif isinstance(prompt, list): # multiple prompts - prompt_tokens = sum([len(encoding.encode(p)) for p in prompt]) - num_tokens = prompt_tokens + completion_tokens * len(prompt) - return num_tokens - else: - raise TypeError('Expecting either string or list of strings for "prompt" field in completion request') - # if embeddings request, tokens = input tokens - elif api_endpoint == "embeddings": - input = request_json["input"] - if isinstance(input, str): # single input - num_tokens = len(encoding.encode(input)) - return num_tokens - elif isinstance(input, list): # multiple inputs - num_tokens = sum([len(encoding.encode(i)) for i in input]) - return num_tokens - else: - raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request') - # more logic needed to support other API calls (e.g., edits, inserts, DALL-E) - else: - raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script') - - -def task_id_generator_function(): - """Generate integers 0, 1, 2, and so on.""" - task_id = 0 - while True: - yield task_id - task_id += 1 - - -if __name__ == '__main__': - pass - - # run script - # if __name__ == "__main__": - # qdrant_client = QdrantClient( - # url=os.getenv('QDRANT_URL'), - # api_key=os.getenv('QDRANT_API_KEY'), - # ) - # vectorstore = Qdrant( - # client=qdrant_client, - # collection_name=os.getenv('QDRANT_COLLECTION_NAME'), # type: ignore - # embeddings=OpenAIEmbeddings()) # type: ignore - - # user_question = "What is the significance of Six Sigma?" - # k = 4 - # fetch_k = 200 - # found_docs = vectorstore.max_marginal_relevance_search(user_question, k=k, fetch_k=200) - - # requests = [] - # for i, doc in enumerate(found_docs): - # dictionary = { - # "model": "gpt-3.5-turbo-0613", # 4k context - # "messages": [{ - # "role": "system", - # "content": "You are a factual summarizer of partial documents. Stick to the facts (including partial info when necessary to avoid making up potentially incorrect details), and say I don't know when necessary." - # }, { - # "role": - # "user", - # "content": - # f"What is a comprehensive summary of the given text, based on the question:\n{doc.page_content}\nQuestion: {user_question}\nThe summary should cover all the key points only relevant to the question, while also condensing the information into a concise and easy-to-understand format. Please ensure that the summary includes relevant details and examples that support the main ideas, while avoiding any unnecessary information or repetition. Feel free to include references, sentence fragments, keywords, or anything that could help someone learn about it, only as it relates to the given question. The length of the summary should be as short as possible, without losing relevant information.\n" - # }], - # "n": 1, - # "max_tokens": 500, - # "metadata": doc.metadata - # } - # requests.append(dictionary) - - # oai = OpenAIAPIProcessor( - # input_prompts_list=requests, - # request_url='https://api.openai.com/v1/chat/completions', - # api_key=os.getenv("OPENAI_API_KEY"), - # max_requests_per_minute=1500, - # max_tokens_per_minute=90000, - # token_encoding_name='cl100k_base', - # max_attempts=5, - # logging_level=20, - # ) - # # run script - # asyncio.run(oai.process_api_requests_from_file()) - - # assistant_contents = [] - # total_prompt_tokens = 0 - # total_completion_tokens = 0 - - # print("Results, end of main: ", oai.results) - # print("-"*50) - - # # jsonObject = json.loads(oai.results) - # for element in oai.results: - # for item in element: - # if 'choices' in item: - # for choice in item['choices']: - # if choice['message']['role'] == 'assistant': - # assistant_contents.append(choice['message']['content']) - # total_prompt_tokens += item['usage']['prompt_tokens'] - # total_completion_tokens += item['usage']['completion_tokens'] - - # print("Assistant Contents:", assistant_contents) - # print("Total Prompt Tokens:", total_prompt_tokens) - # print("Total Completion Tokens:", total_completion_tokens) - # turbo_total_cost = (total_prompt_tokens * 0.0015) + (total_completion_tokens * 0.002) - # print("Total cost (3.5-turbo):", (total_prompt_tokens * 0.0015), " + Completions: ", (total_completion_tokens * 0.002), " = ", turbo_total_cost) - - # gpt4_total_cost = (total_prompt_tokens * 0.03) + (total_completion_tokens * 0.06) - # print("Hypothetical cost for GPT-4:", (total_prompt_tokens * 0.03), " + Completions: ", (total_completion_tokens * 0.06), " = ", gpt4_total_cost) - # print("GPT-4 cost premium: ", (gpt4_total_cost / turbo_total_cost), "x") - ''' - Pricing: - GPT4: - * $0.03 prompt - * $0.06 completions - 3.5-turbo: - * $0.0015 prompt - * $0.002 completions - ''' -""" -APPENDIX - -The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002. - -It was generated with the following code: - -```python -import json - -filename = "data/example_requests_to_parallel_process.jsonl" -n_requests = 10_000 -jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)] -with open(filename, "w") as f: - for job in jobs: - json_string = json.dumps(job) - f.write(json_string + "\n") -``` - -As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically). -""" +# """ +# API REQUEST PARALLEL PROCESSOR + +# Using the OpenAI API to process lots of text quickly takes some care. +# If you trickle in a million API requests one by one, they'll take days to complete. +# If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors. +# To maximize throughput, parallel requests need to be throttled to stay under rate limits. + +# This script parallelizes requests to the OpenAI API while throttling to stay under rate limits. + +# Features: +# - Streams requests from file, to avoid running out of memory for giant jobs +# - Makes requests concurrently, to maximize throughput +# - Throttles request and token usage, to stay under rate limits +# - Retries failed requests up to {max_attempts} times, to avoid missing data +# - Logs errors, to diagnose problems with requests + +# Example command to call script: +# ``` +# python examples/api_request_parallel_processor.py \ +# --requests_filepath examples/data/example_requests_to_parallel_process.jsonl \ +# --save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \ +# --request_url https://api.openai.com/v1/embeddings \ +# --max_requests_per_minute 1500 \ +# --max_tokens_per_minute 6250000 \ +# --token_encoding_name cl100k_base \ +# --max_attempts 5 \ +# --logging_level 20 +# ``` + +# Inputs: +# - requests_filepath : str +# - path to the file containing the requests to be processed +# - file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field +# - e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}} +# - as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically) +# - an example file is provided at examples/data/example_requests_to_parallel_process.jsonl +# - the code to generate the example file is appended to the bottom of this script +# - save_filepath : str, optional +# - path to the file where the results will be saved +# - file will be a jsonl file, where each line is an array with the original request plus the API response +# - e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}] +# - if omitted, results will be saved to {requests_filename}_results.jsonl +# - request_url : str, optional +# - URL of the API endpoint to call +# - if omitted, will default to "https://api.openai.com/v1/embeddings" +# - api_key : str, optional +# - API key to use +# - if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")} +# - max_requests_per_minute : float, optional +# - target number of requests to make per minute (will make less if limited by tokens) +# - leave headroom by setting this to 50% or 75% of your limit +# - if requests are limiting you, try batching multiple embeddings or completions into one request +# - if omitted, will default to 1,500 +# - max_tokens_per_minute : float, optional +# - target number of tokens to use per minute (will use less if limited by requests) +# - leave headroom by setting this to 50% or 75% of your limit +# - if omitted, will default to 125,000 +# - token_encoding_name : str, optional +# - name of the token encoding used, as defined in the `tiktoken` package +# - if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`) +# - max_attempts : int, optional +# - number of times to retry a failed request before giving up +# - if omitted, will default to 5 +# - logging_level : int, optional +# - level of logging to use; higher numbers will log fewer messages +# - 40 = ERROR; will log only when requests fail after all retries +# - 30 = WARNING; will log when requests his rate limits or other errors +# - 20 = INFO; will log when requests start and the status at finish +# - 10 = DEBUG; will log various things as the loop runs to see when they occur +# - if omitted, will default to 20 (INFO). + +# The script is structured as follows: +# - Imports +# - Define main() +# - Initialize things +# - In main loop: +# - Get next request if one is not already waiting for capacity +# - Update available token & request capacity +# - If enough capacity available, call API +# - The loop pauses if a rate limit error is hit +# - The loop breaks when no tasks remain +# - Define dataclasses +# - StatusTracker (stores script metadata counters; only one instance is created) +# - APIRequest (stores API inputs, outputs, metadata; one method to call API) +# - Define functions +# - api_endpoint_from_url (extracts API endpoint from request URL) +# - append_to_jsonl (writes to results file) +# - num_tokens_consumed_from_request (bigger function to infer token usage from request) +# - task_id_generator_function (yields 1, 2, 3, ...) +# - Run main() +# """ + +# # import argparse +# # import subprocess +# # import tempfile +# # from langchain.llms import OpenAI +# import asyncio +# import json +# import logging + +# # import os +# import re +# import time + +# # for storing API inputs, outputs, and metadata +# from dataclasses import dataclass, field +# from typing import Any, List + +# import aiohttp # for making API calls concurrently +# import tiktoken # for counting tokens + +# # from langchain.embeddings.openai import OpenAIEmbeddings +# # from langchain.vectorstores import Qdrant +# # from qdrant_client import QdrantClient, models + +# class OpenAIAPIProcessor: + +# def __init__(self, input_prompts_list, request_url, api_key, max_requests_per_minute, max_tokens_per_minute, +# token_encoding_name, max_attempts, logging_level): +# self.request_url = request_url +# self.api_key = api_key +# self.max_requests_per_minute = max_requests_per_minute +# self.max_tokens_per_minute = max_tokens_per_minute +# self.token_encoding_name = token_encoding_name +# self.max_attempts = max_attempts +# self.logging_level = logging_level +# self.input_prompts_list: List[dict] = input_prompts_list +# self.results = [] +# self.cleaned_results: List[str] = [] + +# async def process_api_requests_from_file(self): +# """Processes API requests in parallel, throttling to stay under rate limits.""" +# # constants +# seconds_to_pause_after_rate_limit_error = 15 +# seconds_to_sleep_each_loop = 0.001 # 1 ms limits max throughput to 1,000 requests per second + +# # initialize logging +# logging.basicConfig(level=self.logging_level) +# logging.debug(f"Logging initialized at level {self.logging_level}") + +# # infer API endpoint and construct request header +# api_endpoint = api_endpoint_from_url(self.request_url) +# request_header = {"Authorization": f"Bearer {self.api_key}"} + +# # initialize trackers +# queue_of_requests_to_retry = asyncio.Queue() +# task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ... +# status_tracker = StatusTracker() # single instance to track a collection of variables +# next_request = None # variable to hold the next request to call + +# # initialize available capacity counts +# available_request_capacity = self.max_requests_per_minute +# available_token_capacity = self.max_tokens_per_minute +# last_update_time = time.time() + +# # initialize flags +# file_not_finished = True # after file is empty, we'll skip reading it +# logging.debug("Initialization complete.") + +# requests = self.input_prompts_list.__iter__() + +# logging.debug("File opened. Entering main loop") + +# task_list = [] + +# while True: +# # get next request (if one is not already waiting for capacity) +# if next_request is None: +# if not queue_of_requests_to_retry.empty(): +# next_request = queue_of_requests_to_retry.get_nowait() +# logging.debug(f"Retrying request {next_request.task_id}: {next_request}") +# elif file_not_finished: +# try: +# # get new request +# # request_json = json.loads(next(requests)) +# request_json = next(requests) + +# next_request = APIRequest(task_id=next(task_id_generator), +# request_json=request_json, +# token_consumption=num_tokens_consumed_from_request( +# request_json, api_endpoint, self.token_encoding_name), +# attempts_left=self.max_attempts, +# metadata=request_json.pop("metadata", None)) +# status_tracker.num_tasks_started += 1 +# status_tracker.num_tasks_in_progress += 1 +# logging.debug(f"Reading request {next_request.task_id}: {next_request}") +# except StopIteration: +# # if file runs out, set flag to stop reading it +# logging.debug("Read file exhausted") +# file_not_finished = False + +# # update available capacity +# current_time = time.time() +# seconds_since_update = current_time - last_update_time +# available_request_capacity = min( +# available_request_capacity + self.max_requests_per_minute * seconds_since_update / 60.0, +# self.max_requests_per_minute, +# ) +# available_token_capacity = min( +# available_token_capacity + self.max_tokens_per_minute * seconds_since_update / 60.0, +# self.max_tokens_per_minute, +# ) +# last_update_time = current_time + +# # if enough capacity available, call API +# if next_request: +# next_request_tokens = next_request.token_consumption +# if (available_request_capacity >= 1 and available_token_capacity >= next_request_tokens): +# # update counters +# available_request_capacity -= 1 +# available_token_capacity -= next_request_tokens +# next_request.attempts_left -= 1 + +# # call API +# # TODO: NOT SURE RESPONSE WILL WORK HERE +# task = asyncio.create_task( +# next_request.call_api( +# request_url=self.request_url, +# request_header=request_header, +# retry_queue=queue_of_requests_to_retry, +# status_tracker=status_tracker, +# )) +# task_list.append(task) +# next_request = None # reset next_request to empty + +# # logging.info("status_tracker.num_tasks_in_progress", status_tracker.num_tasks_in_progress) +# # one_task_result = task.result() +# # logging.info("one_task_result", one_task_result) + +# # if all tasks are finished, break +# if status_tracker.num_tasks_in_progress == 0: +# break + +# # main loop sleeps briefly so concurrent tasks can run +# await asyncio.sleep(seconds_to_sleep_each_loop) + +# # if a rate limit error was hit recently, pause to cool down +# seconds_since_rate_limit_error = (time.time() - status_tracker.time_of_last_rate_limit_error) +# if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error: +# remaining_seconds_to_pause = (seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error) +# await asyncio.sleep(remaining_seconds_to_pause) +# # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago +# logging.warn( +# f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}" +# ) + +# # after finishing, log final status +# logging.info("""Parallel processing complete. About to return.""") +# if status_tracker.num_tasks_failed > 0: +# logging.warning(f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed.") +# if status_tracker.num_rate_limit_errors > 0: +# logging.warning( +# f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.") + +# # asyncio wait for task_list +# await asyncio.wait(task_list) + +# for task in task_list: +# openai_completion = task.result() +# self.results.append(openai_completion) + +# self.cleaned_results: List[str] = extract_context_from_results(self.results) + +# def extract_context_from_results(results: List[Any]) -> List[str]: +# assistant_contents = [] +# total_prompt_tokens = 0 +# total_completion_tokens = 0 + +# for element in results: +# if element is not None: +# for item in element: +# if 'choices' in item: +# for choice in item['choices']: +# if choice['message']['role'] == 'assistant': +# assistant_contents.append(choice['message']['content']) +# total_prompt_tokens += item['usage']['prompt_tokens'] +# total_completion_tokens += item['usage']['completion_tokens'] +# # Note: I don't think the prompt_tokens or completion_tokens is working quite right... + +# return assistant_contents + +# # dataclasses + +# @dataclass +# class StatusTracker: +# """Stores metadata about the script's progress. Only one instance is created.""" + +# num_tasks_started: int = 0 +# num_tasks_in_progress: int = 0 # script ends when this reaches 0 +# num_tasks_succeeded: int = 0 +# num_tasks_failed: int = 0 +# num_rate_limit_errors: int = 0 +# num_api_errors: int = 0 # excluding rate limit errors, counted above +# num_other_errors: int = 0 +# time_of_last_rate_limit_error: float = 0 # used to cool off after hitting rate limits + +# @dataclass +# class APIRequest: +# """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call.""" + +# task_id: int +# request_json: dict +# token_consumption: int +# attempts_left: int +# metadata: dict +# result: list = field(default_factory=list) + +# async def call_api( +# self, +# request_url: str, +# request_header: dict, +# retry_queue: asyncio.Queue, +# status_tracker: StatusTracker, +# ): +# """Calls the OpenAI API and saves results.""" +# # logging.info(f"Starting request #{self.task_id}") +# error = None +# try: +# async with aiohttp.ClientSession() as session: +# async with session.post(url=request_url, headers=request_header, json=self.request_json) as response: +# response = await response.json() +# if "error" in response: +# logging.warning(f"Request {self.task_id} failed with error {response['error']}") +# status_tracker.num_api_errors += 1 +# error = response +# if "Rate limit" in response["error"].get("message", ""): +# status_tracker.time_of_last_rate_limit_error = time.time() +# status_tracker.num_rate_limit_errors += 1 +# status_tracker.num_api_errors -= 1 # rate limit errors are counted separately + +# except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them +# logging.warning(f"Request {self.task_id} failed with Exception {e}") +# status_tracker.num_other_errors += 1 +# error = e +# if error: +# self.result.append(error) +# if self.attempts_left: +# retry_queue.put_nowait(self) +# else: +# logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}") +# data = ([self.request_json, [str(e) for e in self.result], self.metadata] +# if self.metadata else [self.request_json, [str(e) for e in self.result]]) +# #append_to_jsonl(data, save_filepath) +# status_tracker.num_tasks_in_progress -= 1 +# status_tracker.num_tasks_failed += 1 +# return data +# else: +# data = ([self.request_json, response, self.metadata] if self.metadata else [self.request_json, response] +# ) # type: ignore +# #append_to_jsonl(data, save_filepath) +# status_tracker.num_tasks_in_progress -= 1 +# status_tracker.num_tasks_succeeded += 1 +# # logging.debug(f"Request {self.task_id} saved to {save_filepath}") + +# return data + +# # functions + +# def api_endpoint_from_url(request_url: str): +# """Extract the API endpoint from the request URL.""" +# if 'text-embedding-ada-002' in request_url: +# return 'embeddings' +# else: +# match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url) +# return match[1] # type: ignore + +# def append_to_jsonl(data, filename: str) -> None: +# """Append a json payload to the end of a jsonl file.""" +# json_string = json.dumps(data) +# with open(filename, "a") as f: +# f.write(json_string + "\n") + +# def num_tokens_consumed_from_request( +# request_json: dict, +# api_endpoint: str, +# token_encoding_name: str, +# ): +# """Count the number of tokens in the request. Only supports completion and embedding requests.""" +# encoding = tiktoken.get_encoding(token_encoding_name) +# # if completions request, tokens = prompt + n * max_tokens +# if api_endpoint.endswith("completions"): +# max_tokens = request_json.get("max_tokens", 15) +# n = request_json.get("n", 1) +# completion_tokens = n * max_tokens + +# # chat completions +# if api_endpoint.startswith("chat/"): +# num_tokens = 0 +# for message in request_json["messages"]: +# num_tokens += 4 # every message follows {role/name}\n{content}\n +# for key, value in message.items(): +# num_tokens += len(encoding.encode(value)) +# if key == "name": # if there's a name, the role is omitted +# num_tokens -= 1 # role is always required and always 1 token +# num_tokens += 2 # every reply is primed with assistant +# return num_tokens + completion_tokens +# # normal completions +# else: +# prompt = request_json["prompt"] +# if isinstance(prompt, str): # single prompt +# prompt_tokens = len(encoding.encode(prompt)) +# num_tokens = prompt_tokens + completion_tokens +# return num_tokens +# elif isinstance(prompt, list): # multiple prompts +# prompt_tokens = sum([len(encoding.encode(p)) for p in prompt]) +# num_tokens = prompt_tokens + completion_tokens * len(prompt) +# return num_tokens +# else: +# raise TypeError('Expecting either string or list of strings for "prompt" field in completion request') +# # if embeddings request, tokens = input tokens +# elif api_endpoint == "embeddings": +# input = request_json["input"] +# if isinstance(input, str): # single input +# num_tokens = len(encoding.encode(input)) +# return num_tokens +# elif isinstance(input, list): # multiple inputs +# num_tokens = sum([len(encoding.encode(i)) for i in input]) +# return num_tokens +# else: +# raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request') +# # more logic needed to support other API calls (e.g., edits, inserts, DALL-E) +# else: +# raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script') + +# def task_id_generator_function(): +# """Generate integers 0, 1, 2, and so on.""" +# task_id = 0 +# while True: +# yield task_id +# task_id += 1 + +# if __name__ == '__main__': +# pass + +# # run script +# # if __name__ == "__main__": +# # qdrant_client = QdrantClient( +# # url=os.getenv('QDRANT_URL'), +# # api_key=os.getenv('QDRANT_API_KEY'), +# # ) +# # vectorstore = Qdrant( +# # client=qdrant_client, +# # collection_name=os.getenv('QDRANT_COLLECTION_NAME'), # type: ignore +# # embeddings=OpenAIEmbeddings()) # type: ignore + +# # user_question = "What is the significance of Six Sigma?" +# # k = 4 +# # fetch_k = 200 +# # found_docs = vectorstore.max_marginal_relevance_search(user_question, k=k, fetch_k=200) + +# # requests = [] +# # for i, doc in enumerate(found_docs): +# # dictionary = { +# # "model": "gpt-3.5-turbo-0613", # 4k context +# # "messages": [{ +# # "role": "system", +# # "content": "You are a factual summarizer of partial documents. Stick to the facts (including partial info when necessary to avoid making up potentially incorrect details), and say I don't know when necessary." +# # }, { +# # "role": +# # "user", +# # "content": +# # f"What is a comprehensive summary of the given text, based on the question:\n{doc.page_content}\nQuestion: {user_question}\nThe summary should cover all the key points only relevant to the question, while also condensing the information into a concise and easy-to-understand format. Please ensure that the summary includes relevant details and examples that support the main ideas, while avoiding any unnecessary information or repetition. Feel free to include references, sentence fragments, keywords, or anything that could help someone learn about it, only as it relates to the given question. The length of the summary should be as short as possible, without losing relevant information.\n" +# # }], +# # "n": 1, +# # "max_tokens": 500, +# # "metadata": doc.metadata +# # } +# # requests.append(dictionary) + +# # oai = OpenAIAPIProcessor( +# # input_prompts_list=requests, +# # request_url='https://api.openai.com/v1/chat/completions', +# # api_key=os.getenv("OPENAI_API_KEY"), +# # max_requests_per_minute=1500, +# # max_tokens_per_minute=90000, +# # token_encoding_name='cl100k_base', +# # max_attempts=5, +# # logging_level=20, +# # ) +# # # run script +# # asyncio.run(oai.process_api_requests_from_file()) + +# # assistant_contents = [] +# # total_prompt_tokens = 0 +# # total_completion_tokens = 0 + +# # logging.info("Results, end of main: ", oai.results) +# # logging.info("-"*50) + +# # # jsonObject = json.loads(oai.results) +# # for element in oai.results: +# # for item in element: +# # if 'choices' in item: +# # for choice in item['choices']: +# # if choice['message']['role'] == 'assistant': +# # assistant_contents.append(choice['message']['content']) +# # total_prompt_tokens += item['usage']['prompt_tokens'] +# # total_completion_tokens += item['usage']['completion_tokens'] + +# # logging.info("Assistant Contents:", assistant_contents) +# # logging.info("Total Prompt Tokens:", total_prompt_tokens) +# # logging.info("Total Completion Tokens:", total_completion_tokens) +# # turbo_total_cost = (total_prompt_tokens * 0.0015) + (total_completion_tokens * 0.002) +# # logging.info("Total cost (3.5-turbo):", (total_prompt_tokens * 0.0015), " + Completions: ", (total_completion_tokens * 0.002), " = ", turbo_total_cost) + +# # gpt4_total_cost = (total_prompt_tokens * 0.03) + (total_completion_tokens * 0.06) +# # logging.info("Hypothetical cost for GPT-4:", (total_prompt_tokens * 0.03), " + Completions: ", (total_completion_tokens * 0.06), " = ", gpt4_total_cost) +# # logging.info("GPT-4 cost premium: ", (gpt4_total_cost / turbo_total_cost), "x") +# ''' +# Pricing: +# GPT4: +# * $0.03 prompt +# * $0.06 completions +# 3.5-turbo: +# * $0.0015 prompt +# * $0.002 completions +# ''' +# """ +# APPENDIX + +# The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002. + +# It was generated with the following code: + +# ```python +# import json + +# filename = "data/example_requests_to_parallel_process.jsonl" +# n_requests = 10_000 +# jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)] +# with open(filename, "w") as f: +# for job in jobs: +# json_string = json.dumps(job) +# f.write(json_string + "\n") +# ``` + +# As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically). +# """ diff --git a/ai_ta_backend/beam/ingest.py b/ai_ta_backend/beam/ingest.py index f292f204..42174f27 100644 --- a/ai_ta_backend/beam/ingest.py +++ b/ai_ta_backend/beam/ingest.py @@ -1,1371 +1,1366 @@ -""" -To deploy: beam deploy ingest.py --profile caii-ncsa -Use CAII gmail to auth. -""" -import asyncio -import inspect -import json -import logging -import mimetypes -import os -import re -import shutil -import time -import traceback -import uuid -from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Any, Callable, Dict, List, Optional, Union - -import beam -import boto3 -import fitz -import openai -import pytesseract -import pdfplumber -import sentry_sdk -import supabase -from beam import App, QueueDepthAutoscaler, Runtime # RequestLatencyAutoscaler, -from bs4 import BeautifulSoup -from git.repo import Repo -from langchain.document_loaders import ( - Docx2txtLoader, - GitLoader, - PythonLoader, - TextLoader, - UnstructuredExcelLoader, - UnstructuredPowerPointLoader, -) -from langchain.document_loaders.csv_loader import CSVLoader -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.schema import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.vectorstores import Qdrant -from nomic_logging import delete_from_document_map, log_to_document_map, rebuild_map -from OpenaiEmbeddings import OpenAIAPIProcessor -from PIL import Image -from posthog import Posthog -from pydub import AudioSegment -from qdrant_client import QdrantClient, models -from qdrant_client.models import PointStruct - -# from langchain.schema.output_parser import StrOutputParser -# from langchain.chat_models import AzureChatOpenAI - -requirements = [ - "openai<1.0", - "supabase==2.0.2", - "tiktoken==0.5.1", - "boto3==1.28.79", - "qdrant-client==1.7.3", - "langchain==0.0.331", - "posthog==3.1.0", - "pysrt==1.1.2", - "docx2txt==0.8", - "pydub==0.25.1", - "ffmpeg-python==0.2.0", - "ffprobe==0.5", - "ffmpeg==1.4", - "PyMuPDF==1.23.6", - "pytesseract==0.3.10", # image OCR" - "openpyxl==3.1.2", # excel" - "networkx==3.2.1", # unused part of excel partitioning :(" - "python-pptx==0.6.23", - "unstructured==0.10.29", - "GitPython==3.1.40", - "beautifulsoup4==4.12.2", - "sentry-sdk==1.39.1", - "nomic==2.0.14", - "pdfplumber==0.11.0", # PDF OCR, better performance than Fitz/PyMuPDF in my Gies PDF testing. -] - -# TODO: consider adding workers. They share CPU and memory https://docs.beam.cloud/deployment/autoscaling#worker-use-cases -app = App("ingest", - runtime=Runtime( - cpu=1, - memory="3Gi", - image=beam.Image( - python_version="python3.10", - python_packages=requirements, - commands=["apt-get update && apt-get install -y ffmpeg tesseract-ocr"], - ), - )) - -# MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation") -OPENAI_API_TYPE = "azure" # "openai" or "azure" - - -def loader(): - """ - The loader function will run once for each worker that starts up. https://docs.beam.cloud/deployment/loaders - """ - openai.api_key = os.getenv("VLADS_OPENAI_KEY") - - # vector DB - qdrant_client = QdrantClient( - url=os.getenv('QDRANT_URL'), - api_key=os.getenv('QDRANT_API_KEY'), - ) - - vectorstore = Qdrant(client=qdrant_client, - collection_name=os.environ['QDRANT_COLLECTION_NAME'], - embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE, - openai_api_key=os.getenv('VLADS_OPENAI_KEY'))) - - # S3 - s3_client = boto3.client( - 's3', - aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), - aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), - ) - - # Create a Supabase client - supabase_client = supabase.create_client( # type: ignore - supabase_url=os.environ['SUPABASE_URL'], supabase_key=os.environ['SUPABASE_API_KEY']) - - # llm = AzureChatOpenAI( - # temperature=0, - # deployment_name=os.getenv('AZURE_OPENAI_ENGINE'), #type:ignore - # openai_api_base=os.getenv('AZURE_OPENAI_ENDPOINT'), #type:ignore - # openai_api_key=os.getenv('AZURE_OPENAI_KEY'), #type:ignore - # openai_api_version=os.getenv('OPENAI_API_VERSION'), #type:ignore - # openai_api_type=OPENAI_API_TYPE) - - posthog = Posthog(sync_mode=True, project_api_key=os.environ['POSTHOG_API_KEY'], host='https://app.posthog.com') - sentry_sdk.init( - dsn="https://examplePublicKey@o0.ingest.sentry.io/0", - - # Enable performance monitoring - enable_tracing=True, - ) - - return qdrant_client, vectorstore, s3_client, supabase_client, posthog - - -# autoscaler = RequestLatencyAutoscaler(desired_latency=30, max_replicas=2) -autoscaler = QueueDepthAutoscaler(max_tasks_per_replica=300, max_replicas=3) - - -# Triggers determine how your app is deployed -# @app.rest_api( -@app.task_queue( - workers=4, - callback_url='https://uiuc-chat-git-ingestprogresstracking-kastanday.vercel.app/api/UIUC-api/ingestTaskCallback', - max_pending_tasks=15_000, - max_retries=3, - timeout=-1, - loader=loader, - autoscaler=autoscaler) -def ingest(**inputs: Dict[str, Any]): - - qdrant_client, vectorstore, s3_client, supabase_client, posthog = inputs["context"] - - course_name: List[str] | str = inputs.get('course_name', '') - s3_paths: List[str] | str = inputs.get('s3_paths', '') - url: List[str] | str | None = inputs.get('url', None) - base_url: List[str] | str | None = inputs.get('base_url', None) - readable_filename: List[str] | str = inputs.get('readable_filename', '') - content: str | None = inputs.get('content', None) # is webtext if content exists - - print( - f"In top of /ingest route. course: {course_name}, s3paths: {s3_paths}, readable_filename: {readable_filename}, base_url: {base_url}, url: {url}, content: {content}" - ) - - ingester = Ingest(qdrant_client, vectorstore, s3_client, supabase_client, posthog) - - def run_ingest(course_name, s3_paths, base_url, url, readable_filename, content): - if content: - return ingester.ingest_single_web_text(course_name, base_url, url, content, readable_filename) - elif readable_filename == '': - return ingester.bulk_ingest(course_name, s3_paths, base_url=base_url, url=url) - else: - return ingester.bulk_ingest(course_name, - s3_paths, - readable_filename=readable_filename, - base_url=base_url, - url=url) - - # First try - success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) - - # retries - num_retires = 5 - for retry_num in range(1, num_retires): - if isinstance(success_fail_dict, str): - print(f"STRING ERROR: {success_fail_dict = }") - success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) - time.sleep(13 * retry_num) # max is 65 - elif success_fail_dict['failure_ingest']: - print(f"Ingest failure -- Retry attempt {retry_num}. File: {success_fail_dict}") - # s3_paths = success_fail_dict['failure_ingest'] # retry only failed paths.... what if this is a URL instead? - success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) - time.sleep(13 * retry_num) # max is 65 - else: - break - - # Final failure / success check - if success_fail_dict['failure_ingest']: - print(f"INGEST FAILURE -- About to send to supabase. success_fail_dict: {success_fail_dict}") - document = { - "course_name": - course_name, - "s3_path": - s3_paths, - "readable_filename": - readable_filename, - "url": - url, - "base_url": - base_url, - "error": - success_fail_dict['failure_ingest']['error'] - if isinstance(success_fail_dict['failure_ingest'], dict) else success_fail_dict['failure_ingest'] - } - response = supabase_client.table('documents_failed').insert(document).execute() # type: ignore - print(f"Supabase ingest failure response: {response}") - else: - # Success case: rebuild nomic document map after all ingests are done - # rebuild_status = rebuild_map(str(course_name), map_type='document') - pass - - print(f"Final success_fail_dict: {success_fail_dict}") - return json.dumps(success_fail_dict) - - -class Ingest(): - - def __init__(self, qdrant_client, vectorstore, s3_client, supabase_client, posthog): - self.qdrant_client = qdrant_client - self.vectorstore = vectorstore - self.s3_client = s3_client - self.supabase_client = supabase_client - self.posthog = posthog - - def bulk_ingest(self, course_name: str, s3_paths: Union[str, List[str]], - **kwargs) -> Dict[str, None | str | Dict[str, str]]: - """ - Bulk ingest a list of s3 paths into the vectorstore, and also into the supabase database. - -> Dict[str, str | Dict[str, str]] - """ - - def _ingest_single(ingest_method: Callable, s3_path, *args, **kwargs): - """Handle running an arbitrary ingest function for an individual file.""" - # RUN INGEST METHOD - ret = ingest_method(s3_path, *args, **kwargs) - if ret == "Success": - success_status['success_ingest'] = str(s3_path) - else: - success_status['failure_ingest'] = {'s3_path': str(s3_path), 'error': str(ret)} - - # πŸ‘‡πŸ‘‡πŸ‘‡πŸ‘‡ ADD NEW INGEST METHODS HERE πŸ‘‡πŸ‘‡πŸ‘‡πŸ‘‡πŸŽ‰ - file_ingest_methods = { - '.html': self._ingest_html, - '.py': self._ingest_single_py, - '.pdf': self._ingest_single_pdf, - '.txt': self._ingest_single_txt, - '.md': self._ingest_single_txt, - '.srt': self._ingest_single_srt, - '.vtt': self._ingest_single_vtt, - '.docx': self._ingest_single_docx, - '.ppt': self._ingest_single_ppt, - '.pptx': self._ingest_single_ppt, - '.xlsx': self._ingest_single_excel, - '.xls': self._ingest_single_excel, - '.csv': self._ingest_single_csv, - '.png': self._ingest_single_image, - '.jpg': self._ingest_single_image, - } - - # Ingest methods via MIME type (more general than filetype) - mimetype_ingest_methods = { - 'video': self._ingest_single_video, - 'audio': self._ingest_single_video, - 'text': self._ingest_single_txt, - 'image': self._ingest_single_image, - } - # πŸ‘†πŸ‘†πŸ‘†πŸ‘† ADD NEW INGEST METHODhe πŸ‘†πŸ‘†πŸ‘†πŸ‘†πŸŽ‰ - - print(f"Top of ingest, Course_name {course_name}. S3 paths {s3_paths}") - success_status: Dict[str, None | str | Dict[str, str]] = {"success_ingest": None, "failure_ingest": None} - try: - if isinstance(s3_paths, str): - s3_paths = [s3_paths] - - for s3_path in s3_paths: - file_extension = Path(s3_path).suffix - with NamedTemporaryFile(suffix=file_extension) as tmpfile: - self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) - mime_type = str(mimetypes.guess_type(tmpfile.name, strict=False)[0]) - mime_category = mime_type.split('/')[0] if '/' in mime_type else mime_type - - if file_extension in file_ingest_methods: - # Use specialized functions when possible, fallback to mimetype. Else raise error. - ingest_method = file_ingest_methods[file_extension] - _ingest_single(ingest_method, s3_path, course_name, **kwargs) - elif mime_category in mimetype_ingest_methods: - # fallback to MimeType - print("mime category", mime_category) - ingest_method = mimetype_ingest_methods[mime_category] - _ingest_single(ingest_method, s3_path, course_name, **kwargs) - else: - # No supported ingest... Fallback to attempting utf-8 decoding, otherwise fail. - try: - self._ingest_single_txt(s3_path, course_name) - success_status['success_ingest'] = s3_path - print(f"No ingest methods -- Falling back to UTF-8 INGEST... s3_path = {s3_path}") - except Exception as e: - print( - f"We don't have a ingest method for this filetype: {file_extension}. As a last-ditch effort, we tried to ingest the file as utf-8 text, but that failed too. File is unsupported: {s3_path}. UTF-8 ingest error: {e}" - ) - success_status['failure_ingest'] = { - 's3_path': - s3_path, - 'error': - f"We don't have a ingest method for this filetype: {file_extension} (with generic type {mime_type}), for file: {s3_path}" - } - self.posthog.capture( - 'distinct_id_of_the_user', - event='ingest_failure', - properties={ - 'course_name': - course_name, - 's3_path': - s3_paths, - 'kwargs': - kwargs, - 'error': - f"We don't have a ingest method for this filetype: {file_extension} (with generic type {mime_type}), for file: {s3_path}" - }) - - return success_status - except Exception as e: - err = f"❌❌ Error in /ingest: `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) # type: ignore - - success_status['failure_ingest'] = {'s3_path': s3_path, 'error': f"MAJOR ERROR DURING INGEST: {err}"} - self.posthog.capture('distinct_id_of_the_user', - event='ingest_failure', - properties={ - 'course_name': course_name, - 's3_path': s3_paths, - 'kwargs': kwargs, - 'error': err - }) - - sentry_sdk.capture_exception(e) - print(f"MAJOR ERROR IN /bulk_ingest: {str(e)}") - return success_status - - def ingest_single_web_text(self, course_name: str, base_url: str, url: str, content: str, readable_filename: str): - """Crawlee integration - """ - self.posthog.capture('distinct_id_of_the_user', - event='ingest_single_web_text_invoked', - properties={ - 'course_name': course_name, - 'base_url': base_url, - 'url': url, - 'content': content, - 'title': readable_filename - }) - success_or_failure: Dict[str, None | str | Dict[str, str]] = {"success_ingest": None, "failure_ingest": None} - try: - # if not, ingest the text - text = [content] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': '', - 'readable_filename': readable_filename, - 'pagenumber': '', - 'timestamp': '', - 'url': url, - 'base_url': base_url, - }] - self.split_and_upload(texts=text, metadatas=metadatas) - self.posthog.capture('distinct_id_of_the_user', - event='ingest_single_web_text_succeeded', - properties={ - 'course_name': course_name, - 'base_url': base_url, - 'url': url, - 'title': readable_filename - }) - - success_or_failure['success_ingest'] = url - return success_or_failure - except Exception as e: - err = f"❌❌ Error in (web text ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) # type: ignore - print(err) - sentry_sdk.capture_exception(e) - success_or_failure['failure_ingest'] = {'url': url, 'error': str(err)} - return success_or_failure - - def _ingest_single_py(self, s3_path: str, course_name: str, **kwargs): - try: - file_name = s3_path.split("/")[-1] - file_path = "media/" + file_name # download from s3 to local folder for ingest - - self.s3_client.download_file(os.getenv('S3_BUCKET_NAME'), s3_path, file_path) - - loader = PythonLoader(file_path) - documents = loader.load() - - texts = [doc.page_content for doc in documents] - - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - #print(texts) - os.remove(file_path) - - success_or_failure = self.split_and_upload(texts=texts, metadatas=metadatas) - print("Python ingest: ", success_or_failure) - return success_or_failure - - except Exception as e: - err = f"❌❌ Error in (Python ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return err - - def _ingest_single_vtt(self, s3_path: str, course_name: str, **kwargs): - """ - Ingest a single .vtt file from S3. - """ - try: - with NamedTemporaryFile() as tmpfile: - # download from S3 into vtt_tmpfile - self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) - loader = TextLoader(tmpfile.name) - documents = loader.load() - texts = [doc.page_content for doc in documents] - - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - success_or_failure = self.split_and_upload(texts=texts, metadatas=metadatas) - return success_or_failure - except Exception as e: - err = f"❌❌ Error in (VTT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return err - - def _ingest_html(self, s3_path: str, course_name: str, **kwargs) -> str: - print(f"IN _ingest_html s3_path `{s3_path}` kwargs: {kwargs}") - try: - response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) - raw_html = response['Body'].read().decode('utf-8') - - soup = BeautifulSoup(raw_html, 'html.parser') - title = s3_path.replace("courses/" + course_name, "") - title = title.replace(".html", "") - title = title.replace("_", " ") - title = title.replace("/", " ") - title = title.strip() - title = title[37:] # removing the uuid prefix - text = [soup.get_text()] - - metadata: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': str(title), # adding str to avoid error: unhashable type 'slice' - 'url': kwargs.get('url', ''), - 'base_url': kwargs.get('base_url', ''), - 'pagenumber': '', - 'timestamp': '', - }] - - success_or_failure = self.split_and_upload(text, metadata) - print(f"_ingest_html: {success_or_failure}") - return success_or_failure - except Exception as e: - err: str = f"ERROR IN _ingest_html: {e}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore - print(err) - sentry_sdk.capture_exception(e) - return err - - def _ingest_single_video(self, s3_path: str, course_name: str, **kwargs) -> str: - """ - Ingest a single video file from S3. - """ - print("Starting ingest video or audio") - try: - # Ensure the media directory exists - media_dir = "media" - if not os.path.exists(media_dir): - os.makedirs(media_dir) - - # check for file extension - file_ext = Path(s3_path).suffix - openai.api_key = os.getenv('VLADS_OPENAI_KEY') - transcript_list = [] - with NamedTemporaryFile(suffix=file_ext) as video_tmpfile: - # download from S3 into an video tmpfile - self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=video_tmpfile) - # extract audio from video tmpfile - mp4_version = AudioSegment.from_file(video_tmpfile.name, file_ext[1:]) - - # save the extracted audio as a temporary webm file - with NamedTemporaryFile(suffix=".webm", dir=media_dir, delete=False) as webm_tmpfile: - mp4_version.export(webm_tmpfile, format="webm") - - # check file size - file_size = os.path.getsize(webm_tmpfile.name) - # split the audio into 25MB chunks - if file_size > 26214400: - # load the webm file into audio object - full_audio = AudioSegment.from_file(webm_tmpfile.name, "webm") - file_count = file_size // 26214400 + 1 - split_segment = 35 * 60 * 1000 - start = 0 - count = 0 - - while count < file_count: - with NamedTemporaryFile(suffix=".webm", dir=media_dir, delete=False) as split_tmp: - if count == file_count - 1: - # last segment - audio_chunk = full_audio[start:] - else: - audio_chunk = full_audio[start:split_segment] - - audio_chunk.export(split_tmp.name, format="webm") - - # transcribe the split file and store the text in dictionary - with open(split_tmp.name, "rb") as f: - transcript = openai.Audio.transcribe("whisper-1", f) - transcript_list.append(transcript['text']) # type: ignore - start += split_segment - split_segment += split_segment - count += 1 - os.remove(split_tmp.name) - else: - # transcribe the full audio - with open(webm_tmpfile.name, "rb") as f: - transcript = openai.Audio.transcribe("whisper-1", f) - transcript_list.append(transcript['text']) # type: ignore - - os.remove(webm_tmpfile.name) - - text = [txt for txt in transcript_list] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': text.index(txt), - 'url': '', - 'base_url': '', - } for txt in text] - - self.split_and_upload(texts=text, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (VIDEO ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_docx(self, s3_path: str, course_name: str, **kwargs) -> str: - try: - with NamedTemporaryFile() as tmpfile: - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) - - loader = Docx2txtLoader(tmpfile.name) - documents = loader.load() - - texts = [doc.page_content for doc in documents] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (DOCX ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_srt(self, s3_path: str, course_name: str, **kwargs) -> str: - try: - import pysrt - - # NOTE: slightly different method for .txt files, no need for download. It's part of the 'body' - response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) - raw_text = response['Body'].read().decode('utf-8') - - print("UTF-8 text to ingest as SRT:", raw_text) - parsed_info = pysrt.from_string(raw_text) - text = " ".join([t.text for t in parsed_info]) # type: ignore - print(f"Final SRT ingest: {text}") - - texts = [text] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - }] - if len(text) == 0: - return "Error: SRT file appears empty. Skipping." - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (SRT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_excel(self, s3_path: str, course_name: str, **kwargs) -> str: - try: - with NamedTemporaryFile() as tmpfile: - # download from S3 into pdf_tmpfile - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) - - loader = UnstructuredExcelLoader(tmpfile.name, mode="elements") - # loader = SRTLoader(tmpfile.name) - documents = loader.load() - - texts = [doc.page_content for doc in documents] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (Excel/xlsx ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_image(self, s3_path: str, course_name: str, **kwargs) -> str: - try: - with NamedTemporaryFile() as tmpfile: - # download from S3 into pdf_tmpfile - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) - """ - # Unstructured image loader makes the install too large (700MB --> 6GB. 3min -> 12 min build times). AND nobody uses it. - # The "hi_res" strategy will identify the layout of the document using detectron2. "ocr_only" uses pdfminer.six. https://unstructured-io.github.io/unstructured/core/partition.html#partition-image - loader = UnstructuredImageLoader(tmpfile.name, unstructured_kwargs={'strategy': "ocr_only"}) - documents = loader.load() - """ - - res_str = pytesseract.image_to_string(Image.open(tmpfile.name)) - print("IMAGE PARSING RESULT:", res_str) - documents = [Document(page_content=res_str)] - - texts = [doc.page_content for doc in documents] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (png/jpg ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_csv(self, s3_path: str, course_name: str, **kwargs) -> str: - try: - with NamedTemporaryFile() as tmpfile: - # download from S3 into pdf_tmpfile - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) - - loader = CSVLoader(file_path=tmpfile.name) - documents = loader.load() - - texts = [doc.page_content for doc in documents] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (CSV ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_pdf(self, s3_path: str, course_name: str, **kwargs): - """ - Both OCR the PDF. And grab the first image as a PNG. - LangChain `Documents` have .metadata and .page_content attributes. - Be sure to use TemporaryFile() to avoid memory leaks! - """ - print("IN PDF ingest: s3_path: ", s3_path, "and kwargs:", kwargs) - - try: - with NamedTemporaryFile() as pdf_tmpfile: - # download from S3 into pdf_tmpfile - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=pdf_tmpfile) - ### READ OCR of PDF - try: - doc = fitz.open(pdf_tmpfile.name) # type: ignore - except fitz.fitz.EmptyFileError as e: - print(f"Empty PDF file: {s3_path}") - return "Failed ingest: Could not detect ANY text in the PDF. OCR did not help. PDF appears empty of text." - - # improve quality of the image - zoom_x = 2.0 # horizontal zoom - zoom_y = 2.0 # vertical zoom - mat = fitz.Matrix(zoom_x, zoom_y) # zoom factor 2 in each dimension - - pdf_pages_no_OCR: List[Dict] = [] - for i, page in enumerate(doc): # type: ignore - - # UPLOAD FIRST PAGE IMAGE to S3 - if i == 0: - with NamedTemporaryFile(suffix=".png") as first_page_png: - pix = page.get_pixmap(matrix=mat) - pix.save(first_page_png) # store image as a PNG - - s3_upload_path = str(Path(s3_path)).rsplit('.pdf')[0] + "-pg1-thumb.png" - first_page_png.seek(0) # Seek the file pointer back to the beginning - with open(first_page_png.name, 'rb') as f: - print("Uploading image png to S3") - self.s3_client.upload_fileobj(f, os.getenv('S3_BUCKET_NAME'), s3_upload_path) - - # Extract text - text = page.get_text().encode("utf8").decode("utf8", errors='ignore') # get plain text (is in UTF-8) - pdf_pages_no_OCR.append(dict(text=text, page_number=i, readable_filename=Path(s3_path).name[37:])) - - metadatas: List[Dict[str, Any]] = [ - { - 'course_name': course_name, - 's3_path': s3_path, - 'pagenumber': page['page_number'] + 1, # +1 for human indexing - 'timestamp': '', - 'readable_filename': kwargs.get('readable_filename', page['readable_filename']), - 'url': kwargs.get('url', ''), - 'base_url': kwargs.get('base_url', ''), - } for page in pdf_pages_no_OCR - ] - pdf_texts = [page['text'] for page in pdf_pages_no_OCR] - - # count the total number of words in the pdf_texts. If it's less than 100, we'll OCR the PDF - has_words = any(text.strip() for text in pdf_texts) - if has_words: - success_or_failure = self.split_and_upload(texts=pdf_texts, metadatas=metadatas) - else: - print("⚠️ PDF IS EMPTY -- OCR-ing the PDF.") - success_or_failure = self._ocr_pdf(s3_path=s3_path, course_name=course_name, **kwargs) - - return success_or_failure - except Exception as e: - err = f"❌❌ Error in PDF ingest (no OCR): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) # type: ignore - print(err) - sentry_sdk.capture_exception(e) - return err - return "Success" - - def _ocr_pdf(self, s3_path: str, course_name: str, **kwargs): - self.posthog.capture('distinct_id_of_the_user', - event='ocr_pdf_invoked', - properties={ - 'course_name': course_name, - 's3_path': s3_path, - }) - - pdf_pages_OCRed: List[Dict] = [] - try: - with NamedTemporaryFile() as pdf_tmpfile: - # download from S3 into pdf_tmpfile - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=pdf_tmpfile) - - with pdfplumber.open(pdf_tmpfile.name) as pdf: - # for page in : - for i, page in enumerate(pdf.pages): - im = page.to_image() - text = pytesseract.image_to_string(im.original) - print("Page number: ", i, "Text: ", text[:100]) - pdf_pages_OCRed.append(dict(text=text, page_number=i, readable_filename=Path(s3_path).name[37:])) - - metadatas: List[Dict[str, Any]] = [ - { - 'course_name': course_name, - 's3_path': s3_path, - 'pagenumber': page['page_number'] + 1, # +1 for human indexing - 'timestamp': '', - 'readable_filename': kwargs.get('readable_filename', page['readable_filename']), - 'url': kwargs.get('url', ''), - 'base_url': kwargs.get('base_url', ''), - } for page in pdf_pages_OCRed - ] - pdf_texts = [page['text'] for page in pdf_pages_OCRed] - self.posthog.capture('distinct_id_of_the_user', - event='ocr_pdf_succeeded', - properties={ - 'course_name': course_name, - 's3_path': s3_path, - }) - - has_words = any(text.strip() for text in pdf_texts) - if not has_words: - raise ValueError("Failed ingest: Could not detect ANY text in the PDF. OCR did not help. PDF appears empty of text.") - - success_or_failure = self.split_and_upload(texts=pdf_texts, metadatas=metadatas) - return success_or_failure - except Exception as e: - err = f"❌❌ Error in PDF ingest (with OCR): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc() - print(err) - sentry_sdk.capture_exception(e) - return err - - def _ingest_single_txt(self, s3_path: str, course_name: str, **kwargs) -> str: - """Ingest a single .txt or .md file from S3. - Args: - s3_path (str): A path to a .txt file in S3 - course_name (str): The name of the course - Returns: - str: "Success" or an error message - """ - print("In text ingest, UTF-8") - try: - # NOTE: slightly different method for .txt files, no need for download. It's part of the 'body' - response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) - text = response['Body'].read().decode('utf-8') - print("UTF-8 text to ignest (from s3)", text) - text = [text] - - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - }] - print("Prior to ingest", metadatas) - - success_or_failure = self.split_and_upload(texts=text, metadatas=metadatas) - return success_or_failure - except Exception as e: - err = f"❌❌ Error in (TXT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_ppt(self, s3_path: str, course_name: str, **kwargs) -> str: - """ - Ingest a single .ppt or .pptx file from S3. - """ - try: - with NamedTemporaryFile() as tmpfile: - # download from S3 into pdf_tmpfile - #print("in ingest PPTX") - self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) - - loader = UnstructuredPowerPointLoader(tmpfile.name) - documents = loader.load() - - texts = [doc.page_content for doc in documents] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (PPTX ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def ingest_github(self, github_url: str, course_name: str) -> str: - """ - Clones the given GitHub URL and uses Langchain to load data. - 1. Clone the repo - 2. Use Langchain to load the data - 3. Pass to split_and_upload() - Args: - github_url (str): The Github Repo URL to be ingested. - course_name (str): The name of the course in our system. - - Returns: - _type_: Success or error message. - """ - try: - repo_path = "media/cloned_repo" - repo = Repo.clone_from(github_url, to_path=repo_path, depth=1, clone_submodules=False) - branch = repo.head.reference - - loader = GitLoader(repo_path="media/cloned_repo", branch=str(branch)) - data = loader.load() - shutil.rmtree("media/cloned_repo") - # create metadata for each file in data - - for doc in data: - texts = doc.page_content - metadatas: Dict[str, Any] = { - 'course_name': course_name, - 's3_path': '', - 'readable_filename': doc.metadata['file_name'], - 'url': f"{github_url}/blob/main/{doc.metadata['file_path']}", - 'pagenumber': '', - 'timestamp': '', - } - self.split_and_upload(texts=[texts], metadatas=[metadatas]) - return "Success" - except Exception as e: - err = f"❌❌ Error in (GITHUB ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n{traceback.format_exc()}" - print(err) - sentry_sdk.capture_exception(e) - return err - - def split_and_upload(self, texts: List[str], metadatas: List[Dict[str, Any]]): - """ This is usually the last step of document ingest. Chunk & upload to Qdrant (and Supabase.. todo). - Takes in Text and Metadata (from Langchain doc loaders) and splits / uploads to Qdrant. - - good examples here: https://langchain.readthedocs.io/en/latest/modules/utils/combine_docs_examples/textsplitter.html - - Args: - texts (List[str]): _description_ - metadatas (List[Dict[str, Any]]): _description_ - """ - # return "Success" - self.posthog.capture('distinct_id_of_the_user', - event='split_and_upload_invoked', - properties={ - 'course_name': metadatas[0].get('course_name', None), - 's3_path': metadatas[0].get('s3_path', None), - 'readable_filename': metadatas[0].get('readable_filename', None), - 'url': metadatas[0].get('url', None), - 'base_url': metadatas[0].get('base_url', None), - }) - - print(f"In split and upload. Metadatas: {metadatas}") - print(f"Texts: {texts}") - assert len(texts) == len( - metadatas - ), f'must have equal number of text strings and metadata dicts. len(texts) is {len(texts)}. len(metadatas) is {len(metadatas)}' - - try: - text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( - chunk_size=1000, - chunk_overlap=150, - separators=[ - "\n\n", "\n", ". ", " ", "" - ] # try to split on paragraphs... fallback to sentences, then chars, ensure we always fit in context window - ) - contexts: List[Document] = text_splitter.create_documents(texts=texts, metadatas=metadatas) - input_texts = [{'input': context.page_content, 'model': 'text-embedding-ada-002'} for context in contexts] - - # check for duplicates - is_duplicate = self.check_for_duplicates(input_texts, metadatas) - if is_duplicate: - self.posthog.capture('distinct_id_of_the_user', - event='split_and_upload_succeeded', - properties={ - 'course_name': metadatas[0].get('course_name', None), - 's3_path': metadatas[0].get('s3_path', None), - 'readable_filename': metadatas[0].get('readable_filename', None), - 'url': metadatas[0].get('url', None), - 'base_url': metadatas[0].get('base_url', None), - 'is_duplicate': True, - }) - return "Success" - - # adding chunk index to metadata for parent doc retrieval - for i, context in enumerate(contexts): - context.metadata['chunk_index'] = i - - oai = OpenAIAPIProcessor( - input_prompts_list=input_texts, - request_url='https://api.openai.com/v1/embeddings', - api_key=os.getenv('VLADS_OPENAI_KEY'), - # request_url='https://uiuc-chat-canada-east.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2023-05-15', - # api_key=os.getenv('AZURE_OPENAI_KEY'), - max_requests_per_minute=5_000, - max_tokens_per_minute=300_000, - max_attempts=20, - logging_level=logging.INFO, - token_encoding_name='cl100k_base') # nosec -- reasonable bandit error suppression - asyncio.run(oai.process_api_requests_from_file()) - # parse results into dict of shape page_content -> embedding - embeddings_dict: dict[str, List[float]] = { - item[0]['input']: item[1]['data'][0]['embedding'] for item in oai.results - } - - ### BULK upload to Qdrant ### - vectors: list[PointStruct] = [] - for context in contexts: - # !DONE: Updated the payload so each key is top level (no more payload.metadata.course_name. Instead, use payload.course_name), great for creating indexes. - upload_metadata = {**context.metadata, "page_content": context.page_content} - vectors.append( - PointStruct(id=str(uuid.uuid4()), vector=embeddings_dict[context.page_content], payload=upload_metadata)) - - self.qdrant_client.upsert( - collection_name=os.environ['QDRANT_COLLECTION_NAME'], # type: ignore - points=vectors # type: ignore - ) - ### Supabase SQL ### - contexts_for_supa = [{ - "text": context.page_content, - "pagenumber": context.metadata.get('pagenumber'), - "timestamp": context.metadata.get('timestamp'), - "chunk_index": context.metadata.get('chunk_index'), - "embedding": embeddings_dict[context.page_content] - } for context in contexts] - - document = { - "course_name": contexts[0].metadata.get('course_name'), - "s3_path": contexts[0].metadata.get('s3_path'), - "readable_filename": contexts[0].metadata.get('readable_filename'), - "url": contexts[0].metadata.get('url'), - "base_url": contexts[0].metadata.get('base_url'), - "contexts": contexts_for_supa, - } - - response = self.supabase_client.table( - os.getenv('SUPABASE_DOCUMENTS_TABLE')).insert(document).execute() # type: ignore - - # add to Nomic document map - if len(response.data) > 0: - course_name = contexts[0].metadata.get('course_name') - log_to_document_map(course_name) - - self.posthog.capture('distinct_id_of_the_user', - event='split_and_upload_succeeded', - properties={ - 'course_name': metadatas[0].get('course_name', None), - 's3_path': metadatas[0].get('s3_path', None), - 'readable_filename': metadatas[0].get('readable_filename', None), - 'url': metadatas[0].get('url', None), - 'base_url': metadatas[0].get('base_url', None), - }) - print("successful END OF split_and_upload") - return "Success" - except Exception as e: - err: str = f"ERROR IN split_and_upload(): Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore - print(err) - sentry_sdk.capture_exception(e) - return err - - def check_for_duplicates(self, texts: List[Dict], metadatas: List[Dict[str, Any]]) -> bool: - """ - For given metadata, fetch docs from Supabase based on S3 path or URL. - If docs exists, concatenate the texts and compare with current texts, if same, return True. - """ - doc_table = os.getenv('SUPABASE_DOCUMENTS_TABLE', '') - course_name = metadatas[0]['course_name'] - incoming_s3_path = metadatas[0]['s3_path'] - url = metadatas[0]['url'] - original_filename = incoming_s3_path.split('/')[-1][37:] # remove the 37-char uuid prefix - - # check if uuid exists in s3_path -- not all s3_paths have uuids! - incoming_filename = incoming_s3_path.split('/')[-1] - pattern = re.compile(r'[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}', - re.I) # uuid V4 pattern, and v4 only. - if bool(pattern.search(incoming_filename)): - # uuid pattern exists -- remove the uuid and proceed with duplicate checking - original_filename = incoming_filename[37:] - else: - # do not remove anything and proceed with duplicate checking - original_filename = incoming_filename - - if incoming_s3_path: - filename = incoming_s3_path - supabase_contents = self.supabase_client.table(doc_table).select('id', 'contexts', 's3_path').eq( - 'course_name', course_name).like('s3_path', '%' + original_filename + '%').order('id', desc=True).execute() - supabase_contents = supabase_contents.data - elif url: - filename = url - supabase_contents = self.supabase_client.table(doc_table).select('id', 'contexts', 's3_path').eq( - 'course_name', course_name).eq('url', url).order('id', desc=True).execute() - supabase_contents = supabase_contents.data - else: - filename = None - supabase_contents = [] - - supabase_whole_text = "" - if len(supabase_contents) > 0: # if a doc with same filename exists in Supabase - # concatenate texts - supabase_contexts = supabase_contents[0] - for text in supabase_contexts['contexts']: - supabase_whole_text += text['text'] - - current_whole_text = "" - for text in texts: - current_whole_text += text['input'] - - if supabase_whole_text == current_whole_text: # matches the previous file - print(f"Duplicate ingested! πŸ“„ s3_path: {filename}.") - return True - - else: # the file is updated - print(f"Updated file detected! Same filename, new contents. πŸ“„ s3_path: {filename}") - - # call the delete function on older docs - for content in supabase_contents: - print("older s3_path to be deleted: ", content['s3_path']) - delete_status = self.delete_data(course_name, content['s3_path'], '') - print("delete_status: ", delete_status) - return False - - else: # filename does not already exist in Supabase, so its a brand new file - print(f"NOT a duplicate! πŸ“„s3_path: {filename}") - return False - - def delete_data(self, course_name: str, s3_path: str, source_url: str): - """Delete file from S3, Qdrant, and Supabase.""" - print(f"Deleting {s3_path} from S3, Qdrant, and Supabase for course {course_name}") - # add delete from doc map logic here - try: - # Delete file from S3 - bucket_name = os.getenv('S3_BUCKET_NAME') - - # Delete files by S3 path - if s3_path: - try: - self.s3_client.delete_object(Bucket=bucket_name, Key=s3_path) - except Exception as e: - print("Error in deleting file from s3:", e) - sentry_sdk.capture_exception(e) - # Delete from Qdrant - # docs for nested keys: https://qdrant.tech/documentation/concepts/filtering/#nested-key - # Qdrant "points" look like this: Record(id='000295ca-bd28-ac4a-6f8d-c245f7377f90', payload={'metadata': {'course_name': 'zotero-extreme', 'pagenumber_or_timestamp': 15, 'readable_filename': 'Dunlosky et al. - 2013 - Improving Students’ Learning With Effective Learni.pdf', 's3_path': 'courses/zotero-extreme/Dunlosky et al. - 2013 - Improving Students’ Learning With Effective Learni.pdf'}, 'page_content': '18 \nDunlosky et al.\n3.3 Effects in representative educational contexts. Sev-\neral of the large summarization-training studies have been \nconducted in regular classrooms, indicating the feasibility of \ndoing so. For example, the study by A. King (1992) took place \nin the context of a remedial study-skills course for undergrad-\nuates, and the study by Rinehart et al. (1986) took place in \nsixth-grade classrooms, with the instruction led by students \nregular teachers. In these and other cases, students benefited \nfrom the classroom training. We suspect it may actually be \nmore feasible to conduct these kinds of training ... - try: - self.qdrant_client.delete( - collection_name=os.environ['QDRANT_COLLECTION_NAME'], - points_selector=models.Filter(must=[ - models.FieldCondition( - key="s3_path", - match=models.MatchValue(value=s3_path), - ), - ]), - ) - except Exception as e: - if "timed out" in str(e): - # Timed out is fine. Still deletes. - # https://github.com/qdrant/qdrant/issues/3654#issuecomment-1955074525 - pass - else: - print("Error in deleting file from Qdrant:", e) - sentry_sdk.capture_exception(e) - try: - # delete from Nomic - response = self.supabase_client.from_( - os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq('s3_path', s3_path).eq( - 'course_name', course_name).execute() - data = response.data[0] #single record fetched - nomic_ids_to_delete = [] - context_count = len(data['contexts']) - for i in range(1, context_count + 1): - nomic_ids_to_delete.append(str(data['id']) + "_" + str(i)) - - # delete from Nomic - delete_from_document_map(course_name, nomic_ids_to_delete) - except Exception as e: - print("Error in deleting file from Nomic:", e) - sentry_sdk.capture_exception(e) - - try: - self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('s3_path', s3_path).eq( - 'course_name', course_name).execute() - except Exception as e: - print("Error in deleting file from supabase:", e) - sentry_sdk.capture_exception(e) - - # Delete files by their URL identifier - elif source_url: - try: - # Delete from Qdrant - self.qdrant_client.delete( - collection_name=os.environ['QDRANT_COLLECTION_NAME'], - points_selector=models.Filter(must=[ - models.FieldCondition( - key="url", - match=models.MatchValue(value=source_url), - ), - ]), - ) - except Exception as e: - if "timed out" in str(e): - # Timed out is fine. Still deletes. - # https://github.com/qdrant/qdrant/issues/3654#issuecomment-1955074525 - pass - else: - print("Error in deleting file from Qdrant:", e) - sentry_sdk.capture_exception(e) - try: - # delete from Nomic - response = self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, url, contexts").eq( - 'url', source_url).eq('course_name', course_name).execute() - data = response.data[0] #single record fetched - nomic_ids_to_delete = [] - context_count = len(data['contexts']) - for i in range(1, context_count + 1): - nomic_ids_to_delete.append(str(data['id']) + "_" + str(i)) - - # delete from Nomic - delete_from_document_map(course_name, nomic_ids_to_delete) - except Exception as e: - print("Error in deleting file from Nomic:", e) - sentry_sdk.capture_exception(e) - - try: - # delete from Supabase - self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('url', source_url).eq( - 'course_name', course_name).execute() - except Exception as e: - print("Error in deleting file from supabase:", e) - sentry_sdk.capture_exception(e) - - # Delete from Supabase - return "Success" - except Exception as e: - err: str = f"ERROR IN delete_data: Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore - print(err) - sentry_sdk.capture_exception(e) - return err - - # def ingest_coursera(self, coursera_course_name: str, course_name: str) -> str: - # """ Download all the files from a coursera course and ingest them. - - # 1. Download the coursera content. - # 2. Upload to S3 (so users can view it) - # 3. Run everything through the ingest_bulk method. - - # Args: - # coursera_course_name (str): The name of the coursera course. - # course_name (str): The name of the course in our system. - - # Returns: - # _type_: Success or error message. - # """ - # certificate = "-ca 'FVhVoDp5cb-ZaoRr5nNJLYbyjCLz8cGvaXzizqNlQEBsG5wSq7AHScZGAGfC1nI0ehXFvWy1NG8dyuIBF7DLMA.X3cXsDvHcOmSdo3Fyvg27Q.qyGfoo0GOHosTVoSMFy-gc24B-_BIxJtqblTzN5xQWT3hSntTR1DMPgPQKQmfZh_40UaV8oZKKiF15HtZBaLHWLbpEpAgTg3KiTiU1WSdUWueo92tnhz-lcLeLmCQE2y3XpijaN6G4mmgznLGVsVLXb-P3Cibzz0aVeT_lWIJNrCsXrTFh2HzFEhC4FxfTVqS6cRsKVskPpSu8D9EuCQUwJoOJHP_GvcME9-RISBhi46p-Z1IQZAC4qHPDhthIJG4bJqpq8-ZClRL3DFGqOfaiu5y415LJcH--PRRKTBnP7fNWPKhcEK2xoYQLr9RxBVL3pzVPEFyTYtGg6hFIdJcjKOU11AXAnQ-Kw-Gb_wXiHmu63veM6T8N2dEkdqygMre_xMDT5NVaP3xrPbA4eAQjl9yov4tyX4AQWMaCS5OCbGTpMTq2Y4L0Mbz93MHrblM2JL_cBYa59bq7DFK1IgzmOjFhNG266mQlC9juNcEhc'" - # always_use_flags = "-u kastanvday@gmail.com -p hSBsLaF5YM469# --ignore-formats mp4 --subtitle-language en --path ./coursera-dl" - - # try: - # subprocess.run( - # f"coursera-dl {always_use_flags} {certificate} {coursera_course_name}", - # check=True, - # shell=True, # nosec -- reasonable bandit error suppression - # stdout=subprocess.PIPE, - # stderr=subprocess.PIPE) # capture_output=True, - # dl_results_path = os.path.join('coursera-dl', coursera_course_name) - # s3_paths: Union[List, None] = upload_data_files_to_s3(course_name, dl_results_path) - - # if s3_paths is None: - # return "Error: No files found in the coursera-dl directory" - - # print("starting bulk ingest") - # start_time = time.monotonic() - # self.bulk_ingest(s3_paths, course_name) - # print("completed bulk ingest") - # print(f"⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds") - - # # Cleanup the coursera downloads - # shutil.rmtree(dl_results_path) - - # return "Success" - # except Exception as e: - # err: str = f"Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore - # print(err) - # return err - - # def list_files_recursively(self, bucket, prefix): - # all_files = [] - # continuation_token = None - - # while True: - # list_objects_kwargs = { - # 'Bucket': bucket, - # 'Prefix': prefix, - # } - # if continuation_token: - # list_objects_kwargs['ContinuationToken'] = continuation_token - - # response = self.s3_client.list_objects_v2(**list_objects_kwargs) - - # if 'Contents' in response: - # for obj in response['Contents']: - # all_files.append(obj['Key']) - - # if response['IsTruncated']: - # continuation_token = response['NextContinuationToken'] - # else: - # break - - # return all_files - - -if __name__ == "__main__": - raise NotImplementedError("This file is not meant to be run directly") - text = "Testing 123" - # ingest(text=text) +# """ +# To deploy: beam deploy ingest.py --profile caii-ncsa +# Use CAII gmail to auth. +# """ +# import asyncio +# import inspect +# import json +# import logging +# import mimetypes +# import os +# import re +# import shutil +# import time +# import traceback +# import uuid +# from pathlib import Path +# from tempfile import NamedTemporaryFile +# from typing import Any, Callable, Dict, List, Optional, Union + +# import beam +# import boto3 +# import fitz +# import openai +# import pytesseract +# import pdfplumber +# import sentry_sdk +# import supabase +# from beam import App, QueueDepthAutoscaler, Runtime # RequestLatencyAutoscaler, +# from bs4 import BeautifulSoup +# from git.repo import Repo +# from langchain.document_loaders import ( +# Docx2txtLoader, +# GitLoader, +# PythonLoader, +# TextLoader, +# UnstructuredExcelLoader, +# UnstructuredPowerPointLoader, +# ) +# from langchain.document_loaders.csv_loader import CSVLoader +# from langchain.embeddings.openai import OpenAIEmbeddings +# from langchain.schema import Document +# from langchain.text_splitter import RecursiveCharacterTextSplitter +# from langchain.vectorstores import Qdrant +# from nomic_logging import delete_from_document_map, log_to_document_map, rebuild_map +# from OpenaiEmbeddings import OpenAIAPIProcessor +# from PIL import Image +# from posthog import Posthog +# from pydub import AudioSegment +# from qdrant_client import QdrantClient, models +# from qdrant_client.models import PointStruct + +# # from langchain.schema.output_parser import StrOutputParser +# # from langchain.chat_models import AzureChatOpenAI + +# requirements = [ +# "openai<1.0", +# "supabase==2.0.2", +# "tiktoken==0.5.1", +# "boto3==1.28.79", +# "qdrant-client==1.7.3", +# "langchain==0.0.331", +# "posthog==3.1.0", +# "pysrt==1.1.2", +# "docx2txt==0.8", +# "pydub==0.25.1", +# "ffmpeg-python==0.2.0", +# "ffprobe==0.5", +# "ffmpeg==1.4", +# "PyMuPDF==1.23.6", +# "pytesseract==0.3.10", # image OCR" +# "openpyxl==3.1.2", # excel" +# "networkx==3.2.1", # unused part of excel partitioning :(" +# "python-pptx==0.6.23", +# "unstructured==0.10.29", +# "GitPython==3.1.40", +# "beautifulsoup4==4.12.2", +# "sentry-sdk==1.39.1", +# "nomic==2.0.14", +# "pdfplumber==0.11.0", # PDF OCR, better performance than Fitz/PyMuPDF in my Gies PDF testing. +# ] + +# # TODO: consider adding workers. They share CPU and memory https://docs.beam.cloud/deployment/autoscaling#worker-use-cases +# app = App("ingest", +# runtime=Runtime( +# cpu=1, +# memory="3Gi", +# image=beam.Image( +# python_version="python3.10", +# python_packages=requirements, +# commands=["apt-get update && apt-get install -y ffmpeg tesseract-ocr"], +# ), +# )) + +# # MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation") +# OPENAI_API_TYPE = "azure" # "openai" or "azure" + +# def loader(): +# """ +# The loader function will run once for each worker that starts up. https://docs.beam.cloud/deployment/loaders +# """ +# openai.api_key = os.getenv("VLADS_OPENAI_KEY") + +# # vector DB +# qdrant_client = QdrantClient( +# url=os.getenv('QDRANT_URL'), +# api_key=os.getenv('QDRANT_API_KEY'), +# ) + +# vectorstore = Qdrant(client=qdrant_client, +# collection_name=os.environ['QDRANT_COLLECTION_NAME'], +# embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE, +# openai_api_key=os.getenv('VLADS_OPENAI_KEY'))) + +# # S3 +# s3_client = boto3.client( +# 's3', +# aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), +# aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), +# ) + +# # Create a Supabase client +# supabase_client = supabase.create_client( # type: ignore +# supabase_url=os.environ['SUPABASE_URL'], supabase_key=os.environ['SUPABASE_API_KEY']) + +# # llm = AzureChatOpenAI( +# # temperature=0, +# # deployment_name=os.getenv('AZURE_OPENAI_ENGINE'), #type:ignore +# # openai_api_base=os.getenv('AZURE_OPENAI_ENDPOINT'), #type:ignore +# # openai_api_key=os.getenv('AZURE_OPENAI_KEY'), #type:ignore +# # openai_api_version=os.getenv('OPENAI_API_VERSION'), #type:ignore +# # openai_api_type=OPENAI_API_TYPE) + +# posthog = Posthog(sync_mode=True, project_api_key=os.environ['POSTHOG_API_KEY'], host='https://app.posthog.com') +# sentry_sdk.init( +# dsn="https://examplePublicKey@o0.ingest.sentry.io/0", + +# # Enable performance monitoring +# enable_tracing=True, +# ) + +# return qdrant_client, vectorstore, s3_client, supabase_client, posthog + +# # autoscaler = RequestLatencyAutoscaler(desired_latency=30, max_replicas=2) +# autoscaler = QueueDepthAutoscaler(max_tasks_per_replica=300, max_replicas=3) + +# # Triggers determine how your app is deployed +# # @app.rest_api( +# @app.task_queue( +# workers=4, +# callback_url='https://uiuc-chat-git-ingestprogresstracking-kastanday.vercel.app/api/UIUC-api/ingestTaskCallback', +# max_pending_tasks=15_000, +# max_retries=3, +# timeout=-1, +# loader=loader, +# autoscaler=autoscaler) +# def ingest(**inputs: Dict[str, Any]): + +# qdrant_client, vectorstore, s3_client, supabase_client, posthog = inputs["context"] + +# course_name: List[str] | str = inputs.get('course_name', '') +# s3_paths: List[str] | str = inputs.get('s3_paths', '') +# url: List[str] | str | None = inputs.get('url', None) +# base_url: List[str] | str | None = inputs.get('base_url', None) +# readable_filename: List[str] | str = inputs.get('readable_filename', '') +# content: str | None = inputs.get('content', None) # is webtext if content exists + +# logging.info( +# f"In top of /ingest route. course: {course_name}, s3paths: {s3_paths}, readable_filename: {readable_filename}, base_url: {base_url}, url: {url}, content: {content}" +# ) + +# ingester = Ingest(qdrant_client, vectorstore, s3_client, supabase_client, posthog) + +# def run_ingest(course_name, s3_paths, base_url, url, readable_filename, content): +# if content: +# return ingester.ingest_single_web_text(course_name, base_url, url, content, readable_filename) +# elif readable_filename == '': +# return ingester.bulk_ingest(course_name, s3_paths, base_url=base_url, url=url) +# else: +# return ingester.bulk_ingest(course_name, +# s3_paths, +# readable_filename=readable_filename, +# base_url=base_url, +# url=url) + +# # First try +# success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) + +# # retries +# num_retires = 5 +# for retry_num in range(1, num_retires): +# if isinstance(success_fail_dict, str): +# logging.info(f"STRING ERROR: {success_fail_dict = }") +# success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) +# time.sleep(13 * retry_num) # max is 65 +# elif success_fail_dict['failure_ingest']: +# logging.info(f"Ingest failure -- Retry attempt {retry_num}. File: {success_fail_dict}") +# # s3_paths = success_fail_dict['failure_ingest'] # retry only failed paths.... what if this is a URL instead? +# success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) +# time.sleep(13 * retry_num) # max is 65 +# else: +# break + +# # Final failure / success check +# if success_fail_dict['failure_ingest']: +# logging.info(f"INGEST FAILURE -- About to send to supabase. success_fail_dict: {success_fail_dict}") +# document = { +# "course_name": +# course_name, +# "s3_path": +# s3_paths, +# "readable_filename": +# readable_filename, +# "url": +# url, +# "base_url": +# base_url, +# "error": +# success_fail_dict['failure_ingest']['error'] +# if isinstance(success_fail_dict['failure_ingest'], dict) else success_fail_dict['failure_ingest'] +# } +# response = supabase_client.table('documents_failed').insert(document).execute() # type: ignore +# logging.info(f"Supabase ingest failure response: {response}") +# else: +# # Success case: rebuild nomic document map after all ingests are done +# # rebuild_status = rebuild_map(str(course_name), map_type='document') +# pass + +# logging.info(f"Final success_fail_dict: {success_fail_dict}") +# return json.dumps(success_fail_dict) + +# class Ingest(): + +# def __init__(self, qdrant_client, vectorstore, s3_client, supabase_client, posthog): +# self.qdrant_client = qdrant_client +# self.vectorstore = vectorstore +# self.s3_client = s3_client +# self.supabase_client = supabase_client +# self.posthog = posthog + +# def bulk_ingest(self, course_name: str, s3_paths: Union[str, List[str]], +# **kwargs) -> Dict[str, None | str | Dict[str, str]]: +# """ +# Bulk ingest a list of s3 paths into the vectorstore, and also into the supabase database. +# -> Dict[str, str | Dict[str, str]] +# """ + +# def _ingest_single(ingest_method: Callable, s3_path, *args, **kwargs): +# """Handle running an arbitrary ingest function for an individual file.""" +# # RUN INGEST METHOD +# ret = ingest_method(s3_path, *args, **kwargs) +# if ret == "Success": +# success_status['success_ingest'] = str(s3_path) +# else: +# success_status['failure_ingest'] = {'s3_path': str(s3_path), 'error': str(ret)} + +# # πŸ‘‡πŸ‘‡πŸ‘‡πŸ‘‡ ADD NEW INGEST METHODS HERE πŸ‘‡πŸ‘‡πŸ‘‡πŸ‘‡πŸŽ‰ +# file_ingest_methods = { +# '.html': self._ingest_html, +# '.py': self._ingest_single_py, +# '.pdf': self._ingest_single_pdf, +# '.txt': self._ingest_single_txt, +# '.md': self._ingest_single_txt, +# '.srt': self._ingest_single_srt, +# '.vtt': self._ingest_single_vtt, +# '.docx': self._ingest_single_docx, +# '.ppt': self._ingest_single_ppt, +# '.pptx': self._ingest_single_ppt, +# '.xlsx': self._ingest_single_excel, +# '.xls': self._ingest_single_excel, +# '.csv': self._ingest_single_csv, +# '.png': self._ingest_single_image, +# '.jpg': self._ingest_single_image, +# } + +# # Ingest methods via MIME type (more general than filetype) +# mimetype_ingest_methods = { +# 'video': self._ingest_single_video, +# 'audio': self._ingest_single_video, +# 'text': self._ingest_single_txt, +# 'image': self._ingest_single_image, +# } +# # πŸ‘†πŸ‘†πŸ‘†πŸ‘† ADD NEW INGEST METHODhe πŸ‘†πŸ‘†πŸ‘†πŸ‘†πŸŽ‰ + +# logging.info(f"Top of ingest, Course_name {course_name}. S3 paths {s3_paths}") +# success_status: Dict[str, None | str | Dict[str, str]] = {"success_ingest": None, "failure_ingest": None} +# try: +# if isinstance(s3_paths, str): +# s3_paths = [s3_paths] + +# for s3_path in s3_paths: +# file_extension = Path(s3_path).suffix +# with NamedTemporaryFile(suffix=file_extension) as tmpfile: +# self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) +# mime_type = str(mimetypes.guess_type(tmpfile.name, strict=False)[0]) +# mime_category = mime_type.split('/')[0] if '/' in mime_type else mime_type + +# if file_extension in file_ingest_methods: +# # Use specialized functions when possible, fallback to mimetype. Else raise error. +# ingest_method = file_ingest_methods[file_extension] +# _ingest_single(ingest_method, s3_path, course_name, **kwargs) +# elif mime_category in mimetype_ingest_methods: +# # fallback to MimeType +# logging.info("mime category", mime_category) +# ingest_method = mimetype_ingest_methods[mime_category] +# _ingest_single(ingest_method, s3_path, course_name, **kwargs) +# else: +# # No supported ingest... Fallback to attempting utf-8 decoding, otherwise fail. +# try: +# self._ingest_single_txt(s3_path, course_name) +# success_status['success_ingest'] = s3_path +# logging.info(f"No ingest methods -- Falling back to UTF-8 INGEST... s3_path = {s3_path}") +# except Exception as e: +# logging.info( +# f"We don't have a ingest method for this filetype: {file_extension}. As a last-ditch effort, we tried to ingest the file as utf-8 text, but that failed too. File is unsupported: {s3_path}. UTF-8 ingest error: {e}" +# ) +# success_status['failure_ingest'] = { +# 's3_path': +# s3_path, +# 'error': +# f"We don't have a ingest method for this filetype: {file_extension} (with generic type {mime_type}), for file: {s3_path}" +# } +# self.posthog.capture( +# 'distinct_id_of_the_user', +# event='ingest_failure', +# properties={ +# 'course_name': +# course_name, +# 's3_path': +# s3_paths, +# 'kwargs': +# kwargs, +# 'error': +# f"We don't have a ingest method for this filetype: {file_extension} (with generic type {mime_type}), for file: {s3_path}" +# }) + +# return success_status +# except Exception as e: +# err = f"❌❌ Error in /ingest: `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) # type: ignore + +# success_status['failure_ingest'] = {'s3_path': s3_path, 'error': f"MAJOR ERROR DURING INGEST: {err}"} +# self.posthog.capture('distinct_id_of_the_user', +# event='ingest_failure', +# properties={ +# 'course_name': course_name, +# 's3_path': s3_paths, +# 'kwargs': kwargs, +# 'error': err +# }) + +# sentry_sdk.capture_exception(e) +# logging.info(f"MAJOR ERROR IN /bulk_ingest: {str(e)}") +# return success_status + +# def ingest_single_web_text(self, course_name: str, base_url: str, url: str, content: str, readable_filename: str): +# """Crawlee integration +# """ +# self.posthog.capture('distinct_id_of_the_user', +# event='ingest_single_web_text_invoked', +# properties={ +# 'course_name': course_name, +# 'base_url': base_url, +# 'url': url, +# 'content': content, +# 'title': readable_filename +# }) +# success_or_failure: Dict[str, None | str | Dict[str, str]] = {"success_ingest": None, "failure_ingest": None} +# try: +# # if not, ingest the text +# text = [content] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': '', +# 'readable_filename': readable_filename, +# 'pagenumber': '', +# 'timestamp': '', +# 'url': url, +# 'base_url': base_url, +# }] +# self.split_and_upload(texts=text, metadatas=metadatas) +# self.posthog.capture('distinct_id_of_the_user', +# event='ingest_single_web_text_succeeded', +# properties={ +# 'course_name': course_name, +# 'base_url': base_url, +# 'url': url, +# 'title': readable_filename +# }) + +# success_or_failure['success_ingest'] = url +# return success_or_failure +# except Exception as e: +# err = f"❌❌ Error in (web text ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) # type: ignore +# logging.info(err) +# sentry_sdk.capture_exception(e) +# success_or_failure['failure_ingest'] = {'url': url, 'error': str(err)} +# return success_or_failure + +# def _ingest_single_py(self, s3_path: str, course_name: str, **kwargs): +# try: +# file_name = s3_path.split("/")[-1] +# file_path = "media/" + file_name # download from s3 to local folder for ingest + +# self.s3_client.download_file(os.getenv('S3_BUCKET_NAME'), s3_path, file_path) + +# loader = PythonLoader(file_path) +# documents = loader.load() + +# texts = [doc.page_content for doc in documents] + +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] +# #logging.info(texts) +# os.remove(file_path) + +# success_or_failure = self.split_and_upload(texts=texts, metadatas=metadatas) +# logging.info("Python ingest: ", success_or_failure) +# return success_or_failure + +# except Exception as e: +# err = f"❌❌ Error in (Python ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return err + +# def _ingest_single_vtt(self, s3_path: str, course_name: str, **kwargs): +# """ +# Ingest a single .vtt file from S3. +# """ +# try: +# with NamedTemporaryFile() as tmpfile: +# # download from S3 into vtt_tmpfile +# self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) +# loader = TextLoader(tmpfile.name) +# documents = loader.load() +# texts = [doc.page_content for doc in documents] + +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# success_or_failure = self.split_and_upload(texts=texts, metadatas=metadatas) +# return success_or_failure +# except Exception as e: +# err = f"❌❌ Error in (VTT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return err + +# def _ingest_html(self, s3_path: str, course_name: str, **kwargs) -> str: +# logging.info(f"IN _ingest_html s3_path `{s3_path}` kwargs: {kwargs}") +# try: +# response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) +# raw_html = response['Body'].read().decode('utf-8') + +# soup = BeautifulSoup(raw_html, 'html.parser') +# title = s3_path.replace("courses/" + course_name, "") +# title = title.replace(".html", "") +# title = title.replace("_", " ") +# title = title.replace("/", " ") +# title = title.strip() +# title = title[37:] # removing the uuid prefix +# text = [soup.get_text()] + +# metadata: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': str(title), # adding str to avoid error: unhashable type 'slice' +# 'url': kwargs.get('url', ''), +# 'base_url': kwargs.get('base_url', ''), +# 'pagenumber': '', +# 'timestamp': '', +# }] + +# success_or_failure = self.split_and_upload(text, metadata) +# logging.info(f"_ingest_html: {success_or_failure}") +# return success_or_failure +# except Exception as e: +# err: str = f"ERROR IN _ingest_html: {e}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return err + +# def _ingest_single_video(self, s3_path: str, course_name: str, **kwargs) -> str: +# """ +# Ingest a single video file from S3. +# """ +# logging.info("Starting ingest video or audio") +# try: +# # Ensure the media directory exists +# media_dir = "media" +# if not os.path.exists(media_dir): +# os.makedirs(media_dir) + +# # check for file extension +# file_ext = Path(s3_path).suffix +# openai.api_key = os.getenv('VLADS_OPENAI_KEY') +# transcript_list = [] +# with NamedTemporaryFile(suffix=file_ext) as video_tmpfile: +# # download from S3 into an video tmpfile +# self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=video_tmpfile) +# # extract audio from video tmpfile +# mp4_version = AudioSegment.from_file(video_tmpfile.name, file_ext[1:]) + +# # save the extracted audio as a temporary webm file +# with NamedTemporaryFile(suffix=".webm", dir=media_dir, delete=False) as webm_tmpfile: +# mp4_version.export(webm_tmpfile, format="webm") + +# # check file size +# file_size = os.path.getsize(webm_tmpfile.name) +# # split the audio into 25MB chunks +# if file_size > 26214400: +# # load the webm file into audio object +# full_audio = AudioSegment.from_file(webm_tmpfile.name, "webm") +# file_count = file_size // 26214400 + 1 +# split_segment = 35 * 60 * 1000 +# start = 0 +# count = 0 + +# while count < file_count: +# with NamedTemporaryFile(suffix=".webm", dir=media_dir, delete=False) as split_tmp: +# if count == file_count - 1: +# # last segment +# audio_chunk = full_audio[start:] +# else: +# audio_chunk = full_audio[start:split_segment] + +# audio_chunk.export(split_tmp.name, format="webm") + +# # transcribe the split file and store the text in dictionary +# with open(split_tmp.name, "rb") as f: +# transcript = openai.Audio.transcribe("whisper-1", f) +# transcript_list.append(transcript['text']) # type: ignore +# start += split_segment +# split_segment += split_segment +# count += 1 +# os.remove(split_tmp.name) +# else: +# # transcribe the full audio +# with open(webm_tmpfile.name, "rb") as f: +# transcript = openai.Audio.transcribe("whisper-1", f) +# transcript_list.append(transcript['text']) # type: ignore + +# os.remove(webm_tmpfile.name) + +# text = [txt for txt in transcript_list] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': text.index(txt), +# 'url': '', +# 'base_url': '', +# } for txt in text] + +# self.split_and_upload(texts=text, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (VIDEO ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_docx(self, s3_path: str, course_name: str, **kwargs) -> str: +# try: +# with NamedTemporaryFile() as tmpfile: +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) + +# loader = Docx2txtLoader(tmpfile.name) +# documents = loader.load() + +# texts = [doc.page_content for doc in documents] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (DOCX ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_srt(self, s3_path: str, course_name: str, **kwargs) -> str: +# try: +# import pysrt + +# # NOTE: slightly different method for .txt files, no need for download. It's part of the 'body' +# response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) +# raw_text = response['Body'].read().decode('utf-8') + +# logging.info("UTF-8 text to ingest as SRT:", raw_text) +# parsed_info = pysrt.from_string(raw_text) +# text = " ".join([t.text for t in parsed_info]) # type: ignore +# logging.info(f"Final SRT ingest: {text}") + +# texts = [text] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# }] +# if len(text) == 0: +# return "Error: SRT file appears empty. Skipping." + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (SRT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_excel(self, s3_path: str, course_name: str, **kwargs) -> str: +# try: +# with NamedTemporaryFile() as tmpfile: +# # download from S3 into pdf_tmpfile +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) + +# loader = UnstructuredExcelLoader(tmpfile.name, mode="elements") +# # loader = SRTLoader(tmpfile.name) +# documents = loader.load() + +# texts = [doc.page_content for doc in documents] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (Excel/xlsx ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_image(self, s3_path: str, course_name: str, **kwargs) -> str: +# try: +# with NamedTemporaryFile() as tmpfile: +# # download from S3 into pdf_tmpfile +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) +# """ +# # Unstructured image loader makes the install too large (700MB --> 6GB. 3min -> 12 min build times). AND nobody uses it. +# # The "hi_res" strategy will identify the layout of the document using detectron2. "ocr_only" uses pdfminer.six. https://unstructured-io.github.io/unstructured/core/partition.html#partition-image +# loader = UnstructuredImageLoader(tmpfile.name, unstructured_kwargs={'strategy': "ocr_only"}) +# documents = loader.load() +# """ + +# res_str = pytesseract.image_to_string(Image.open(tmpfile.name)) +# logging.info("IMAGE PARSING RESULT:", res_str) +# documents = [Document(page_content=res_str)] + +# texts = [doc.page_content for doc in documents] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (png/jpg ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_csv(self, s3_path: str, course_name: str, **kwargs) -> str: +# try: +# with NamedTemporaryFile() as tmpfile: +# # download from S3 into pdf_tmpfile +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) + +# loader = CSVLoader(file_path=tmpfile.name) +# documents = loader.load() + +# texts = [doc.page_content for doc in documents] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (CSV ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_pdf(self, s3_path: str, course_name: str, **kwargs): +# """ +# Both OCR the PDF. And grab the first image as a PNG. +# LangChain `Documents` have .metadata and .page_content attributes. +# Be sure to use TemporaryFile() to avoid memory leaks! +# """ +# logging.info("IN PDF ingest: s3_path: ", s3_path, "and kwargs:", kwargs) + +# try: +# with NamedTemporaryFile() as pdf_tmpfile: +# # download from S3 into pdf_tmpfile +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=pdf_tmpfile) +# ### READ OCR of PDF +# try: +# doc = fitz.open(pdf_tmpfile.name) # type: ignore +# except fitz.fitz.EmptyFileError as e: +# logging.info(f"Empty PDF file: {s3_path}") +# return "Failed ingest: Could not detect ANY text in the PDF. OCR did not help. PDF appears empty of text." + +# # improve quality of the image +# zoom_x = 2.0 # horizontal zoom +# zoom_y = 2.0 # vertical zoom +# mat = fitz.Matrix(zoom_x, zoom_y) # zoom factor 2 in each dimension + +# pdf_pages_no_OCR: List[Dict] = [] +# for i, page in enumerate(doc): # type: ignore + +# # UPLOAD FIRST PAGE IMAGE to S3 +# if i == 0: +# with NamedTemporaryFile(suffix=".png") as first_page_png: +# pix = page.get_pixmap(matrix=mat) +# pix.save(first_page_png) # store image as a PNG + +# s3_upload_path = str(Path(s3_path)).rsplit('.pdf')[0] + "-pg1-thumb.png" +# first_page_png.seek(0) # Seek the file pointer back to the beginning +# with open(first_page_png.name, 'rb') as f: +# logging.info("Uploading image png to S3") +# self.s3_client.upload_fileobj(f, os.getenv('S3_BUCKET_NAME'), s3_upload_path) + +# # Extract text +# text = page.get_text().encode("utf8").decode("utf8", errors='ignore') # get plain text (is in UTF-8) +# pdf_pages_no_OCR.append(dict(text=text, page_number=i, readable_filename=Path(s3_path).name[37:])) + +# metadatas: List[Dict[str, Any]] = [ +# { +# 'course_name': course_name, +# 's3_path': s3_path, +# 'pagenumber': page['page_number'] + 1, # +1 for human indexing +# 'timestamp': '', +# 'readable_filename': kwargs.get('readable_filename', page['readable_filename']), +# 'url': kwargs.get('url', ''), +# 'base_url': kwargs.get('base_url', ''), +# } for page in pdf_pages_no_OCR +# ] +# pdf_texts = [page['text'] for page in pdf_pages_no_OCR] + +# # count the total number of words in the pdf_texts. If it's less than 100, we'll OCR the PDF +# has_words = any(text.strip() for text in pdf_texts) +# if has_words: +# success_or_failure = self.split_and_upload(texts=pdf_texts, metadatas=metadatas) +# else: +# logging.info("⚠️ PDF IS EMPTY -- OCR-ing the PDF.") +# success_or_failure = self._ocr_pdf(s3_path=s3_path, course_name=course_name, **kwargs) + +# return success_or_failure +# except Exception as e: +# err = f"❌❌ Error in PDF ingest (no OCR): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) # type: ignore +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return err +# return "Success" + +# def _ocr_pdf(self, s3_path: str, course_name: str, **kwargs): +# self.posthog.capture('distinct_id_of_the_user', +# event='ocr_pdf_invoked', +# properties={ +# 'course_name': course_name, +# 's3_path': s3_path, +# }) + +# pdf_pages_OCRed: List[Dict] = [] +# try: +# with NamedTemporaryFile() as pdf_tmpfile: +# # download from S3 into pdf_tmpfile +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=pdf_tmpfile) + +# with pdfplumber.open(pdf_tmpfile.name) as pdf: +# # for page in : +# for i, page in enumerate(pdf.pages): +# im = page.to_image() +# text = pytesseract.image_to_string(im.original) +# logging.info("Page number: ", i, "Text: ", text[:100]) +# pdf_pages_OCRed.append(dict(text=text, page_number=i, readable_filename=Path(s3_path).name[37:])) + +# metadatas: List[Dict[str, Any]] = [ +# { +# 'course_name': course_name, +# 's3_path': s3_path, +# 'pagenumber': page['page_number'] + 1, # +1 for human indexing +# 'timestamp': '', +# 'readable_filename': kwargs.get('readable_filename', page['readable_filename']), +# 'url': kwargs.get('url', ''), +# 'base_url': kwargs.get('base_url', ''), +# } for page in pdf_pages_OCRed +# ] +# pdf_texts = [page['text'] for page in pdf_pages_OCRed] +# self.posthog.capture('distinct_id_of_the_user', +# event='ocr_pdf_succeeded', +# properties={ +# 'course_name': course_name, +# 's3_path': s3_path, +# }) + +# has_words = any(text.strip() for text in pdf_texts) +# if not has_words: +# raise ValueError("Failed ingest: Could not detect ANY text in the PDF. OCR did not help. PDF appears empty of text.") + +# success_or_failure = self.split_and_upload(texts=pdf_texts, metadatas=metadatas) +# return success_or_failure +# except Exception as e: +# err = f"❌❌ Error in PDF ingest (with OCR): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc() +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return err + +# def _ingest_single_txt(self, s3_path: str, course_name: str, **kwargs) -> str: +# """Ingest a single .txt or .md file from S3. +# Args: +# s3_path (str): A path to a .txt file in S3 +# course_name (str): The name of the course +# Returns: +# str: "Success" or an error message +# """ +# logging.info("In text ingest, UTF-8") +# try: +# # NOTE: slightly different method for .txt files, no need for download. It's part of the 'body' +# response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) +# text = response['Body'].read().decode('utf-8') +# logging.info("UTF-8 text to ignest (from s3)", text) +# text = [text] + +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# }] +# logging.info("Prior to ingest", metadatas) + +# success_or_failure = self.split_and_upload(texts=text, metadatas=metadatas) +# return success_or_failure +# except Exception as e: +# err = f"❌❌ Error in (TXT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_ppt(self, s3_path: str, course_name: str, **kwargs) -> str: +# """ +# Ingest a single .ppt or .pptx file from S3. +# """ +# try: +# with NamedTemporaryFile() as tmpfile: +# # download from S3 into pdf_tmpfile +# #logging.info("in ingest PPTX") +# self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) + +# loader = UnstructuredPowerPointLoader(tmpfile.name) +# documents = loader.load() + +# texts = [doc.page_content for doc in documents] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (PPTX ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def ingest_github(self, github_url: str, course_name: str) -> str: +# """ +# Clones the given GitHub URL and uses Langchain to load data. +# 1. Clone the repo +# 2. Use Langchain to load the data +# 3. Pass to split_and_upload() +# Args: +# github_url (str): The Github Repo URL to be ingested. +# course_name (str): The name of the course in our system. + +# Returns: +# _type_: Success or error message. +# """ +# try: +# repo_path = "media/cloned_repo" +# repo = Repo.clone_from(github_url, to_path=repo_path, depth=1, clone_submodules=False) +# branch = repo.head.reference + +# loader = GitLoader(repo_path="media/cloned_repo", branch=str(branch)) +# data = loader.load() +# shutil.rmtree("media/cloned_repo") +# # create metadata for each file in data + +# for doc in data: +# texts = doc.page_content +# metadatas: Dict[str, Any] = { +# 'course_name': course_name, +# 's3_path': '', +# 'readable_filename': doc.metadata['file_name'], +# 'url': f"{github_url}/blob/main/{doc.metadata['file_path']}", +# 'pagenumber': '', +# 'timestamp': '', +# } +# self.split_and_upload(texts=[texts], metadatas=[metadatas]) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (GITHUB ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n{traceback.format_exc()}" +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return err + +# def split_and_upload(self, texts: List[str], metadatas: List[Dict[str, Any]]): +# """ This is usually the last step of document ingest. Chunk & upload to Qdrant (and Supabase.. todo). +# Takes in Text and Metadata (from Langchain doc loaders) and splits / uploads to Qdrant. + +# good examples here: https://langchain.readthedocs.io/en/latest/modules/utils/combine_docs_examples/textsplitter.html + +# Args: +# texts (List[str]): _description_ +# metadatas (List[Dict[str, Any]]): _description_ +# """ +# # return "Success" +# self.posthog.capture('distinct_id_of_the_user', +# event='split_and_upload_invoked', +# properties={ +# 'course_name': metadatas[0].get('course_name', None), +# 's3_path': metadatas[0].get('s3_path', None), +# 'readable_filename': metadatas[0].get('readable_filename', None), +# 'url': metadatas[0].get('url', None), +# 'base_url': metadatas[0].get('base_url', None), +# }) + +# logging.info(f"In split and upload. Metadatas: {metadatas}") +# logging.info(f"Texts: {texts}") +# assert len(texts) == len( +# metadatas +# ), f'must have equal number of text strings and metadata dicts. len(texts) is {len(texts)}. len(metadatas) is {len(metadatas)}' + +# try: +# text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( +# chunk_size=1000, +# chunk_overlap=150, +# separators=[ +# "\n\n", "\n", ". ", " ", "" +# ] # try to split on paragraphs... fallback to sentences, then chars, ensure we always fit in context window +# ) +# contexts: List[Document] = text_splitter.create_documents(texts=texts, metadatas=metadatas) +# input_texts = [{'input': context.page_content, 'model': 'text-embedding-ada-002'} for context in contexts] + +# # check for duplicates +# is_duplicate = self.check_for_duplicates(input_texts, metadatas) +# if is_duplicate: +# self.posthog.capture('distinct_id_of_the_user', +# event='split_and_upload_succeeded', +# properties={ +# 'course_name': metadatas[0].get('course_name', None), +# 's3_path': metadatas[0].get('s3_path', None), +# 'readable_filename': metadatas[0].get('readable_filename', None), +# 'url': metadatas[0].get('url', None), +# 'base_url': metadatas[0].get('base_url', None), +# 'is_duplicate': True, +# }) +# return "Success" + +# # adding chunk index to metadata for parent doc retrieval +# for i, context in enumerate(contexts): +# context.metadata['chunk_index'] = i + +# oai = OpenAIAPIProcessor( +# input_prompts_list=input_texts, +# request_url='https://api.openai.com/v1/embeddings', +# api_key=os.getenv('VLADS_OPENAI_KEY'), +# # request_url='https://uiuc-chat-canada-east.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2023-05-15', +# # api_key=os.getenv('AZURE_OPENAI_KEY'), +# max_requests_per_minute=5_000, +# max_tokens_per_minute=300_000, +# max_attempts=20, +# logging_level=logging.INFO, +# token_encoding_name='cl100k_base') # nosec -- reasonable bandit error suppression +# asyncio.run(oai.process_api_requests_from_file()) +# # parse results into dict of shape page_content -> embedding +# embeddings_dict: dict[str, List[float]] = { +# item[0]['input']: item[1]['data'][0]['embedding'] for item in oai.results +# } + +# ### BULK upload to Qdrant ### +# vectors: list[PointStruct] = [] +# for context in contexts: +# # !DONE: Updated the payload so each key is top level (no more payload.metadata.course_name. Instead, use payload.course_name), great for creating indexes. +# upload_metadata = {**context.metadata, "page_content": context.page_content} +# vectors.append( +# PointStruct(id=str(uuid.uuid4()), vector=embeddings_dict[context.page_content], payload=upload_metadata)) + +# self.qdrant_client.upsert( +# collection_name=os.environ['QDRANT_COLLECTION_NAME'], # type: ignore +# points=vectors # type: ignore +# ) +# ### Supabase SQL ### +# contexts_for_supa = [{ +# "text": context.page_content, +# "pagenumber": context.metadata.get('pagenumber'), +# "timestamp": context.metadata.get('timestamp'), +# "chunk_index": context.metadata.get('chunk_index'), +# "embedding": embeddings_dict[context.page_content] +# } for context in contexts] + +# document = { +# "course_name": contexts[0].metadata.get('course_name'), +# "s3_path": contexts[0].metadata.get('s3_path'), +# "readable_filename": contexts[0].metadata.get('readable_filename'), +# "url": contexts[0].metadata.get('url'), +# "base_url": contexts[0].metadata.get('base_url'), +# "contexts": contexts_for_supa, +# } + +# response = self.supabase_client.table( +# os.getenv('SUPABASE_DOCUMENTS_TABLE')).insert(document).execute() # type: ignore + +# # add to Nomic document map +# if len(response.data) > 0: +# course_name = contexts[0].metadata.get('course_name') +# log_to_document_map(course_name) + +# self.posthog.capture('distinct_id_of_the_user', +# event='split_and_upload_succeeded', +# properties={ +# 'course_name': metadatas[0].get('course_name', None), +# 's3_path': metadatas[0].get('s3_path', None), +# 'readable_filename': metadatas[0].get('readable_filename', None), +# 'url': metadatas[0].get('url', None), +# 'base_url': metadatas[0].get('base_url', None), +# }) +# logging.info("successful END OF split_and_upload") +# return "Success" +# except Exception as e: +# err: str = f"ERROR IN split_and_upload(): Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return err + +# def check_for_duplicates(self, texts: List[Dict], metadatas: List[Dict[str, Any]]) -> bool: +# """ +# For given metadata, fetch docs from Supabase based on S3 path or URL. +# If docs exists, concatenate the texts and compare with current texts, if same, return True. +# """ +# doc_table = os.getenv('SUPABASE_DOCUMENTS_TABLE', '') +# course_name = metadatas[0]['course_name'] +# incoming_s3_path = metadatas[0]['s3_path'] +# url = metadatas[0]['url'] +# original_filename = incoming_s3_path.split('/')[-1][37:] # remove the 37-char uuid prefix + +# # check if uuid exists in s3_path -- not all s3_paths have uuids! +# incoming_filename = incoming_s3_path.split('/')[-1] +# pattern = re.compile(r'[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}', +# re.I) # uuid V4 pattern, and v4 only. +# if bool(pattern.search(incoming_filename)): +# # uuid pattern exists -- remove the uuid and proceed with duplicate checking +# original_filename = incoming_filename[37:] +# else: +# # do not remove anything and proceed with duplicate checking +# original_filename = incoming_filename + +# if incoming_s3_path: +# filename = incoming_s3_path +# supabase_contents = self.supabase_client.table(doc_table).select('id', 'contexts', 's3_path').eq( +# 'course_name', course_name).like('s3_path', '%' + original_filename + '%').order('id', desc=True).execute() +# supabase_contents = supabase_contents.data +# elif url: +# filename = url +# supabase_contents = self.supabase_client.table(doc_table).select('id', 'contexts', 's3_path').eq( +# 'course_name', course_name).eq('url', url).order('id', desc=True).execute() +# supabase_contents = supabase_contents.data +# else: +# filename = None +# supabase_contents = [] + +# supabase_whole_text = "" +# if len(supabase_contents) > 0: # if a doc with same filename exists in Supabase +# # concatenate texts +# supabase_contexts = supabase_contents[0] +# for text in supabase_contexts['contexts']: +# supabase_whole_text += text['text'] + +# current_whole_text = "" +# for text in texts: +# current_whole_text += text['input'] + +# if supabase_whole_text == current_whole_text: # matches the previous file +# logging.info(f"Duplicate ingested! πŸ“„ s3_path: {filename}.") +# return True + +# else: # the file is updated +# logging.info(f"Updated file detected! Same filename, new contents. πŸ“„ s3_path: {filename}") + +# # call the delete function on older docs +# for content in supabase_contents: +# logging.info("older s3_path to be deleted: ", content['s3_path']) +# delete_status = self.delete_data(course_name, content['s3_path'], '') +# logging.info("delete_status: ", delete_status) +# return False + +# else: # filename does not already exist in Supabase, so its a brand new file +# logging.info(f"NOT a duplicate! πŸ“„s3_path: {filename}") +# return False + +# def delete_data(self, course_name: str, s3_path: str, source_url: str): +# """Delete file from S3, Qdrant, and Supabase.""" +# logging.info(f"Deleting {s3_path} from S3, Qdrant, and Supabase for course {course_name}") +# # add delete from doc map logic here +# try: +# # Delete file from S3 +# bucket_name = os.getenv('S3_BUCKET_NAME') + +# # Delete files by S3 path +# if s3_path: +# try: +# self.s3_client.delete_object(Bucket=bucket_name, Key=s3_path) +# except Exception as e: +# logging.info("Error in deleting file from s3:", e) +# sentry_sdk.capture_exception(e) +# # Delete from Qdrant +# # docs for nested keys: https://qdrant.tech/documentation/concepts/filtering/#nested-key +# # Qdrant "points" look like this: Record(id='000295ca-bd28-ac4a-6f8d-c245f7377f90', payload={'metadata': {'course_name': 'zotero-extreme', 'pagenumber_or_timestamp': 15, 'readable_filename': 'Dunlosky et al. - 2013 - Improving Students’ Learning With Effective Learni.pdf', 's3_path': 'courses/zotero-extreme/Dunlosky et al. - 2013 - Improving Students’ Learning With Effective Learni.pdf'}, 'page_content': '18 \nDunlosky et al.\n3.3 Effects in representative educational contexts. Sev-\neral of the large summarization-training studies have been \nconducted in regular classrooms, indicating the feasibility of \ndoing so. For example, the study by A. King (1992) took place \nin the context of a remedial study-skills course for undergrad-\nuates, and the study by Rinehart et al. (1986) took place in \nsixth-grade classrooms, with the instruction led by students \nregular teachers. In these and other cases, students benefited \nfrom the classroom training. We suspect it may actually be \nmore feasible to conduct these kinds of training ... +# try: +# self.qdrant_client.delete( +# collection_name=os.environ['QDRANT_COLLECTION_NAME'], +# points_selector=models.Filter(must=[ +# models.FieldCondition( +# key="s3_path", +# match=models.MatchValue(value=s3_path), +# ), +# ]), +# ) +# except Exception as e: +# if "timed out" in str(e): +# # Timed out is fine. Still deletes. +# # https://github.com/qdrant/qdrant/issues/3654#issuecomment-1955074525 +# pass +# else: +# logging.info("Error in deleting file from Qdrant:", e) +# sentry_sdk.capture_exception(e) +# try: +# # delete from Nomic +# response = self.supabase_client.from_( +# os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq('s3_path', s3_path).eq( +# 'course_name', course_name).execute() +# data = response.data[0] #single record fetched +# nomic_ids_to_delete = [] +# context_count = len(data['contexts']) +# for i in range(1, context_count + 1): +# nomic_ids_to_delete.append(str(data['id']) + "_" + str(i)) + +# # delete from Nomic +# delete_from_document_map(course_name, nomic_ids_to_delete) +# except Exception as e: +# logging.info("Error in deleting file from Nomic:", e) +# sentry_sdk.capture_exception(e) + +# try: +# self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('s3_path', s3_path).eq( +# 'course_name', course_name).execute() +# except Exception as e: +# logging.info("Error in deleting file from supabase:", e) +# sentry_sdk.capture_exception(e) + +# # Delete files by their URL identifier +# elif source_url: +# try: +# # Delete from Qdrant +# self.qdrant_client.delete( +# collection_name=os.environ['QDRANT_COLLECTION_NAME'], +# points_selector=models.Filter(must=[ +# models.FieldCondition( +# key="url", +# match=models.MatchValue(value=source_url), +# ), +# ]), +# ) +# except Exception as e: +# if "timed out" in str(e): +# # Timed out is fine. Still deletes. +# # https://github.com/qdrant/qdrant/issues/3654#issuecomment-1955074525 +# pass +# else: +# logging.info("Error in deleting file from Qdrant:", e) +# sentry_sdk.capture_exception(e) +# try: +# # delete from Nomic +# response = self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, url, contexts").eq( +# 'url', source_url).eq('course_name', course_name).execute() +# data = response.data[0] #single record fetched +# nomic_ids_to_delete = [] +# context_count = len(data['contexts']) +# for i in range(1, context_count + 1): +# nomic_ids_to_delete.append(str(data['id']) + "_" + str(i)) + +# # delete from Nomic +# delete_from_document_map(course_name, nomic_ids_to_delete) +# except Exception as e: +# logging.info("Error in deleting file from Nomic:", e) +# sentry_sdk.capture_exception(e) + +# try: +# # delete from Supabase +# self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('url', source_url).eq( +# 'course_name', course_name).execute() +# except Exception as e: +# logging.info("Error in deleting file from supabase:", e) +# sentry_sdk.capture_exception(e) + +# # Delete from Supabase +# return "Success" +# except Exception as e: +# err: str = f"ERROR IN delete_data: Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore +# logging.info(err) +# sentry_sdk.capture_exception(e) +# return err + +# # def ingest_coursera(self, coursera_course_name: str, course_name: str) -> str: +# # """ Download all the files from a coursera course and ingest them. + +# # 1. Download the coursera content. +# # 2. Upload to S3 (so users can view it) +# # 3. Run everything through the ingest_bulk method. + +# # Args: +# # coursera_course_name (str): The name of the coursera course. +# # course_name (str): The name of the course in our system. + +# # Returns: +# # _type_: Success or error message. +# # """ +# # certificate = "-ca 'FVhVoDp5cb-ZaoRr5nNJLYbyjCLz8cGvaXzizqNlQEBsG5wSq7AHScZGAGfC1nI0ehXFvWy1NG8dyuIBF7DLMA.X3cXsDvHcOmSdo3Fyvg27Q.qyGfoo0GOHosTVoSMFy-gc24B-_BIxJtqblTzN5xQWT3hSntTR1DMPgPQKQmfZh_40UaV8oZKKiF15HtZBaLHWLbpEpAgTg3KiTiU1WSdUWueo92tnhz-lcLeLmCQE2y3XpijaN6G4mmgznLGVsVLXb-P3Cibzz0aVeT_lWIJNrCsXrTFh2HzFEhC4FxfTVqS6cRsKVskPpSu8D9EuCQUwJoOJHP_GvcME9-RISBhi46p-Z1IQZAC4qHPDhthIJG4bJqpq8-ZClRL3DFGqOfaiu5y415LJcH--PRRKTBnP7fNWPKhcEK2xoYQLr9RxBVL3pzVPEFyTYtGg6hFIdJcjKOU11AXAnQ-Kw-Gb_wXiHmu63veM6T8N2dEkdqygMre_xMDT5NVaP3xrPbA4eAQjl9yov4tyX4AQWMaCS5OCbGTpMTq2Y4L0Mbz93MHrblM2JL_cBYa59bq7DFK1IgzmOjFhNG266mQlC9juNcEhc'" +# # always_use_flags = "-u kastanvday@gmail.com -p hSBsLaF5YM469# --ignore-formats mp4 --subtitle-language en --path ./coursera-dl" + +# # try: +# # subprocess.run( +# # f"coursera-dl {always_use_flags} {certificate} {coursera_course_name}", +# # check=True, +# # shell=True, # nosec -- reasonable bandit error suppression +# # stdout=subprocess.PIPE, +# # stderr=subprocess.PIPE) # capture_output=True, +# # dl_results_path = os.path.join('coursera-dl', coursera_course_name) +# # s3_paths: Union[List, None] = upload_data_files_to_s3(course_name, dl_results_path) + +# # if s3_paths is None: +# # return "Error: No files found in the coursera-dl directory" + +# # logging.info("starting bulk ingest") +# # start_time = time.monotonic() +# # self.bulk_ingest(s3_paths, course_name) +# # logging.info("completed bulk ingest") +# # logging.info(f"⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds") + +# # # Cleanup the coursera downloads +# # shutil.rmtree(dl_results_path) + +# # return "Success" +# # except Exception as e: +# # err: str = f"Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore +# # logging.info(err) +# # return err + +# # def list_files_recursively(self, bucket, prefix): +# # all_files = [] +# # continuation_token = None + +# # while True: +# # list_objects_kwargs = { +# # 'Bucket': bucket, +# # 'Prefix': prefix, +# # } +# # if continuation_token: +# # list_objects_kwargs['ContinuationToken'] = continuation_token + +# # response = self.s3_client.list_objects_v2(**list_objects_kwargs) + +# # if 'Contents' in response: +# # for obj in response['Contents']: +# # all_files.append(obj['Key']) + +# # if response['IsTruncated']: +# # continuation_token = response['NextContinuationToken'] +# # else: +# # break + +# # return all_files + +# if __name__ == "__main__": +# raise NotImplementedError("This file is not meant to be run directly") +# text = "Testing 123" +# # ingest(text=text) diff --git a/ai_ta_backend/beam/nomic_logging.py b/ai_ta_backend/beam/nomic_logging.py index 92db8a62..7396450a 100644 --- a/ai_ta_backend/beam/nomic_logging.py +++ b/ai_ta_backend/beam/nomic_logging.py @@ -1,438 +1,431 @@ -import datetime -import os - -import nomic -import numpy as np -import pandas as pd -import sentry_sdk -import supabase -from langchain.embeddings import OpenAIEmbeddings -from nomic import AtlasProject, atlas - -OPENAI_API_TYPE = "azure" - -SUPABASE_CLIENT = supabase.create_client( # type: ignore - supabase_url=os.getenv('SUPABASE_URL'), # type: ignore - supabase_key=os.getenv('SUPABASE_API_KEY')) # type: ignore - -NOMIC_MAP_NAME_PREFIX = 'Document Map for ' - -## -------------------------------- DOCUMENT MAP FUNCTIONS --------------------------------- ## - -def create_document_map(course_name: str): - """ - This is a function which creates a document map for a given course from scratch - 1. Gets count of documents for the course - 2. If less than 20, returns a message that a map cannot be created - 3. If greater than 20, iteratively fetches documents in batches of 25 - 4. Prepares metadata and embeddings for nomic upload - 5. Creates a new map and uploads the data - - Args: - course_name: str - Returns: - str: success or failed - """ - print("in create_document_map()") - nomic.login(os.getenv('NOMIC_API_KEY')) - - try: - # check if map exists - response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() - if response.data: - if response.data[0]['doc_map_id']: - return "Map already exists for this course." - - # fetch relevant document data from Supabase - response = SUPABASE_CLIENT.table("documents").select("id", - count="exact").eq("course_name", - course_name).order('id', - desc=False).execute() - if not response.count: - return "No documents found for this course." - - total_doc_count = response.count - print("Total number of documents in Supabase: ", total_doc_count) - - # minimum 20 docs needed to create map - if total_doc_count < 20: - return "Cannot create a map because there are less than 20 documents in the course." - - first_id = response.data[0]['id'] - - combined_dfs = [] - curr_total_doc_count = 0 - doc_count = 0 - first_batch = True - - # iteratively query in batches of 25 - while curr_total_doc_count < total_doc_count: - - response = SUPABASE_CLIENT.table("documents").select( - "id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gte( - 'id', first_id).order('id', desc=False).limit(25).execute() - df = pd.DataFrame(response.data) - combined_dfs.append(df) # list of dfs - - curr_total_doc_count += len(response.data) - doc_count += len(response.data) - - if doc_count >= 1000: # upload to Nomic in batches of 1000 - - # concat all dfs from the combined_dfs list - final_df = pd.concat(combined_dfs, ignore_index=True) - - # prep data for nomic upload - embeddings, metadata = data_prep_for_doc_map(final_df) - - if first_batch: - # create a new map - print("Creating new map...") - project_name = NOMIC_MAP_NAME_PREFIX + course_name - index_name = course_name + "_doc_index" - topic_label_field = "text" - colorable_fields = ["readable_filename", "text", "base_url", "created_at"] - result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) - - if result == "success": - # update flag - first_batch = False - # log project info to supabase - project = AtlasProject(name=project_name, add_datums_if_exists=True) - project_id = project.id - last_id = int(final_df['id'].iloc[-1]) - project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} - project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() - if project_response.data: - update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) - else: - insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() - print("Insert Response from supabase: ", insert_response) - - - else: - # append to existing map - print("Appending data to existing map...") - project_name = NOMIC_MAP_NAME_PREFIX + course_name - # add project lock logic here - result = append_to_map(embeddings, metadata, project_name) - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - info = {'last_uploaded_doc_id': last_id} - update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) - - # reset variables - combined_dfs = [] - doc_count = 0 - print("Records uploaded: ", curr_total_doc_count) - - # set first_id for next iteration - first_id = response.data[-1]['id'] + 1 - - # upload last set of docs - if doc_count > 0: - final_df = pd.concat(combined_dfs, ignore_index=True) - embeddings, metadata = data_prep_for_doc_map(final_df) - project_name = NOMIC_MAP_NAME_PREFIX + course_name - if first_batch: - index_name = course_name + "_doc_index" - topic_label_field = "text" - colorable_fields = ["readable_filename", "text", "base_url", "created_at"] - result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) - else: - result = append_to_map(embeddings, metadata, project_name) - - # update the last uploaded id in supabase - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - project = AtlasProject(name=project_name, add_datums_if_exists=True) - project_id = project.id - project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} - print("project_info: ", project_info) - project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() - if project_response.data: - update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) - else: - insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() - print("Insert Response from supabase: ", insert_response) - - - # rebuild the map - rebuild_map(course_name, "document") - - except Exception as e: - print(e) - sentry_sdk.capture_exception(e) - return "failed" - -def delete_from_document_map(course_name: str, ids: list): - """ - This function is used to delete datapoints from a document map. - Currently used within the delete_data() function in vector_database.py - Args: - course_name: str - ids: list of str - """ - print("in delete_from_document_map()") - - try: - # check if project exists - response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() - if response.data: - project_id = response.data[0]['doc_map_id'] - else: - return "No document map found for this course" - - # fetch project from Nomic - project = AtlasProject(project_id=project_id, add_datums_if_exists=True) - - # delete the ids from Nomic - print("Deleting point from document map:", project.delete_data(ids)) - with project.wait_for_project_lock(): - project.rebuild_maps() - return "success" - except Exception as e: - print(e) - sentry_sdk.capture_exception(e) - return "Error in deleting from document map: {e}" - - -def log_to_document_map(course_name: str): - """ - This is a function which appends new documents to an existing document map. It's called - at the end of split_and_upload() after inserting data to Supabase. - Args: - data: dict - the response data from Supabase insertion - """ - print("in add_to_document_map()") - - try: - # check if map exists - response = SUPABASE_CLIENT.table("projects").select("doc_map_id, last_uploaded_doc_id").eq("course_name", course_name).execute() - if response.data: - if response.data[0]['doc_map_id']: - project_id = response.data[0]['doc_map_id'] - last_uploaded_doc_id = response.data[0]['last_uploaded_doc_id'] - else: - # entry present in supabase, but doc map not present - create_document_map(course_name) - return "Document map not present, triggering map creation." - - else: - # create a map - create_document_map(course_name) - return "Document map not present, triggering map creation." - - project = AtlasProject(project_id=project_id, add_datums_if_exists=True) - project_name = "Document Map for " + course_name - - # check if project is LOCKED, if yes -> skip logging - if not project.is_accepting_data: - return "Skipping Nomic logging because project is locked." - - # fetch count of records greater than last_uploaded_doc_id - print("last uploaded doc id: ", last_uploaded_doc_id) - response = SUPABASE_CLIENT.table("documents").select("id", count="exact").eq("course_name", course_name).gt("id", last_uploaded_doc_id).execute() - print("Number of new documents: ", response.count) - - total_doc_count = response.count - current_doc_count = 0 - combined_dfs = [] - doc_count = 0 - first_id = last_uploaded_doc_id - while current_doc_count < total_doc_count: - # fetch all records from supabase greater than last_uploaded_doc_id - response = SUPABASE_CLIENT.table("documents").select("id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gt("id", first_id).limit(25).execute() - df = pd.DataFrame(response.data) - combined_dfs.append(df) # list of dfs - - current_doc_count += len(response.data) - doc_count += len(response.data) - - if doc_count >= 1000: # upload to Nomic in batches of 1000 - # concat all dfs from the combined_dfs list - final_df = pd.concat(combined_dfs, ignore_index=True) - # prep data for nomic upload - embeddings, metadata = data_prep_for_doc_map(final_df) - - # append to existing map - print("Appending data to existing map...") - - result = append_to_map(embeddings, metadata, project_name) - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - info = {'last_uploaded_doc_id': last_id} - update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) - - # reset variables - combined_dfs = [] - doc_count = 0 - print("Records uploaded: ", current_doc_count) - - # set first_id for next iteration - first_id = response.data[-1]['id'] + 1 - - # upload last set of docs - if doc_count > 0: - final_df = pd.concat(combined_dfs, ignore_index=True) - embeddings, metadata = data_prep_for_doc_map(final_df) - result = append_to_map(embeddings, metadata, project_name) - - # update the last uploaded id in supabase - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - project_info = {'last_uploaded_doc_id': last_id} - update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) - - return "success" - except Exception as e: - print(e) - return "failed" - - -def create_map(embeddings, metadata, map_name, index_name, topic_label_field, colorable_fields): - """ - Generic function to create a Nomic map from given parameters. - Args: - embeddings: np.array of embeddings - metadata: pd.DataFrame of metadata - map_name: str - index_name: str - topic_label_field: str - colorable_fields: list of str - """ - nomic.login(os.getenv('NOMIC_API_KEY')) - try: - project = atlas.map_embeddings(embeddings=embeddings, - data=metadata, - id_field="id", - build_topic_model=True, - topic_label_field=topic_label_field, - name=map_name, - colorable_fields=colorable_fields, - add_datums_if_exists=True) - project.create_index(name=index_name, build_topic_model=True) - return "success" - except Exception as e: - print(e) - return "Error in creating map: {e}" - -def append_to_map(embeddings, metadata, map_name): - """ - Generic function to append new data to an existing Nomic map. - Args: - embeddings: np.array of embeddings - metadata: pd.DataFrame of Nomic upload metadata - map_name: str - """ - - nomic.login(os.getenv('NOMIC_API_KEY')) - try: - project = atlas.AtlasProject(name=map_name, add_datums_if_exists=True) - with project.wait_for_project_lock(): - project.add_embeddings(embeddings=embeddings, data=metadata) - return "success" - except Exception as e: - print(e) - return "Error in appending to map: {e}" - -def data_prep_for_doc_map(df: pd.DataFrame): - """ - This function prepares embeddings and metadata for nomic upload in document map creation. - Args: - df: pd.DataFrame - the dataframe of documents from Supabase - Returns: - embeddings: np.array of embeddings - metadata: pd.DataFrame of metadata - """ - print("in data_prep_for_doc_map()") - - metadata = [] - embeddings = [] - - texts = [] - - for index, row in df.iterrows(): - current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - created_at = datetime.datetime.strptime(row['created_at'], "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y-%m-%d %H:%M:%S") - if row['url'] == None: - row['url'] = "" - if row['base_url'] == None: - row['base_url'] = "" - # iterate through all contexts and create separate entries for each - context_count = 0 - for context in row['contexts']: - context_count += 1 - text_row = context['text'] - embeddings_row = context['embedding'] - - meta_row = { - "id": str(row['id']) + "_" + str(context_count), - "created_at": created_at, - "s3_path": row['s3_path'], - "url": row['url'], - "base_url": row['base_url'], - "readable_filename": row['readable_filename'], - "modified_at": current_time, - "text": text_row - } - - embeddings.append(embeddings_row) - metadata.append(meta_row) - texts.append(text_row) - - embeddings_np = np.array(embeddings, dtype=object) - print("Shape of embeddings: ", embeddings_np.shape) - - # check dimension if embeddings_np is (n, 1536) - if len(embeddings_np.shape) < 2: - print("Creating new embeddings...") - - embeddings_model = OpenAIEmbeddings(openai_api_type="openai", - openai_api_base="https://api.openai.com/v1/", - openai_api_key=os.getenv('VLADS_OPENAI_KEY')) # type: ignore - embeddings = embeddings_model.embed_documents(texts) - - metadata = pd.DataFrame(metadata) - embeddings = np.array(embeddings) - - return embeddings, metadata - -def rebuild_map(course_name:str, map_type:str): - """ - This function rebuilds a given map in Nomic. - """ - print("in rebuild_map()") - nomic.login(os.getenv('NOMIC_API_KEY')) - if map_type.lower() == 'document': - NOMIC_MAP_NAME_PREFIX = 'Document Map for ' - else: - NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' - - try: - # fetch project from Nomic - project_name = NOMIC_MAP_NAME_PREFIX + course_name - project = AtlasProject(name=project_name, add_datums_if_exists=True) - - if project.is_accepting_data: # temporary fix - will skip rebuilding if project is locked - project.rebuild_maps() - return "success" - except Exception as e: - print(e) - sentry_sdk.capture_exception(e) - return "Error in rebuilding map: {e}" - - - -if __name__ == '__main__': - pass - +# import datetime +# import os + +# import nomic +# import numpy as np +# import pandas as pd +# import sentry_sdk +# import supabase +# from langchain.embeddings import OpenAIEmbeddings +# from nomic import AtlasProject, atlas + +# OPENAI_API_TYPE = "azure" + +# SUPABASE_CLIENT = supabase.create_client( # type: ignore +# supabase_url=os.getenv('SUPABASE_URL'), # type: ignore +# supabase_key=os.getenv('SUPABASE_API_KEY')) # type: ignore + +# NOMIC_MAP_NAME_PREFIX = 'Document Map for ' + +# ## -------------------------------- DOCUMENT MAP FUNCTIONS --------------------------------- ## + +# def create_document_map(course_name: str): +# """ +# This is a function which creates a document map for a given course from scratch +# 1. Gets count of documents for the course +# 2. If less than 20, returns a message that a map cannot be created +# 3. If greater than 20, iteratively fetches documents in batches of 25 +# 4. Prepares metadata and embeddings for nomic upload +# 5. Creates a new map and uploads the data + +# Args: +# course_name: str +# Returns: +# str: success or failed +# """ +# logging.info("in create_document_map()") +# nomic.login(os.getenv('NOMIC_API_KEY')) + +# try: +# # check if map exists +# response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() +# if response.data: +# if response.data[0]['doc_map_id']: +# return "Map already exists for this course." + +# # fetch relevant document data from Supabase +# response = SUPABASE_CLIENT.table("documents").select("id", +# count="exact").eq("course_name", +# course_name).order('id', +# desc=False).execute() +# if not response.count: +# return "No documents found for this course." + +# total_doc_count = response.count +# logging.info("Total number of documents in Supabase: ", total_doc_count) + +# # minimum 20 docs needed to create map +# if total_doc_count < 20: +# return "Cannot create a map because there are less than 20 documents in the course." + +# first_id = response.data[0]['id'] + +# combined_dfs = [] +# curr_total_doc_count = 0 +# doc_count = 0 +# first_batch = True + +# # iteratively query in batches of 25 +# while curr_total_doc_count < total_doc_count: + +# response = SUPABASE_CLIENT.table("documents").select( +# "id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gte( +# 'id', first_id).order('id', desc=False).limit(25).execute() +# df = pd.DataFrame(response.data) +# combined_dfs.append(df) # list of dfs + +# curr_total_doc_count += len(response.data) +# doc_count += len(response.data) + +# if doc_count >= 1000: # upload to Nomic in batches of 1000 + +# # concat all dfs from the combined_dfs list +# final_df = pd.concat(combined_dfs, ignore_index=True) + +# # prep data for nomic upload +# embeddings, metadata = data_prep_for_doc_map(final_df) + +# if first_batch: +# # create a new map +# logging.info("Creating new map...") +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# index_name = course_name + "_doc_index" +# topic_label_field = "text" +# colorable_fields = ["readable_filename", "text", "base_url", "created_at"] +# result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) + +# if result == "success": +# # update flag +# first_batch = False +# # log project info to supabase +# project = AtlasProject(name=project_name, add_datums_if_exists=True) +# project_id = project.id +# last_id = int(final_df['id'].iloc[-1]) +# project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} +# project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() +# if project_response.data: +# update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() +# logging.info("Response from supabase: ", update_response) +# else: +# insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() +# logging.info("Insert Response from supabase: ", insert_response) + +# else: +# # append to existing map +# logging.info("Appending data to existing map...") +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# # add project lock logic here +# result = append_to_map(embeddings, metadata, project_name) +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# info = {'last_uploaded_doc_id': last_id} +# update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() +# logging.info("Response from supabase: ", update_response) + +# # reset variables +# combined_dfs = [] +# doc_count = 0 +# logging.info("Records uploaded: ", curr_total_doc_count) + +# # set first_id for next iteration +# first_id = response.data[-1]['id'] + 1 + +# # upload last set of docs +# if doc_count > 0: +# final_df = pd.concat(combined_dfs, ignore_index=True) +# embeddings, metadata = data_prep_for_doc_map(final_df) +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# if first_batch: +# index_name = course_name + "_doc_index" +# topic_label_field = "text" +# colorable_fields = ["readable_filename", "text", "base_url", "created_at"] +# result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) +# else: +# result = append_to_map(embeddings, metadata, project_name) + +# # update the last uploaded id in supabase +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# project = AtlasProject(name=project_name, add_datums_if_exists=True) +# project_id = project.id +# project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} +# logging.info("project_info: ", project_info) +# project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() +# if project_response.data: +# update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() +# logging.info("Response from supabase: ", update_response) +# else: +# insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() +# logging.info("Insert Response from supabase: ", insert_response) + +# # rebuild the map +# rebuild_map(course_name, "document") + +# except Exception as e: +# logging.info(e) +# sentry_sdk.capture_exception(e) +# return "failed" + +# def delete_from_document_map(course_name: str, ids: list): +# """ +# This function is used to delete datapoints from a document map. +# Currently used within the delete_data() function in vector_database.py +# Args: +# course_name: str +# ids: list of str +# """ +# logging.info("in delete_from_document_map()") + +# try: +# # check if project exists +# response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() +# if response.data: +# project_id = response.data[0]['doc_map_id'] +# else: +# return "No document map found for this course" + +# # fetch project from Nomic +# project = AtlasProject(project_id=project_id, add_datums_if_exists=True) + +# # delete the ids from Nomic +# logging.info("Deleting point from document map:", project.delete_data(ids)) +# with project.wait_for_project_lock(): +# project.rebuild_maps() +# return "success" +# except Exception as e: +# logging.info(e) +# sentry_sdk.capture_exception(e) +# return "Error in deleting from document map: {e}" + +# def log_to_document_map(course_name: str): +# """ +# This is a function which appends new documents to an existing document map. It's called +# at the end of split_and_upload() after inserting data to Supabase. +# Args: +# data: dict - the response data from Supabase insertion +# """ +# logging.info("in add_to_document_map()") + +# try: +# # check if map exists +# response = SUPABASE_CLIENT.table("projects").select("doc_map_id, last_uploaded_doc_id").eq("course_name", course_name).execute() +# if response.data: +# if response.data[0]['doc_map_id']: +# project_id = response.data[0]['doc_map_id'] +# last_uploaded_doc_id = response.data[0]['last_uploaded_doc_id'] +# else: +# # entry present in supabase, but doc map not present +# create_document_map(course_name) +# return "Document map not present, triggering map creation." + +# else: +# # create a map +# create_document_map(course_name) +# return "Document map not present, triggering map creation." + +# project = AtlasProject(project_id=project_id, add_datums_if_exists=True) +# project_name = "Document Map for " + course_name + +# # check if project is LOCKED, if yes -> skip logging +# if not project.is_accepting_data: +# return "Skipping Nomic logging because project is locked." + +# # fetch count of records greater than last_uploaded_doc_id +# logging.info("last uploaded doc id: ", last_uploaded_doc_id) +# response = SUPABASE_CLIENT.table("documents").select("id", count="exact").eq("course_name", course_name).gt("id", last_uploaded_doc_id).execute() +# logging.info("Number of new documents: ", response.count) + +# total_doc_count = response.count +# current_doc_count = 0 +# combined_dfs = [] +# doc_count = 0 +# first_id = last_uploaded_doc_id +# while current_doc_count < total_doc_count: +# # fetch all records from supabase greater than last_uploaded_doc_id +# response = SUPABASE_CLIENT.table("documents").select("id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gt("id", first_id).limit(25).execute() +# df = pd.DataFrame(response.data) +# combined_dfs.append(df) # list of dfs + +# current_doc_count += len(response.data) +# doc_count += len(response.data) + +# if doc_count >= 1000: # upload to Nomic in batches of 1000 +# # concat all dfs from the combined_dfs list +# final_df = pd.concat(combined_dfs, ignore_index=True) +# # prep data for nomic upload +# embeddings, metadata = data_prep_for_doc_map(final_df) + +# # append to existing map +# logging.info("Appending data to existing map...") + +# result = append_to_map(embeddings, metadata, project_name) +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# info = {'last_uploaded_doc_id': last_id} +# update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() +# logging.info("Response from supabase: ", update_response) + +# # reset variables +# combined_dfs = [] +# doc_count = 0 +# logging.info("Records uploaded: ", current_doc_count) + +# # set first_id for next iteration +# first_id = response.data[-1]['id'] + 1 + +# # upload last set of docs +# if doc_count > 0: +# final_df = pd.concat(combined_dfs, ignore_index=True) +# embeddings, metadata = data_prep_for_doc_map(final_df) +# result = append_to_map(embeddings, metadata, project_name) + +# # update the last uploaded id in supabase +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# project_info = {'last_uploaded_doc_id': last_id} +# update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() +# logging.info("Response from supabase: ", update_response) + +# return "success" +# except Exception as e: +# logging.info(e) +# return "failed" + +# def create_map(embeddings, metadata, map_name, index_name, topic_label_field, colorable_fields): +# """ +# Generic function to create a Nomic map from given parameters. +# Args: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of metadata +# map_name: str +# index_name: str +# topic_label_field: str +# colorable_fields: list of str +# """ +# nomic.login(os.getenv('NOMIC_API_KEY')) +# try: +# project = atlas.map_embeddings(embeddings=embeddings, +# data=metadata, +# id_field="id", +# build_topic_model=True, +# topic_label_field=topic_label_field, +# name=map_name, +# colorable_fields=colorable_fields, +# add_datums_if_exists=True) +# project.create_index(name=index_name, build_topic_model=True) +# return "success" +# except Exception as e: +# logging.info(e) +# return "Error in creating map: {e}" + +# def append_to_map(embeddings, metadata, map_name): +# """ +# Generic function to append new data to an existing Nomic map. +# Args: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of Nomic upload metadata +# map_name: str +# """ + +# nomic.login(os.getenv('NOMIC_API_KEY')) +# try: +# project = atlas.AtlasProject(name=map_name, add_datums_if_exists=True) +# with project.wait_for_project_lock(): +# project.add_embeddings(embeddings=embeddings, data=metadata) +# return "success" +# except Exception as e: +# logging.info(e) +# return "Error in appending to map: {e}" + +# def data_prep_for_doc_map(df: pd.DataFrame): +# """ +# This function prepares embeddings and metadata for nomic upload in document map creation. +# Args: +# df: pd.DataFrame - the dataframe of documents from Supabase +# Returns: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of metadata +# """ +# logging.info("in data_prep_for_doc_map()") + +# metadata = [] +# embeddings = [] + +# texts = [] + +# for index, row in df.iterrows(): +# current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") +# created_at = datetime.datetime.strptime(row['created_at'], "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y-%m-%d %H:%M:%S") +# if row['url'] == None: +# row['url'] = "" +# if row['base_url'] == None: +# row['base_url'] = "" +# # iterate through all contexts and create separate entries for each +# context_count = 0 +# for context in row['contexts']: +# context_count += 1 +# text_row = context['text'] +# embeddings_row = context['embedding'] + +# meta_row = { +# "id": str(row['id']) + "_" + str(context_count), +# "created_at": created_at, +# "s3_path": row['s3_path'], +# "url": row['url'], +# "base_url": row['base_url'], +# "readable_filename": row['readable_filename'], +# "modified_at": current_time, +# "text": text_row +# } + +# embeddings.append(embeddings_row) +# metadata.append(meta_row) +# texts.append(text_row) + +# embeddings_np = np.array(embeddings, dtype=object) +# logging.info("Shape of embeddings: ", embeddings_np.shape) + +# # check dimension if embeddings_np is (n, 1536) +# if len(embeddings_np.shape) < 2: +# logging.info("Creating new embeddings...") + +# embeddings_model = OpenAIEmbeddings(openai_api_type="openai", +# openai_api_base="https://api.openai.com/v1/", +# openai_api_key=os.getenv('VLADS_OPENAI_KEY')) # type: ignore +# embeddings = embeddings_model.embed_documents(texts) + +# metadata = pd.DataFrame(metadata) +# embeddings = np.array(embeddings) + +# return embeddings, metadata + +# def rebuild_map(course_name:str, map_type:str): +# """ +# This function rebuilds a given map in Nomic. +# """ +# logging.info("in rebuild_map()") +# nomic.login(os.getenv('NOMIC_API_KEY')) +# if map_type.lower() == 'document': +# NOMIC_MAP_NAME_PREFIX = 'Document Map for ' +# else: +# NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' + +# try: +# # fetch project from Nomic +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# project = AtlasProject(name=project_name, add_datums_if_exists=True) + +# if project.is_accepting_data: # temporary fix - will skip rebuilding if project is locked +# project.rebuild_maps() +# return "success" +# except Exception as e: +# logging.info(e) +# sentry_sdk.capture_exception(e) +# return "Error in rebuilding map: {e}" + +# if __name__ == '__main__': +# pass diff --git a/ai_ta_backend/database/aws.py b/ai_ta_backend/database/aws.py index 68e61b68..047974cd 100644 --- a/ai_ta_backend/database/aws.py +++ b/ai_ta_backend/database/aws.py @@ -1,19 +1,31 @@ +import logging import os import boto3 from injector import inject -class AWSStorage: +class AWSStorage(): @inject def __init__(self): - # S3 - self.s3_client = boto3.client( - 's3', - aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'], - aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'], - ) + if all(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"]): + logging.info("Using AWS for storage") + self.s3_client = boto3.client( + 's3', + aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), + aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), + ) + elif all(os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): + logging.info("Using Minio for storage") + self.s3_client = boto3.client('s3', + endpoint_url=os.getenv('MINIO_URL'), + aws_access_key_id=os.getenv('MINIO_ACCESS_KEY'), + aws_secret_access_key=os.getenv('MINIO_SECRET_KEY'), + config=boto3.session.Config(signature_version='s3v4'), + region_name='us-east-1') + else: + raise ValueError("No valid storage credentials found.") def upload_file(self, file_path: str, bucket_name: str, object_name: str): self.s3_client.upload_file(file_path, bucket_name, object_name) @@ -26,9 +38,4 @@ def delete_file(self, bucket_name: str, s3_path: str): def generatePresignedUrl(self, object: str, bucket_name: str, s3_path: str, expiration: int = 3600): # generate presigned URL - return self.s3_client.generate_presigned_url('get_object', - Params={ - 'Bucket': bucket_name, - 'Key': s3_path - }, - ExpiresIn=expiration) + return self.s3_client.generate_presigned_url('get_object', Params={'Bucket': bucket_name, 'Key': s3_path}, ExpiresIn=expiration) diff --git a/ai_ta_backend/database/poi_sql.py b/ai_ta_backend/database/poi_sql.py new file mode 100644 index 00000000..8cf94d4e --- /dev/null +++ b/ai_ta_backend/database/poi_sql.py @@ -0,0 +1,14 @@ +from typing import List +from injector import inject +from flask_sqlalchemy import SQLAlchemy +import ai_ta_backend.model.models as models +import logging + +from ai_ta_backend.model.response import DatabaseResponse + +class POISQLDatabase: + + @inject + def __init__(self, db: SQLAlchemy): + logging.info("Initializing SQLAlchemyDatabase") + self.db = db \ No newline at end of file diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/qdrant.py similarity index 89% rename from ai_ta_backend/database/vector.py rename to ai_ta_backend/database/qdrant.py index f9d002ec..18a792ff 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/qdrant.py @@ -1,10 +1,12 @@ +import logging import os from typing import List from injector import inject from langchain.embeddings.openai import OpenAIEmbeddings from langchain.vectorstores import Qdrant -from qdrant_client import QdrantClient, models +from qdrant_client import models +from qdrant_client import QdrantClient OPENAI_API_TYPE = "azure" # "openai" or "azure" @@ -21,7 +23,8 @@ def __init__(self): """ # vector DB self.qdrant_client = QdrantClient( - url=os.environ['QDRANT_URL'], + url='http://qdrant:6333', + https=False, api_key=os.environ['QDRANT_API_KEY'], timeout=20, # default is 5 seconds. Getting timeout errors w/ document groups. ) @@ -37,11 +40,11 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q Search the vector database for a given query. """ must_conditions = self._create_search_conditions(course_name, doc_groups) - + # Filter for the must_conditions myfilter = models.Filter(must=must_conditions) - print(f"Filter: {myfilter}") - + logging.info(f"Qdrant serach Filter: {myfilter}") + # Search the vector database search_results = self.qdrant_client.search( collection_name=os.environ['QDRANT_COLLECTION_NAME'], @@ -51,26 +54,25 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q limit=top_n, # Return n closest points # In a system with high disk latency, the re-scoring step may become a bottleneck: https://qdrant.tech/documentation/guides/quantization/ search_params=models.SearchParams(quantization=models.QuantizationSearchParams(rescore=False))) + return search_results def _create_search_conditions(self, course_name, doc_groups: List[str]): """ Create search conditions for the vector search. """ - must_conditions: list[models.Condition] = [ - models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)) - ] - + must_conditions: list[models.Condition] = [models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name))] + if doc_groups and 'All Documents' not in doc_groups: # Final combined condition combined_condition = None # Condition for matching any of the specified doc_groups match_any_condition = models.FieldCondition(key='doc_groups', match=models.MatchAny(any=doc_groups)) combined_condition = models.Filter(should=[match_any_condition]) - + # Add the combined condition to the must_conditions list must_conditions.append(combined_condition) - + return must_conditions def delete_data(self, collection_name: str, key: str, value: str): diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index 6f7ae01d..d58dc618 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -1,125 +1,231 @@ -import os +import logging +from typing import List -import supabase +from flask_sqlalchemy import SQLAlchemy from injector import inject +import ai_ta_backend.model.models as models +from ai_ta_backend.model.response import DatabaseResponse -class SQLDatabase: + +class SQLAlchemyDatabase: @inject - def __init__(self): - # Create a Supabase client - self.supabase_client = supabase.create_client( # type: ignore - supabase_url=os.environ['SUPABASE_URL'], supabase_key=os.environ['SUPABASE_API_KEY']) + def __init__(self, db: SQLAlchemy): + logging.info("Initializing SQLAlchemyDatabase") + self.db = db def getAllMaterialsForCourse(self, course_name: str): - return self.supabase_client.table( - os.environ['SUPABASE_DOCUMENTS_TABLE']).select('course_name, s3_path, readable_filename, url, base_url').eq( - 'course_name', course_name).execute() + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq( - 's3_path', s3_path).eq('course_name', course_name).execute() + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name, models.Document.s3_path == s3_path) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq( - key, value).eq('course_name', course_name).execute() + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name, getattr(models.Document, key) == value) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq(key, value).eq( - 'course_name', course_name).execute() + try: + query = self.db.delete(models.Document).where(models.Document.course_name == course_name, getattr(models.Document, key) == value) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() def deleteMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('s3_path', s3_path).eq( - 'course_name', course_name).execute() + try: + query = self.db.delete(models.Document).where(models.Document.course_name == course_name, models.Document.s3_path == s3_path) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() def getProjectsMapForCourse(self, course_name: str): - return self.supabase_client.table("projects").select("doc_map_id").eq("course_name", course_name).execute() - - def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: str, table_name: str): - if from_date != '' and to_date != '': - # query between the dates - print("from_date and to_date") - - response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", course_name).gte( - 'created_at', from_date).lte('created_at', to_date).order('id', desc=False).execute() - - elif from_date != '' and to_date == '': - # query from from_date to now - print("only from_date") - response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", course_name).gte( - 'created_at', from_date).order('id', desc=False).execute() - - elif from_date == '' and to_date != '': - # query from beginning to to_date - print("only to_date") - response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", course_name).lte( - 'created_at', to_date).order('id', desc=False).execute() - - else: - # query all data - print("No dates") - response = self.supabase_client.table(table_name).select("id", count='exact').eq( - "course_name", course_name).order('id', desc=False).execute() - return response - - def getAllFromTableForDownloadType(self, course_name: str, download_type: str, first_id: int): - if download_type == 'documents': - response = self.supabase_client.table("documents").select("*").eq("course_name", course_name).gte( - 'id', first_id).order('id', desc=False).limit(100).execute() - else: - response = self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gte( - 'id', first_id).order('id', desc=False).limit(100).execute() - - return response + try: + query = self.db.select(models.Project.doc_map_id).where(models.Project.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + projects: List[models.Project] = [doc for doc in result] + return DatabaseResponse[models.Project](data=projects, count=len(result)) + finally: + self.db.session.close() + + def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: str): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name) + if from_date: + query = query.filter(models.Document.created_at >= from_date) + if to_date: + query = query.filter(models.Document.created_at <= to_date) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getConversationsBetweenDates(self, course_name: str, from_date: str, to_date: str): + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name) + if from_date: + query = query.filter(models.LlmConvoMonitor.created_at >= from_date) + if to_date: + query = query.filter(models.LlmConvoMonitor.created_at <= to_date) + result = self.db.session.execute(query).scalars().all() + documents: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getAllDocumentsForDownload(self, course_name: str, first_id: int): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name, models.Document.id >= first_id) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getAllConversationsForDownload(self, course_name: str, first_id: int): + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name, models.LlmConvoMonitor.id + >= first_id) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() def getAllConversationsBetweenIds(self, course_name: str, first_id: int, last_id: int, limit: int = 50): - if last_id == 0: - return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gt( - 'id', first_id).order('id', desc=False).limit(limit).execute() - else: - return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gte( - 'id', first_id).lte('id', last_id).order('id', desc=False).limit(limit).execute() - + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name, models.LlmConvoMonitor.id + > first_id) + if last_id != 0: + query = query.filter(models.LlmConvoMonitor.id <= last_id) + query = query.limit(limit) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() def getDocsForIdsGte(self, course_name: str, first_id: int, fields: str = "*", limit: int = 100): - return self.supabase_client.table("documents").select(fields).eq("course_name", course_name).gte( - 'id', first_id).order('id', desc=False).limit(limit).execute() + try: + fields_to_select = [getattr(models.Document, field) for field in fields.split(", ")] + query = self.db.select(*fields_to_select).where(models.Document.course_name == course_name, models.Document.id + >= first_id).limit(limit) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() def insertProjectInfo(self, project_info): - return self.supabase_client.table("projects").insert(project_info).execute() + try: + self.db.session.execute(self.db.insert(models.Project).values(**project_info)) + self.db.session.commit() + finally: + self.db.session.close() def getAllFromLLMConvoMonitor(self, course_name: str): - return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).order('id', desc=False).execute() - + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() + def getCountFromLLMConvoMonitor(self, course_name: str, last_id: int): - if last_id == 0: - return self.supabase_client.table("llm-convo-monitor").select("id", count='exact').eq("course_name", course_name).order('id', desc=False).execute() - else: - return self.supabase_client.table("llm-convo-monitor").select("id", count='exact').eq("course_name", course_name).gt("id", last_id).order('id', desc=False).execute() - + try: + query = self.db.select(models.LlmConvoMonitor.id).where(models.LlmConvoMonitor.course_name == course_name) + if last_id != 0: + query = query.filter(models.LlmConvoMonitor.id > last_id) + count_query = self.db.select(self.db.func.count()).select_from(query.subquery()) + self.db.session.execute(count_query).scalar() + return DatabaseResponse[models.LlmConvoMonitor](data=[], count=1) + finally: + self.db.session.close() + def getDocMapFromProjects(self, course_name: str): - return self.supabase_client.table("projects").select("doc_map_id").eq("course_name", course_name).execute() - + try: + query = self.db.select(models.Project.doc_map_id).where(models.Project.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Project] = [doc for doc in result] + return DatabaseResponse[models.Project](data=documents, count=len(result)) + finally: + self.db.session.close() + def getConvoMapFromProjects(self, course_name: str): - return self.supabase_client.table("projects").select("*").eq("course_name", course_name).execute() - + try: + query = self.db.select(models.Project).where(models.Project.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.Project] = [doc for doc in result] + return DatabaseResponse[models.Project](data=conversations, count=len(result)) + finally: + self.db.session.close() + def updateProjects(self, course_name: str, data: dict): - return self.supabase_client.table("projects").update(data).eq("course_name", course_name).execute() - + try: + query = self.db.update(models.Project).where(models.Project.course_name == course_name).values(**data) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + def getLatestWorkflowId(self): - return self.supabase_client.table('n8n_workflows').select("latest_workflow_id").execute() - + try: + query = self.db.select(models.N8nWorkflows.latest_workflow_id) + result = self.db.session.execute(query).fetchone() + return result + finally: + self.db.session.close() + def lockWorkflow(self, id: str): - return self.supabase_client.table('n8n_workflows').insert({"latest_workflow_id": id, "is_locked": True}).execute() - + try: + new_workflow = models.N8nWorkflows(is_locked=True) + self.db.session.add(new_workflow) + self.db.session.commit() + finally: + self.db.session.close() + def deleteLatestWorkflowId(self, id: str): - return self.supabase_client.table('n8n_workflows').delete().eq('latest_workflow_id', id).execute() - + try: + query = self.db.delete(models.N8nWorkflows).where(models.N8nWorkflows.latest_workflow_id == id) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + def unlockWorkflow(self, id: str): - return self.supabase_client.table('n8n_workflows').update({"is_locked": False}).eq('latest_workflow_id', id).execute() + try: + query = self.db.update(models.N8nWorkflows).where(models.N8nWorkflows.latest_workflow_id == id).values(is_locked=False) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() def getConversation(self, course_name: str, key: str, value: str): - return self.supabase_client.table("llm-convo-monitor").select("*").eq(key, value).eq("course_name", course_name).execute() - - + try: + query = self.db.select(models.LlmConvoMonitor).where(getattr(models.LlmConvoMonitor, key) == value) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() diff --git a/ai_ta_backend/database/supabase.py b/ai_ta_backend/database/supabase.py new file mode 100644 index 00000000..5a1cf39c --- /dev/null +++ b/ai_ta_backend/database/supabase.py @@ -0,0 +1,138 @@ +import logging +import os + +from injector import inject +import supabase + + +class SQLDatabase(): + + @inject + def __init__(self): + # Create a Supabase client + self.supabase_client = supabase.create_client( # type: ignore + supabase_url=os.environ['SUPABASE_URL'], supabase_key=os.environ['SUPABASE_API_KEY']) + + def getAllMaterialsForCourse(self, course_name: str): + return self.supabase_client.table( + os.environ['SUPABASE_DOCUMENTS_TABLE']).select('course_name, s3_path, readable_filename, url, base_url').eq( + 'course_name', course_name).execute() + + def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq('s3_path', s3_path).eq( + 'course_name', course_name).execute() + + def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq(key, value).eq( + 'course_name', course_name).execute() + + def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq(key, value).eq('course_name', + course_name).execute() + + def deleteMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('s3_path', + s3_path).eq('course_name', course_name).execute() + + def getProjectsMapForCourse(self, course_name: str): + return self.supabase_client.table("projects").select("doc_map_id").eq("course_name", course_name).execute() + + def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: str, table_name: str): + if from_date != '' and to_date != '': + # query between the dates + logging.info("from_date and to_date") + + response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", course_name).gte( + 'created_at', from_date).lte('created_at', to_date).order('id', desc=False).execute() + + elif from_date != '' and to_date == '': + # query from from_date to now + logging.info("only from_date") + response = self.supabase_client.table(table_name).select("id", + count='exact').eq("course_name", + course_name).gte('created_at', + from_date).order('id', + desc=False).execute() + + elif from_date == '' and to_date != '': + # query from beginning to to_date + logging.info("only to_date") + response = self.supabase_client.table(table_name).select("id", + count='exact').eq("course_name", + course_name).lte('created_at', + to_date).order('id', + desc=False).execute() + + else: + # query all data + logging.info("No dates") + response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", + course_name).order('id', desc=False).execute() + return response + + def getAllFromTableForDownloadType(self, course_name: str, download_type: str, first_id: int): + if download_type == 'documents': + response = self.supabase_client.table("documents").select("*").eq("course_name", + course_name).gte('id', + first_id).order('id', + desc=False).limit(100).execute() + else: + response = self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gte('id', first_id).order( + 'id', desc=False).limit(100).execute() + + return response + + def getAllConversationsBetweenIds(self, course_name: str, first_id: int, last_id: int, limit: int = 50): + if last_id == 0: + return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gt('id', first_id).order( + 'id', desc=False).limit(limit).execute() + else: + return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gte('id', first_id).lte( + 'id', last_id).order('id', desc=False).limit(limit).execute() + + def getDocsForIdsGte(self, course_name: str, first_id: int, fields: str = "*", limit: int = 100): + return self.supabase_client.table("documents").select(fields).eq("course_name", + course_name).gte('id', + first_id).order('id', + desc=False).limit(limit).execute() + + def insertProjectInfo(self, project_info): + return self.supabase_client.table("projects").insert(project_info).execute() + + def getAllFromLLMConvoMonitor(self, course_name: str): + return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).order('id', desc=False).execute() + + def getCountFromLLMConvoMonitor(self, course_name: str, last_id: int): + if last_id == 0: + return self.supabase_client.table("llm-convo-monitor").select("id", count='exact').eq("course_name", + course_name).order('id', desc=False).execute() + else: + return self.supabase_client.table("llm-convo-monitor").select("id", + count='exact').eq("course_name", + course_name).gt("id", + last_id).order('id', + desc=False).execute() + + def getDocMapFromProjects(self, course_name: str): + return self.supabase_client.table("projects").select("doc_map_id").eq("course_name", course_name).execute() + + def getConvoMapFromProjects(self, course_name: str): + return self.supabase_client.table("projects").select("*").eq("course_name", course_name).execute() + + def updateProjects(self, course_name: str, data: dict): + return self.supabase_client.table("projects").update(data).eq("course_name", course_name).execute() + + def getLatestWorkflowId(self): + return self.supabase_client.table('n8n_workflows').select("latest_workflow_id").execute() + + def lockWorkflow(self, id: str): + return self.supabase_client.table('n8n_workflows').insert({"latest_workflow_id": id, "is_locked": True}).execute() + + def deleteLatestWorkflowId(self, id: str): + return self.supabase_client.table('n8n_workflows').delete().eq('latest_workflow_id', id).execute() + + def unlockWorkflow(self, id: str): + return self.supabase_client.table('n8n_workflows').update({"is_locked": False}).eq('latest_workflow_id', id).execute() + + def getConversation(self, course_name: str, key: str, value: str): + return self.supabase_client.table("llm-convo-monitor").select("*").eq(key, value).eq("course_name", course_name).execute() diff --git a/ai_ta_backend/executors/process_pool_executor.py b/ai_ta_backend/executors/process_pool_executor.py index 81b4860c..33dc21aa 100644 --- a/ai_ta_backend/executors/process_pool_executor.py +++ b/ai_ta_backend/executors/process_pool_executor.py @@ -24,8 +24,7 @@ def __init__(self, max_workers=None): def submit(self, fn, *args, **kwargs): raise NotImplementedError( - "ProcessPoolExecutorAdapter does not support 'submit' directly due to its nature. Use 'map' or other methods as needed." - ) + "ProcessPoolExecutorAdapter does not support 'submit' directly due to its nature. Use 'map' or other methods as needed.") def map(self, fn, *iterables, timeout=None, chunksize=1): return self.executor.map(fn, *iterables, timeout=timeout, chunksize=chunksize) diff --git a/ai_ta_backend/extensions.py b/ai_ta_backend/extensions.py new file mode 100644 index 00000000..f0b13d6f --- /dev/null +++ b/ai_ta_backend/extensions.py @@ -0,0 +1,3 @@ +from flask_sqlalchemy import SQLAlchemy + +db = SQLAlchemy() diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 0085abd5..7d9cc47b 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -1,56 +1,56 @@ -import json +import logging import os import time from typing import List - +from ai_ta_backend.database.poi_sql import POISQLDatabase from dotenv import load_dotenv -from flask import ( - Flask, - Response, - abort, - jsonify, - make_response, - request, - send_from_directory, -) +from flask import abort +from flask import Flask +from flask import jsonify +from flask import make_response +from flask import request +from flask import Response +from flask import send_from_directory from flask_cors import CORS from flask_executor import Executor -from flask_injector import FlaskInjector, RequestScope -from injector import Binder, SingletonScope +from flask_injector import FlaskInjector +from flask_injector import RequestScope +from injector import Binder +from injector import SingletonScope from ai_ta_backend.database.aws import AWSStorage -from ai_ta_backend.database.sql import SQLDatabase -from ai_ta_backend.database.vector import VectorDatabase -from ai_ta_backend.executors.flask_executor import ( - ExecutorInterface, - FlaskExecutorAdapter, -) -from ai_ta_backend.executors.process_pool_executor import ( - ProcessPoolExecutorAdapter, - ProcessPoolExecutorInterface, -) -from ai_ta_backend.executors.thread_pool_executor import ( - ThreadPoolExecutorAdapter, - ThreadPoolExecutorInterface, -) +from ai_ta_backend.database.qdrant import VectorDatabase +from ai_ta_backend.database.sql import SQLAlchemyDatabase +from ai_ta_backend.executors.flask_executor import ExecutorInterface +from ai_ta_backend.executors.flask_executor import FlaskExecutorAdapter +from ai_ta_backend.executors.process_pool_executor import \ + ProcessPoolExecutorAdapter +from ai_ta_backend.executors.process_pool_executor import \ + ProcessPoolExecutorInterface +from ai_ta_backend.executors.thread_pool_executor import \ + ThreadPoolExecutorAdapter +from ai_ta_backend.executors.thread_pool_executor import \ + ThreadPoolExecutorInterface +from ai_ta_backend.extensions import db from ai_ta_backend.service.export_service import ExportService from ai_ta_backend.service.nomic_service import NomicService +from ai_ta_backend.service.poi_agent_service_v2 import POIAgentService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.retrieval_service import RetrievalService from ai_ta_backend.service.sentry_service import SentryService - -from ai_ta_backend.beam.nomic_logging import create_document_map from ai_ta_backend.service.workflow_service import WorkflowService + +from langchain_core.messages import HumanMessage, SystemMessage + app = Flask(__name__) CORS(app) executor = Executor(app) # app.config['EXECUTOR_MAX_WORKERS'] = 5 nothing == picks defaults for me -#app.config['SERVER_TIMEOUT'] = 1000 # seconds +# app.config['SERVER_TIMEOUT'] = 1000 # seconds # load API keys from globally-availabe .env file -load_dotenv() - +load_dotenv(override=True) @app.route('/') def index() -> Response: @@ -62,8 +62,7 @@ def index() -> Response: Returns: JSON: _description_ """ - response = jsonify( - {"hi there, this is a 404": "Welcome to UIUC.chat backend πŸš… Read the docs here: https://docs.uiuc.chat/ "}) + response = jsonify({"hi there, this is a 404": "Welcome to UIUC.chat backend πŸš… Read the docs here: https://docs.uiuc.chat/ "}) response.headers.add('Access-Control-Allow-Origin', '*') return response @@ -111,6 +110,9 @@ def getTopContexts(service: RetrievalService) -> Response: token_limit: int = data.get('token_limit', 3000) doc_groups: List[str] = data.get('doc_groups', []) + logging.info(f"QDRANT URL {os.environ['QDRANT_URL']}") + logging.info(f"QDRANT_API_KEY {os.environ['QDRANT_API_KEY']}") + if search_query == '' or course_name == '': # proper web error "400 Bad request" abort( @@ -130,13 +132,12 @@ def getTopContexts(service: RetrievalService) -> Response: def getAll(service: RetrievalService) -> Response: """Get all course materials based on the course_name """ + logging.info("In getAll()") course_name: List[str] | str = request.args.get('course_name', default='', type=str) if course_name == '': # proper web error "400 Bad request" - abort( - 400, - description=f"Missing the one required parameter: 'course_name' must be provided. Course name: `{course_name}`") + abort(400, description=f"Missing the one required parameter: 'course_name' must be provided. Course name: `{course_name}`") distinct_dicts = service.getAll(course_name) @@ -166,8 +167,8 @@ def delete(service: RetrievalService, flaskExecutor: ExecutorInterface): start_time = time.monotonic() # background execution of tasks!! flaskExecutor.submit(service.delete_data, course_name, s3_path, source_url) - print(f"From {course_name}, deleted file: {s3_path}") - print(f"⏰ Runtime of FULL delete func: {(time.monotonic() - start_time):.2f} seconds") + logging.info(f"From {course_name}, deleted file: {s3_path}") + logging.info(f"⏰ Runtime of FULL delete func: {(time.monotonic() - start_time):.2f} seconds") # we need instant return. Delets are "best effort" assume always successful... sigh :( response = jsonify({"outcome": 'success'}) response.headers.add('Access-Control-Allow-Origin', '*') @@ -184,26 +185,27 @@ def nomic_map(service: NomicService): abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") map_id = service.get_nomic_map(course_name, map_type) - print("nomic map\n", map_id) + logging.info("nomic map\n", map_id) response = jsonify(map_id) response.headers.add('Access-Control-Allow-Origin', '*') return response -@app.route('/createDocumentMap', methods=['GET']) -def createDocumentMap(service: NomicService): - course_name: str = request.args.get('course_name', default='', type=str) +# @app.route('/createDocumentMap', methods=['GET']) +# def createDocumentMap(service: NomicService): +# course_name: str = request.args.get('course_name', default='', type=str) - if course_name == '': - # proper web error "400 Bad request" - abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") +# if course_name == '': +# # proper web error "400 Bad request" +# abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") - map_id = create_document_map(course_name) +# map_id = create_document_map(course_name) + +# response = jsonify(map_id) +# response.headers.add('Access-Control-Allow-Origin', '*') +# return response - response = jsonify(map_id) - response.headers.add('Access-Control-Allow-Origin', '*') - return response @app.route('/createConversationMap', methods=['GET']) def createConversationMap(service: NomicService): @@ -219,6 +221,25 @@ def createConversationMap(service: NomicService): response.headers.add('Access-Control-Allow-Origin', '*') return response + +@app.route('/query_sql_agent', methods=['POST']) +def query_sql_agent(service: POIAgentService): + data = request.get_json() + user_input = data["query"] + system_message = SystemMessage(content="you are a helpful assistant and need to provide answers in text format about the plants found in India. If the Question is not related to plants in India answer 'I do not have any information on this.'") + + if not user_input: + return jsonify({"error": "No query provided"}), 400 + + try: + user_01 = HumanMessage(content=user_input) + inputs = {"messages": [system_message,user_01]} + response = service.run_workflow(inputs) + return str(response), 200 + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route('/logToConversationMap', methods=['GET']) def logToConversationMap(service: NomicService, flaskExecutor: ExecutorInterface): course_name: str = request.args.get('course_name', default='', type=str) @@ -248,11 +269,11 @@ def logToNomic(service: NomicService, flaskExecutor: ExecutorInterface): description= f"Missing one or more required parameters: 'course_name' and 'conversation' must be provided. Course name: `{course_name}`, Conversation: `{conversation}`" ) - print(f"In /onResponseCompletion for course: {course_name}") + logging.info(f"In /onResponseCompletion for course: {course_name}") # background execution of tasks!! #response = flaskExecutor.submit(service.log_convo_to_nomic, course_name, data) - result = flaskExecutor.submit(service.log_to_conversation_map, course_name, conversation).result() + flaskExecutor.submit(service.log_to_conversation_map, course_name, conversation).result() response = jsonify({'outcome': 'success'}) response.headers.add('Access-Control-Allow-Origin', '*') return response @@ -269,7 +290,7 @@ def export_convo_history(service: ExportService): abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") export_status = service.export_convo_history_json(course_name, from_date, to_date) - print("EXPORT FILE LINKS: ", export_status) + logging.info("EXPORT FILE LINKS: ", export_status) if export_status['response'] == "No data found between the given dates.": response = Response(status=204) @@ -280,14 +301,14 @@ def export_convo_history(service: ExportService): response.headers.add('Access-Control-Allow-Origin', '*') else: - response = make_response( - send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) + response = make_response(send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) response.headers.add('Access-Control-Allow-Origin', '*') response.headers["Content-Disposition"] = f"attachment; filename={export_status['response'][1]}" os.remove(export_status['response'][0]) return response + @app.route('/export-conversations-custom', methods=['GET']) def export_conversations_custom(service: ExportService): course_name: str = request.args.get('course_name', default='', type=str) @@ -297,10 +318,10 @@ def export_conversations_custom(service: ExportService): if course_name == '' and emails == []: # proper web error "400 Bad request" - abort(400, description=f"Missing required parameter: 'course_name' and 'destination_email_ids' must be provided.") + abort(400, description="Missing required parameter: 'course_name' and 'destination_email_ids' must be provided.") export_status = service.export_conversations(course_name, from_date, to_date, emails) - print("EXPORT FILE LINKS: ", export_status) + logging.info("EXPORT FILE LINKS: ", export_status) if export_status['response'] == "No data found between the given dates.": response = Response(status=204) @@ -311,8 +332,7 @@ def export_conversations_custom(service: ExportService): response.headers.add('Access-Control-Allow-Origin', '*') else: - response = make_response( - send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) + response = make_response(send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) response.headers.add('Access-Control-Allow-Origin', '*') response.headers["Content-Disposition"] = f"attachment; filename={export_status['response'][1]}" os.remove(export_status['response'][0]) @@ -331,7 +351,7 @@ def exportDocuments(service: ExportService): abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") export_status = service.export_documents_json(course_name, from_date, to_date) - print("EXPORT FILE LINKS: ", export_status) + logging.info("EXPORT FILE LINKS: ", export_status) if export_status['response'] == "No data found between the given dates.": response = Response(status=204) @@ -342,8 +362,7 @@ def exportDocuments(service: ExportService): response.headers.add('Access-Control-Allow-Origin', '*') else: - response = make_response( - send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) + response = make_response(send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) response.headers.add('Access-Control-Allow-Origin', '*') response.headers["Content-Disposition"] = f"attachment; filename={export_status['response'][1]}" os.remove(export_status['response'][0]) @@ -380,6 +399,7 @@ def getTopContextsWithMQR(service: RetrievalService, posthog_service: PosthogSer response.headers.add('Access-Control-Allow-Origin', '*') return response + @app.route('/getworkflows', methods=['GET']) def get_all_workflows(service: WorkflowService) -> Response: """ @@ -391,10 +411,9 @@ def get_all_workflows(service: WorkflowService) -> Response: pagination = request.args.get('pagination', default=True, type=bool) active = request.args.get('active', default=False, type=bool) name = request.args.get('workflow_name', default='', type=str) - print(request.args) - - print("In get_all_workflows.. api_key: ", api_key) + logging.info(request.args) + logging.info("In get_all_workflows.. api_key: ", api_key) # if no API Key, return empty set. # if api_key == '': @@ -408,10 +427,10 @@ def get_all_workflows(service: WorkflowService) -> Response: return response except Exception as e: if "unauthorized" in str(e).lower(): - print("Unauthorized error in get_all_workflows: ", e) + logging.info("Unauthorized error in get_all_workflows: ", e) abort(401, description=f"Unauthorized: 'api_key' is invalid. Search query: `{api_key}`") else: - print("Error in get_all_workflows: ", e) + logging.info("Error in get_all_workflows: ", e) abort(500, description=f"Failed to fetch n8n workflows: {e}") @@ -425,14 +444,14 @@ def switch_workflow(service: WorkflowService) -> Response: activate = request.args.get('activate', default='', type=str) id = request.args.get('id', default='', type=str) - print(request.args) + logging.info(request.args) if api_key == '': # proper web error "400 Bad request" abort(400, description=f"Missing N8N API_KEY: 'api_key' must be provided. Search query: `{api_key}`") try: - print("activation!!!!!!!!!!!", activate) + logging.info("activation!!!!!!!!!!!", activate) response = service.switch_workflow(id, api_key, activate) response = jsonify(response) response.headers.add('Access-Control-Allow-Origin', '*') @@ -454,7 +473,7 @@ def run_flow(service: WorkflowService) -> Response: name = request.json.get('name', '') data = request.json.get('data', '') - print("Got /run_flow request:", request.json) + logging.info("Got /run_flow request:", request.json) if api_key == '': # proper web error "400 Bad request" @@ -473,18 +492,88 @@ def run_flow(service: WorkflowService) -> Response: def configure(binder: Binder) -> None: - binder.bind(RetrievalService, to=RetrievalService, scope=RequestScope) - binder.bind(PosthogService, to=PosthogService, scope=SingletonScope) - binder.bind(SentryService, to=SentryService, scope=SingletonScope) - binder.bind(NomicService, to=NomicService, scope=SingletonScope) - binder.bind(ExportService, to=ExportService, scope=SingletonScope) - binder.bind(WorkflowService, to=WorkflowService, scope=SingletonScope) - binder.bind(VectorDatabase, to=VectorDatabase, scope=SingletonScope) - binder.bind(SQLDatabase, to=SQLDatabase, scope=SingletonScope) - binder.bind(AWSStorage, to=AWSStorage, scope=SingletonScope) + vector_bound = False + sql_bound = False + storage_bound = False + + # Define database URLs with conditional checks for environment variables + DB_URLS = { + 'supabase': + f"postgresql://{os.getenv('SUPABASE_KEY')}@{os.getenv('SUPABASE_URL')}" + if os.getenv('SUPABASE_KEY') and os.getenv('SUPABASE_URL') else None, + 'sqlite': + f"sqlite:///{os.getenv('SQLITE_DB_NAME')}" if os.getenv('SQLITE_DB_NAME') else None, + 'postgres': + f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@{os.getenv('POSTGRES_URL')}" + if os.getenv('POSTGRES_USER') and os.getenv('POSTGRES_PASSWORD') and os.getenv('POSTGRES_URL') else None + } + + # Bind to the first available SQL database configuration + for db_type, url in DB_URLS.items(): + if url: + logging.info(f"Binding to {db_type} database with URL: {url}") + with app.app_context(): + app.config['SQLALCHEMY_DATABASE_URI'] = url + db.init_app(app) + db.create_all() + binder.bind(SQLAlchemyDatabase, to=SQLAlchemyDatabase(db), scope=SingletonScope) + sql_bound = True + break + + if os.getenv("POI_SQL_DB_NAME"): + logging.info(f"Binding to POI SQL database with URL: {os.getenv('POI_SQL_DB_NAME')}") + binder.bind(POISQLDatabase, to=POISQLDatabase(db), scope=SingletonScope) + binder.bind(POIAgentService, to=POIAgentService, scope=SingletonScope) + # Conditionally bind databases based on the availability of their respective secrets + if all(os.getenv(key) for key in ["QDRANT_URL", "QDRANT_API_KEY", "QDRANT_COLLECTION_NAME"]) or any( + os.getenv(key) for key in ["PINECONE_API_KEY", "PINECONE_PROJECT_NAME"]): + logging.info("Binding to Qdrant database") + + logging.info(f"Qdrant Collection Name: {os.environ['QDRANT_COLLECTION_NAME']}") + logging.info(f"Qdrant URL: {os.environ['QDRANT_URL']}") + logging.info(f"Qdrant API Key: {os.environ['QDRANT_API_KEY']}") + binder.bind(VectorDatabase, to=VectorDatabase, scope=SingletonScope) + vector_bound = True + + if all(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "S3_BUCKET_NAME"]) or any( + os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): + if os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY"): + logging.info("Binding to AWS storage") + elif os.getenv("MINIO_ACCESS_KEY") and os.getenv("MINIO_SECRET_KEY"): + logging.info("Binding to Minio storage") + binder.bind(AWSStorage, to=AWSStorage, scope=SingletonScope) + storage_bound = True + + # Conditionally bind services based on the availability of their respective secrets + if os.getenv("NOMIC_API_KEY"): + logging.info("Binding to Nomic service") + binder.bind(NomicService, to=NomicService, scope=SingletonScope) + + if os.getenv("POSTHOG_API_KEY"): + logging.info("Binding to Posthog service") + binder.bind(PosthogService, to=PosthogService, scope=SingletonScope) + + if os.getenv("SENTRY_DSN"): + logging.info("Binding to Sentry service") + binder.bind(SentryService, to=SentryService, scope=SingletonScope) + + if os.getenv("EMAIL_SENDER"): + logging.info("Binding to Export service") + binder.bind(ExportService, to=ExportService, scope=SingletonScope) + + if os.getenv("N8N_URL"): + logging.info("Binding to Workflow service") + binder.bind(WorkflowService, to=WorkflowService, scope=SingletonScope) + + if vector_bound and sql_bound and storage_bound: + logging.info("Binding to Retrieval service") + binder.bind(RetrievalService, to=RetrievalService, scope=RequestScope) + + # Always bind the executor and its adapters binder.bind(ExecutorInterface, to=FlaskExecutorAdapter(executor), scope=SingletonScope) binder.bind(ThreadPoolExecutorInterface, to=ThreadPoolExecutorAdapter, scope=SingletonScope) binder.bind(ProcessPoolExecutorInterface, to=ProcessPoolExecutorAdapter, scope=SingletonScope) + logging.info("Configured all services and adapters", binder._bindings) FlaskInjector(app=app, modules=[configure]) diff --git a/ai_ta_backend/modal/pest_detection.py b/ai_ta_backend/modal/pest_detection.py index 1500a891..0a528acd 100644 --- a/ai_ta_backend/modal/pest_detection.py +++ b/ai_ta_backend/modal/pest_detection.py @@ -15,15 +15,20 @@ """ import inspect import json +import logging import os -import traceback -import uuid from tempfile import NamedTemporaryFile +import traceback from typing import List +import uuid -import modal from fastapi import Request -from modal import Secret, Stub, build, enter, web_endpoint +import modal +from modal import build +from modal import enter +from modal import Secret +from modal import Stub +from modal import web_endpoint # Simpler image, but slower cold starts: modal.Image.from_registry('ultralytics/ultralytics:latest-cpu') image = ( @@ -42,16 +47,10 @@ # Imports needed inside the image with image.imports(): - import inspect - import os - import traceback - import uuid - from tempfile import NamedTemporaryFile - from typing import List import boto3 - import requests from PIL import Image + import requests from ultralytics import YOLO @@ -90,20 +89,20 @@ async def predict(self, request: Request): This used to use the method decorator Run the pest detection plugin on an image. """ - print("Inside predict() endpoint") + logging.info("Inside predict() endpoint") input = await request.json() - print("Request.json(): ", input) + logging.info("Request.json(): ", input) image_urls = input.get('image_urls', []) if image_urls and isinstance(image_urls, str): image_urls = json.loads(image_urls) - print(f"Final image URLs: {image_urls}") + logging.info(f"Final image URLs: {image_urls}") try: # Run the plugin annotated_images = self._detect_pests(image_urls) - print(f"annotated_images found: {len(annotated_images)}") + logging.info(f"annotated_images found: {len(annotated_images)}") results = [] # Generate a unique ID for the request unique_id = uuid.uuid4() @@ -132,7 +131,7 @@ async def predict(self, request: Request): return results except Exception as e: err = f"❌❌ Error in (pest_detection): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n{traceback.format_exc()}" # type: ignore - print(err) + logging.info(err) # sentry_sdk.capture_exception(e) return err diff --git a/ai_ta_backend/model/models.py b/ai_ta_backend/model/models.py new file mode 100644 index 00000000..cac15cb7 --- /dev/null +++ b/ai_ta_backend/model/models.py @@ -0,0 +1,95 @@ +from sqlalchemy import BigInteger +from sqlalchemy import Boolean +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import ForeignKey +from sqlalchemy import Index +from sqlalchemy import JSON +from sqlalchemy import Text +from sqlalchemy.sql import func + +from ai_ta_backend.extensions import db + + +class Base(db.Model): + __abstract__ = True + + +class Document(Base): + __tablename__ = 'documents' + id = Column(BigInteger, primary_key=True, autoincrement=True) + created_at = Column(DateTime, default=func.now()) + s3_path = Column(Text) + readable_filename = Column(Text) + course_name = Column(Text) + url = Column(Text) + contexts = Column(JSON, default=lambda: [{"text": "", "timestamp": "", "embedding": "", "pagenumber": ""}]) + base_url = Column(Text) + + __table_args__ = ( + Index('documents_course_name_idx', 'course_name', postgresql_using='hash'), + Index('documents_created_at_idx', 'created_at', postgresql_using='btree'), + Index('idx_doc_s3_path', 's3_path', postgresql_using='btree'), + ) + + +class DocumentDocGroup(Base): + __tablename__ = 'documents_doc_groups' + document_id = Column(BigInteger, primary_key=True) + doc_group_id = Column(BigInteger, ForeignKey('doc_groups.id', ondelete='CASCADE'), primary_key=True) + created_at = Column(DateTime, default=func.now()) + + __table_args__ = ( + Index('documents_doc_groups_doc_group_id_idx', 'doc_group_id', postgresql_using='btree'), + Index('documents_doc_groups_document_id_idx', 'document_id', postgresql_using='btree'), + ) + + +class DocGroup(Base): + __tablename__ = 'doc_groups' + id = Column(BigInteger, primary_key=True, autoincrement=True) + name = Column(Text, nullable=False) + course_name = Column(Text, nullable=False) + created_at = Column(DateTime, default=func.now()) + enabled = Column(Boolean, default=True) + private = Column(Boolean, default=True) + doc_count = Column(BigInteger) + + __table_args__ = (Index('doc_groups_enabled_course_name_idx', 'enabled', 'course_name', postgresql_using='btree'),) + + +class Project(Base): + __tablename__ = 'projects' + id = Column(BigInteger, primary_key=True, autoincrement=True) + created_at = Column(DateTime, default=func.now()) + course_name = Column(Text) + doc_map_id = Column(Text) + convo_map_id = Column(Text) + n8n_api_key = Column(Text) + last_uploaded_doc_id = Column(BigInteger) + last_uploaded_convo_id = Column(BigInteger) + subscribed = Column(BigInteger, ForeignKey('doc_groups.id', onupdate='CASCADE', ondelete='SET NULL')) + + +class N8nWorkflows(Base): + __tablename__ = 'n8n_workflows' + latest_workflow_id = Column(BigInteger, primary_key=True, autoincrement=True) + is_locked = Column(Boolean, nullable=False) + + def __init__(self, is_locked: bool): + self.is_locked = is_locked + + +class LlmConvoMonitor(Base): + __tablename__ = 'llm_convo_monitor' + id = Column(BigInteger, primary_key=True, autoincrement=True) + created_at = Column(DateTime, default=func.now()) + convo = Column(JSON) + convo_id = Column(Text, unique=True) + course_name = Column(Text) + user_email = Column(Text) + + __table_args__ = ( + Index('llm_convo_monitor_course_name_idx', 'course_name', postgresql_using='hash'), + Index('llm_convo_monitor_convo_id_idx', 'convo_id', postgresql_using='hash'), + ) diff --git a/ai_ta_backend/model/response.py b/ai_ta_backend/model/response.py new file mode 100644 index 00000000..f0ae5f07 --- /dev/null +++ b/ai_ta_backend/model/response.py @@ -0,0 +1,12 @@ +from typing import Generic, List, TypeVar + +from flask_sqlalchemy.model import Model + +T = TypeVar('T', bound=Model) + + +class DatabaseResponse(Generic[T]): + + def __init__(self, data: List[T], count: int): + self.data = data + self.count = count diff --git a/ai_ta_backend/plants_of_India_demo.db b/ai_ta_backend/plants_of_India_demo.db new file mode 100644 index 00000000..634e4596 Binary files /dev/null and b/ai_ta_backend/plants_of_India_demo.db differ diff --git a/requirements.txt b/ai_ta_backend/requirements.txt similarity index 88% rename from requirements.txt rename to ai_ta_backend/requirements.txt index 848c10d0..dfb67bb0 100644 --- a/requirements.txt +++ b/ai_ta_backend/requirements.txt @@ -13,22 +13,26 @@ mkdocs-material==9.4.7 itsdangerous==2.1.2 Jinja2==3.1.2 mkdocs==1.5.3 -SQLAlchemy==2.0.22 +Flask-SQLAlchemy==3.1.1 tabulate==0.9.0 typing-inspect==0.9.0 typing_extensions==4.8.0 # Utils -tiktoken==0.5.1 +tiktoken==0.7.0 python-dotenv==1.0.0 pydantic==1.10.13 # pydantic v1 works better for ray flask-executor==1.0.0 # AI & core services nomic==2.0.14 -openai==0.28.1 -langchain==0.0.331 +openai==1.31.2 +langchain==0.2.2 langchainhub==0.1.14 +langgraph==0.0.69 +faiss-cpu==1.8.0 +langchain-community==0.2.3 +langchain-openai==0.1.8 # Data boto3==1.28.79 diff --git a/ai_ta_backend/service/export_service.py b/ai_ta_backend/service/export_service.py index 85f01118..a945453d 100644 --- a/ai_ta_backend/service/export_service.py +++ b/ai_ta_backend/service/export_service.py @@ -1,15 +1,17 @@ +from concurrent.futures import ProcessPoolExecutor import json +import logging import os import uuid import zipfile -from concurrent.futures import ProcessPoolExecutor +from injector import inject import pandas as pd import requests -from injector import inject from ai_ta_backend.database.aws import AWSStorage -from ai_ta_backend.database.sql import SQLDatabase +from ai_ta_backend.database.sql import SQLAlchemyDatabase +from ai_ta_backend.extensions import db from ai_ta_backend.service.sentry_service import SentryService from ai_ta_backend.utils.emails import send_email @@ -17,7 +19,7 @@ class ExportService: @inject - def __init__(self, sql: SQLDatabase, s3: AWSStorage, sentry: SentryService): + def __init__(self, sql: SQLAlchemyDatabase, s3: AWSStorage, sentry: SentryService): self.sql = sql self.s3 = s3 self.sentry = sentry @@ -33,9 +35,10 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): to_date (str, optional): The end date for the data export. Defaults to ''. """ - response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date, 'documents') + response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date) + # add a condition to route to direct download or s3 download - if response.count > 500: + if response.count and response.count > 500: # call background task to upload to s3 filename = course_name + '_' + str(uuid.uuid4()) + '_documents.zip' @@ -47,22 +50,22 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): else: # Fetch data - if response.count > 0: + if response.count and response.count > 0: # batch download total_doc_count = response.count - first_id = response.data[0]['id'] - last_id = response.data[-1]['id'] + first_id = int(str(response.data[0].id)) + last_id = int(str(response.data[-1].id)) - print("total_doc_count: ", total_doc_count) - print("first_id: ", first_id) - print("last_id: ", last_id) + logging.info("total_doc_count: ", total_doc_count) + logging.info("first_id: ", first_id) + logging.info("last_id: ", last_id) curr_doc_count = 0 filename = course_name + '_' + str(uuid.uuid4()) + '_documents.jsonl' file_path = os.path.join(os.getcwd(), filename) while curr_doc_count < total_doc_count: - print("Fetching data from id: ", first_id) + logging.info("Fetching data from id: ", first_id) response = self.sql.getDocsForIdsGte(course_name, first_id) df = pd.DataFrame(response.data) @@ -75,7 +78,7 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): df.to_json(file_path, orient='records', lines=True, mode='a') if len(response.data) > 0: - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 # Download file try: @@ -89,7 +92,7 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): os.remove(file_path) return {"response": (zip_file_path, zip_filename, os.getcwd())} except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return {"response": "Error downloading file."} else: @@ -103,9 +106,9 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): from_date (str, optional): The start date for the data export. Defaults to ''. to_date (str, optional): The end date for the data export. Defaults to ''. """ - print("Exporting conversation history to json file...") + logging.info("Exporting conversation history to json file...") - response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date, 'llm-convo-monitor') + response = self.sql.getConversationsBetweenDates(course_name, from_date, to_date) if response.count > 500: # call background task to upload to s3 @@ -118,9 +121,9 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): # Fetch data if response.count > 0: - print("id count greater than zero") - first_id = response.data[0]['id'] - last_id = response.data[-1]['id'] + logging.info("id count greater than zero") + first_id = int(str(response.data[0].id)) + last_id = int(str(response.data[-1].id)) total_count = response.count filename = course_name + '_' + str(uuid.uuid4()) + '_convo_history.jsonl' @@ -128,7 +131,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): curr_count = 0 # Fetch data in batches of 25 from first_id to last_id while curr_count < total_count: - print("Fetching data from id: ", first_id) + logging.info("Fetching data from id: ", first_id) response = self.sql.getAllConversationsBetweenIds(course_name, first_id, last_id) # Convert to pandas dataframe df = pd.DataFrame(response.data) @@ -142,8 +145,8 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): # Update first_id if len(response.data) > 0: - first_id = response.data[-1]['id'] + 1 - print("updated first_id: ", first_id) + first_id = int(str(response.data[-1].id)) + 1 + logging.info("updated first_id: ", first_id) # Download file try: @@ -157,7 +160,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): return {"response": (zip_file_path, zip_filename, os.getcwd())} except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return {"response": "Error downloading file!"} else: @@ -167,9 +170,9 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e """ Another function for exporting convos, emails are passed as a string. """ - print("Exporting conversation history to json file...") + logging.info("Exporting conversation history to json file...") - response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date, 'llm-convo-monitor') + response = self.sql.getConversationsBetweenDates(course_name, from_date, to_date) if response.count > 500: # call background task to upload to s3 @@ -182,9 +185,9 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e # Fetch data if response.count > 0: - print("id count greater than zero") - first_id = response.data[0]['id'] - last_id = response.data[-1]['id'] + logging.info("id count greater than zero") + first_id = int(str(response.data[0].id)) + last_id = int(str(response.data[-1].id)) total_count = response.count filename = course_name + '_' + str(uuid.uuid4()) + '_convo_history.jsonl' @@ -192,7 +195,7 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e curr_count = 0 # Fetch data in batches of 25 from first_id to last_id while curr_count < total_count: - print("Fetching data from id: ", first_id) + logging.info("Fetching data from id: ", first_id) response = self.sql.getAllConversationsBetweenIds(course_name, first_id, last_id) # Convert to pandas dataframe df = pd.DataFrame(response.data) @@ -206,8 +209,8 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e # Update first_id if len(response.data) > 0: - first_id = response.data[-1]['id'] + 1 - print("updated first_id: ", first_id) + first_id = int(str(response.data[-1].id)) + 1 + logging.info("updated first_id: ", first_id) # Download file try: @@ -221,12 +224,12 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e return {"response": (zip_file_path, zip_filename, os.getcwd())} except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return {"response": "Error downloading file!"} else: return {"response": "No data found between the given dates."} - + # Encountered pickling error while running the background task. So, moved the function outside the class. @@ -238,18 +241,18 @@ def export_data_in_bg(response, download_type, course_name, s3_path): 3. send an email to the course admins with the pre-signed URL. Args: - response (dict): The response from the Supabase query. - download_type (str): The type of download - 'documents' or 'conversations'. - course_name (str): The name of the course. - s3_path (str): The S3 path where the file will be uploaded. + response (dict): The response from the Supabase query. + download_type (str): The type of download - 'documents' or 'conversations'. + course_name (str): The name of the course. + s3_path (str): The S3 path where the file will be uploaded. """ s3 = AWSStorage() - sql = SQLDatabase() + sql = SQLAlchemyDatabase(db) total_doc_count = response.count first_id = response.data[0]['id'] - print("total_doc_count: ", total_doc_count) - print("pre-defined s3_path: ", s3_path) + logging.info("total_doc_count: ", total_doc_count) + logging.info("pre-defined s3_path: ", s3_path) curr_doc_count = 0 filename = s3_path.split('/')[-1].split('.')[0] + '.jsonl' @@ -257,8 +260,11 @@ def export_data_in_bg(response, download_type, course_name, s3_path): # download data in batches of 100 while curr_doc_count < total_doc_count: - print("Fetching data from id: ", first_id) - response = sql.getAllFromTableForDownloadType(course_name, download_type, first_id) + logging.info("Fetching data from id: ", first_id) + if download_type == "documents": + response = sql.getAllDocumentsForDownload(course_name, first_id) + else: + response = sql.getAllConversationsForDownload(course_name, first_id) df = pd.DataFrame(response.data) curr_doc_count += len(response.data) @@ -269,7 +275,7 @@ def export_data_in_bg(response, download_type, course_name, s3_path): df.to_json(file_path, orient='records', lines=True, mode='a') if len(response.data) > 0: - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 # zip file zip_filename = filename.split('.')[0] + '.zip' @@ -278,7 +284,7 @@ def export_data_in_bg(response, download_type, course_name, s3_path): with zipfile.ZipFile(zip_file_path, 'w', compression=zipfile.ZIP_DEFLATED) as zipf: zipf.write(file_path, filename) - print("zip file created: ", zip_file_path) + logging.info("zip file created: ", zip_file_path) try: # upload to S3 @@ -291,12 +297,11 @@ def export_data_in_bg(response, download_type, course_name, s3_path): os.remove(file_path) os.remove(zip_file_path) - print("file uploaded to s3: ", s3_file) + logging.info("file uploaded to s3: ", s3_file) # generate presigned URL s3_url = s3.generatePresignedUrl('get_object', os.environ['S3_BUCKET_NAME'], s3_path, 172800) - # get admin email IDs headers = {"Authorization": f"Bearer {os.environ['VERCEL_READ_ONLY_API_KEY']}", "Content-Type": "application/json"} @@ -315,8 +320,8 @@ def export_data_in_bg(response, download_type, course_name, s3_path): # add course owner email to admin_emails admin_emails.append(course_metadata['course_owner']) admin_emails = list(set(admin_emails)) - print("admin_emails: ", admin_emails) - print("bcc_emails: ", bcc_emails) + logging.info("admin_emails: ", admin_emails) + logging.info("bcc_emails: ", bcc_emails) # add a check for emails, don't send email if no admin emails if len(admin_emails) == 0: @@ -331,14 +336,15 @@ def export_data_in_bg(response, download_type, course_name, s3_path): subject = "UIUC.chat Export Complete for " + course_name body_text = "The data export for " + course_name + " is complete.\n\nYou can download the file from the following link: \n\n" + s3_url + "\n\nThis link will expire in 48 hours." email_status = send_email(subject, body_text, os.environ['EMAIL_SENDER'], admin_emails, bcc_emails) - print("email_status: ", email_status) + logging.info("email_status: ", email_status) return "File uploaded to S3. Email sent to admins." except Exception as e: - print(e) + logging.info(e) return "Error: " + str(e) + def export_data_in_bg_emails(response, download_type, course_name, s3_path, emails): """ This function is called in export_documents_csv() to upload the documents to S3. @@ -347,18 +353,18 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai 3. send an email to the course admins with the pre-signed URL. Args: - response (dict): The response from the Supabase query. - download_type (str): The type of download - 'documents' or 'conversations'. - course_name (str): The name of the course. - s3_path (str): The S3 path where the file will be uploaded. + response (dict): The response from the Supabase query. + download_type (str): The type of download - 'documents' or 'conversations'. + course_name (str): The name of the course. + s3_path (str): The S3 path where the file will be uploaded. """ s3 = AWSStorage() - sql = SQLDatabase() + sql = SQLAlchemyDatabase(db) total_doc_count = response.count first_id = response.data[0]['id'] - print("total_doc_count: ", total_doc_count) - print("pre-defined s3_path: ", s3_path) + logging.info("total_doc_count: ", total_doc_count) + logging.info("pre-defined s3_path: ", s3_path) curr_doc_count = 0 filename = s3_path.split('/')[-1].split('.')[0] + '.jsonl' @@ -366,8 +372,11 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai # download data in batches of 100 while curr_doc_count < total_doc_count: - print("Fetching data from id: ", first_id) - response = sql.getAllFromTableForDownloadType(course_name, download_type, first_id) + logging.info("Fetching data from id: ", first_id) + if download_type == "documents": + response = sql.getAllDocumentsForDownload(course_name, first_id) + else: + response = sql.getAllConversationsForDownload(course_name, first_id) df = pd.DataFrame(response.data) curr_doc_count += len(response.data) @@ -378,7 +387,7 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai df.to_json(file_path, orient='records', lines=True, mode='a') if len(response.data) > 0: - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 # zip file zip_filename = filename.split('.')[0] + '.zip' @@ -387,7 +396,7 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai with zipfile.ZipFile(zip_file_path, 'w', compression=zipfile.ZIP_DEFLATED) as zipf: zipf.write(file_path, filename) - print("zip file created: ", zip_file_path) + logging.info("zip file created: ", zip_file_path) try: # upload to S3 @@ -400,16 +409,16 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai os.remove(file_path) os.remove(zip_file_path) - print("file uploaded to s3: ", s3_file) + logging.info("file uploaded to s3: ", s3_file) # generate presigned URL s3_url = s3.generatePresignedUrl('get_object', os.environ['S3_BUCKET_NAME'], s3_path, 172800) admin_emails = emails bcc_emails = [] - - print("admin_emails: ", admin_emails) - print("bcc_emails: ", bcc_emails) + + logging.info("admin_emails: ", admin_emails) + logging.info("bcc_emails: ", bcc_emails) # add a check for emails, don't send email if no admin emails if len(admin_emails) == 0: @@ -424,10 +433,10 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai subject = "UIUC.chat Export Complete for " + course_name body_text = "The data export for " + course_name + " is complete.\n\nYou can download the file from the following link: \n\n" + s3_url + "\n\nThis link will expire in 48 hours." email_status = send_email(subject, body_text, os.environ['EMAIL_SENDER'], admin_emails, bcc_emails) - print("email_status: ", email_status) + logging.info("email_status: ", email_status) return "File uploaded to S3. Email sent to admins." except Exception as e: - print(e) - return "Error: " + str(e) \ No newline at end of file + logging.info(e) + return "Error: " + str(e) diff --git a/ai_ta_backend/service/nomic_service.py b/ai_ta_backend/service/nomic_service.py index 80ca86ca..b97bfd1d 100644 --- a/ai_ta_backend/service/nomic_service.py +++ b/ai_ta_backend/service/nomic_service.py @@ -1,17 +1,17 @@ import datetime +import logging import os import time -from typing import Union -import backoff +from injector import inject +from langchain.embeddings.openai import OpenAIEmbeddings import nomic +from nomic import atlas +from nomic import AtlasProject import numpy as np import pandas as pd -from injector import inject -from langchain.embeddings.openai import OpenAIEmbeddings -from nomic import AtlasProject, atlas -from ai_ta_backend.database.sql import SQLDatabase +from ai_ta_backend.database.sql import SQLAlchemyDatabase from ai_ta_backend.service.sentry_service import SentryService LOCK_EXCEPTIONS = [ @@ -24,7 +24,7 @@ class NomicService(): @inject - def __init__(self, sentry: SentryService, sql: SQLDatabase): + def __init__(self, sentry: SentryService, sql: SQLAlchemyDatabase): nomic.login(os.environ['NOMIC_API_KEY']) self.sentry = sentry self.sql = sql @@ -50,20 +50,17 @@ def get_nomic_map(self, course_name: str, type: str): project = atlas.AtlasProject(name=project_name, add_datums_if_exists=True) map = project.get_map(project_name) - print(f"⏰ Nomic Full Map Retrieval: {(time.monotonic() - start_time):.2f} seconds") + logging.info(f"⏰ Nomic Full Map Retrieval: {(time.monotonic() - start_time):.2f} seconds") return {"map_id": f"iframe{map.id}", "map_link": map.map_link} except Exception as e: # Error: ValueError: You must specify a unique_id_field when creating a new project. if str(e) == 'You must specify a unique_id_field when creating a new project.': # type: ignore - print( - "Nomic map does not exist yet, probably because you have less than 20 queries/documents on your project: ", - e) + logging.info("Nomic map does not exist yet, probably because you have less than 20 queries/documents on your project: ", e) else: - print("ERROR in get_nomic_map():", e) + logging.info("ERROR in get_nomic_map():", e) self.sentry.capture_exception(e) return {"map_id": None, "map_link": None} - def log_to_conversation_map(self, course_name: str, conversation): """ This function logs new conversations to existing nomic maps. @@ -76,20 +73,20 @@ def log_to_conversation_map(self, course_name: str, conversation): try: # check if map exists response = self.sql.getConvoMapFromProjects(course_name) - print("Response from supabase: ", response.data) + logging.info("Response from supabase: ", response.data) # entry not present in projects table if not response.data: - print("Map does not exist for this course. Redirecting to map creation...") + logging.info("Map does not exist for this course. Redirecting to map creation...") return self.create_conversation_map(course_name) - + # entry present for doc map, but not convo map - elif not response.data[0]['convo_map_id']: - print("Map does not exist for this course. Redirecting to map creation...") + elif response.data[0].convo_map_id is not None: + logging.info("Map does not exist for this course. Redirecting to map creation...") return self.create_conversation_map(course_name) - - project_id = response.data[0]['convo_map_id'] - last_uploaded_convo_id = response.data[0]['last_uploaded_convo_id'] + + project_id = response.data[0].convo_map_id + last_uploaded_convo_id: int = int(str(response.data[0].last_uploaded_convo_id)) # check if project is accepting data project = AtlasProject(project_id=project_id, add_datums_if_exists=True) @@ -99,7 +96,7 @@ def log_to_conversation_map(self, course_name: str, conversation): # fetch count of conversations since last upload response = self.sql.getCountFromLLMConvoMonitor(course_name, last_id=last_uploaded_convo_id) total_convo_count = response.count - print("Total number of unlogged conversations in Supabase: ", total_convo_count) + logging.info("Total number of unlogged conversations in Supabase: ", total_convo_count) if total_convo_count == 0: # log to an existing conversation @@ -113,14 +110,14 @@ def log_to_conversation_map(self, course_name: str, conversation): while current_convo_count < total_convo_count: response = self.sql.getAllConversationsBetweenIds(course_name, first_id, 0, 100) - print("Response count: ", len(response.data)) + logging.info("Response count: ", len(response.data)) if len(response.data) == 0: break df = pd.DataFrame(response.data) combined_dfs.append(df) current_convo_count += len(response.data) convo_count += len(response.data) - print(current_convo_count) + logging.info(current_convo_count) if convo_count >= 500: # concat all dfs from the combined_dfs list @@ -128,24 +125,24 @@ def log_to_conversation_map(self, course_name: str, conversation): # prep data for nomic upload embeddings, metadata = self.data_prep_for_convo_map(final_df) # append to existing map - print("Appending data to existing map...") + logging.info("Appending data to existing map...") result = self.append_to_map(embeddings, metadata, NOMIC_MAP_NAME_PREFIX + course_name) if result == "success": last_id = int(final_df['id'].iloc[-1]) project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} project_response = self.sql.updateProjects(course_name, project_info) - print("Update response from supabase: ", project_response) + logging.info("Update response from supabase: ", project_response) # reset variables combined_dfs = [] convo_count = 0 - print("Records uploaded: ", current_convo_count) + logging.info("Records uploaded: ", current_convo_count) # set first_id for next iteration - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 # upload last set of convos if convo_count > 0: - print("Uploading last set of conversations...") + logging.info("Uploading last set of conversations...") final_df = pd.concat(combined_dfs, ignore_index=True) embeddings, metadata = self.data_prep_for_convo_map(final_df) result = self.append_to_map(embeddings, metadata, NOMIC_MAP_NAME_PREFIX + course_name) @@ -153,41 +150,40 @@ def log_to_conversation_map(self, course_name: str, conversation): last_id = int(final_df['id'].iloc[-1]) project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} project_response = self.sql.updateProjects(course_name, project_info) - print("Update response from supabase: ", project_response) - + logging.info("Update response from supabase: ", project_response) + # rebuild the map self.rebuild_map(course_name, "conversation") return "success" - + except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return "Error in logging to conversation map: {e}" - - + def log_to_existing_conversation(self, course_name: str, conversation): """ This function logs follow-up questions to existing conversations in the map. """ - print(f"in log_to_existing_conversation() for course: {course_name}") + logging.info(f"in log_to_existing_conversation() for course: {course_name}") try: conversation_id = conversation['id'] # fetch id from supabase incoming_id_response = self.sql.getConversation(course_name, key="convo_id", value=conversation_id) - + project_name = 'Conversation Map for ' + course_name project = AtlasProject(name=project_name, add_datums_if_exists=True) - prev_id = incoming_id_response.data[0]['id'] - uploaded_data = project.get_data(ids=[prev_id]) # fetch data point from nomic + prev_id = str(incoming_id_response.data[0].id) + uploaded_data = project.get_data(ids=[prev_id]) # fetch data point from nomic prev_convo = uploaded_data[0]['conversation'] # update conversation messages = conversation['messages'] messages_to_be_logged = messages[-2:] - + for message in messages_to_be_logged: if message['role'] == 'user': emoji = "πŸ™‹ " @@ -200,7 +196,7 @@ def log_to_existing_conversation(self, course_name: str, conversation): text = message['content'] prev_convo += "\n>>> " + emoji + message['role'] + ": " + text + "\n" - + # create embeddings of first query embeddings_model = OpenAIEmbeddings(openai_api_type="openai", openai_api_base="https://api.openai.com/v1/", @@ -216,27 +212,26 @@ def log_to_existing_conversation(self, course_name: str, conversation): metadata = pd.DataFrame(uploaded_data) embeddings = np.array(embeddings) - print("Metadata shape:", metadata.shape) - print("Embeddings shape:", embeddings.shape) + logging.info("Metadata shape:", metadata.shape) + logging.info("Embeddings shape:", embeddings.shape) # deleting existing map - print("Deleting point from nomic:", project.delete_data([prev_id])) + logging.info("Deleting point from nomic:", project.delete_data([prev_id])) # re-build map to reflect deletion project.rebuild_maps() # re-insert updated conversation result = self.append_to_map(embeddings, metadata, project_name) - print("Result of appending to existing map:", result) - + logging.info("Result of appending to existing map:", result) + return "success" except Exception as e: - print("Error in log_to_existing_conversation():", e) + logging.info("Error in log_to_existing_conversation():", e) self.sentry.capture_exception(e) return "Error in logging to existing conversation: {e}" - def create_conversation_map(self, course_name: str): """ This function creates a conversation map for a given course from scratch. @@ -246,9 +241,9 @@ def create_conversation_map(self, course_name: str): try: # check if map exists response = self.sql.getConvoMapFromProjects(course_name) - print("Response from supabase: ", response.data) + logging.info("Response from supabase: ", response.data) if response.data: - if response.data[0]['convo_map_id']: + if response.data[0].convo_map_id is not None: return "Map already exists for this course." # if no, fetch total count of records @@ -262,9 +257,9 @@ def create_conversation_map(self, course_name: str): # if >20, iteratively fetch records in batches of 100 total_convo_count = response.count - print("Total number of conversations in Supabase: ", total_convo_count) + logging.info("Total number of conversations in Supabase: ", total_convo_count) - first_id = response.data[0]['id'] - 1 + first_id = int(str(response.data[0].id)) - 1 combined_dfs = [] current_convo_count = 0 convo_count = 0 @@ -274,14 +269,14 @@ def create_conversation_map(self, course_name: str): # iteratively query in batches of 50 while current_convo_count < total_convo_count: response = self.sql.getAllConversationsBetweenIds(course_name, first_id, 0, 100) - print("Response count: ", len(response.data)) + logging.info("Response count: ", len(response.data)) if len(response.data) == 0: break df = pd.DataFrame(response.data) combined_dfs.append(df) current_convo_count += len(response.data) convo_count += len(response.data) - print(current_convo_count) + logging.info(current_convo_count) if convo_count >= 500: # concat all dfs from the combined_dfs list @@ -291,12 +286,11 @@ def create_conversation_map(self, course_name: str): if first_batch: # create a new map - print("Creating new map...") + logging.info("Creating new map...") index_name = course_name + "_convo_index" topic_label_field = "first_query" colorable_fields = ["user_email", "first_query", "conversation_id", "created_at"] - result = self.create_map(embeddings, metadata, project_name, index_name, topic_label_field, - colorable_fields) + result = self.create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) if result == "success": # update flag @@ -312,35 +306,35 @@ def create_conversation_map(self, course_name: str): project_response = self.sql.updateProjects(course_name, project_info) else: project_response = self.sql.insertProjectInfo(project_info) - print("Update response from supabase: ", project_response) + logging.info("Update response from supabase: ", project_response) else: # append to existing map - print("Appending data to existing map...") + logging.info("Appending data to existing map...") project = AtlasProject(name=project_name, add_datums_if_exists=True) result = self.append_to_map(embeddings, metadata, project_name) if result == "success": - print("map append successful") + logging.info("map append successful") last_id = int(final_df['id'].iloc[-1]) project_info = {'last_uploaded_convo_id': last_id} project_response = self.sql.updateProjects(course_name, project_info) - print("Update response from supabase: ", project_response) + logging.info("Update response from supabase: ", project_response) # reset variables combined_dfs = [] convo_count = 0 - print("Records uploaded: ", current_convo_count) + logging.info("Records uploaded: ", current_convo_count) # set first_id for next iteration try: - print("response: ", response.data[-1]['id']) - except: - print("response: ", response.data) - first_id = response.data[-1]['id'] + 1 + logging.info("response: ", response.data[-1].id) + except Exception as e: + logging.info("response: ", response.data) + first_id = int(str(response.data[-1].id)) + 1 - print("Convo count: ", convo_count) + logging.info("Convo count: ", convo_count) # upload last set of convos if convo_count > 0: - print("Uploading last set of conversations...") + logging.info("Uploading last set of conversations...") final_df = pd.concat(combined_dfs, ignore_index=True) embeddings, metadata = self.data_prep_for_convo_map(final_df) if first_batch: @@ -352,30 +346,29 @@ def create_conversation_map(self, course_name: str): else: # append to map - print("in map append") + logging.info("in map append") result = self.append_to_map(embeddings, metadata, project_name) if result == "success": - print("last map append successful") + logging.info("last map append successful") last_id = int(final_df['id'].iloc[-1]) project = AtlasProject(name=project_name, add_datums_if_exists=True) project_id = project.id project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} - print("Project info: ", project_info) + logging.info("Project info: ", project_info) # if entry already exists, update it projects_record = self.sql.getConvoMapFromProjects(course_name) if projects_record.data: project_response = self.sql.updateProjects(course_name, project_info) else: project_response = self.sql.insertProjectInfo(project_info) - print("Response from supabase: ", project_response) - + logging.info("Response from supabase: ", project_response) # rebuild the map self.rebuild_map(course_name, "conversation") return "success" except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return "Error in creating conversation map:" + str(e) @@ -385,7 +378,7 @@ def rebuild_map(self, course_name: str, map_type: str): """ This function rebuilds a given map in Nomic. """ - print("in rebuild_map()") + logging.info("in rebuild_map()") nomic.login(os.getenv('NOMIC_API_KEY')) if map_type.lower() == 'document': @@ -402,7 +395,7 @@ def rebuild_map(self, course_name: str, map_type: str): project.rebuild_maps() return "success" except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return "Error in rebuilding map: {e}" @@ -418,7 +411,7 @@ def create_map(self, embeddings, metadata, map_name, index_name, topic_label_fie colorable_fields: list of str """ nomic.login(os.environ['NOMIC_API_KEY']) - print("in create_map()") + logging.info("in create_map()") try: project = atlas.map_embeddings(embeddings=embeddings, data=metadata, @@ -431,7 +424,7 @@ def create_map(self, embeddings, metadata, map_name, index_name, topic_label_fie project.create_index(index_name, build_topic_model=True) return "success" except Exception as e: - print(e) + logging.info(e) return "Error in creating map: {e}" def append_to_map(self, embeddings, metadata, map_name): @@ -449,7 +442,7 @@ def append_to_map(self, embeddings, metadata, map_name): project.add_embeddings(embeddings=embeddings, data=metadata) return "success" except Exception as e: - print(e) + logging.info(e) return "Error in appending to map: {e}" def data_prep_for_convo_map(self, df: pd.DataFrame): @@ -461,7 +454,7 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): embeddings: np.array of embeddings metadata: pd.DataFrame of metadata """ - print("in data_prep_for_convo_map()") + logging.info("in data_prep_for_convo_map()") try: metadata = [] @@ -471,7 +464,6 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): for _index, row in df.iterrows(): current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") created_at = datetime.datetime.strptime(row['created_at'], "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y-%m-%d %H:%M:%S") - conversation_exists = False conversation = "" emoji = "" @@ -481,18 +473,20 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): user_email = row['user_email'] messages = row['convo']['messages'] + first_message = "" # some conversations include images, so the data structure is different if isinstance(messages[0]['content'], list): if 'text' in messages[0]['content'][0]: first_message = messages[0]['content'][0]['text'] - #print("First message:", first_message) + #logging.info("First message:", first_message) else: first_message = messages[0]['content'] user_queries.append(first_message) # construct metadata for multi-turn conversation for message in messages: + text = "" if message['role'] == 'user': emoji = "πŸ™‹ " else: @@ -517,7 +511,7 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): "created_at": created_at, "modified_at": current_time } - #print("Metadata row:", meta_row) + #logging.info("Metadata row:", meta_row) metadata.append(meta_row) embeddings_model = OpenAIEmbeddings(openai_api_type="openai", @@ -528,12 +522,12 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): metadata = pd.DataFrame(metadata) embeddings = np.array(embeddings) - print("Metadata shape:", metadata.shape) - print("Embeddings shape:", embeddings.shape) + logging.info("Metadata shape:", metadata.shape) + logging.info("Embeddings shape:", embeddings.shape) return embeddings, metadata except Exception as e: - print("Error in data_prep_for_convo_map():", e) + logging.info("Error in data_prep_for_convo_map():", e) self.sentry.capture_exception(e) return None, None @@ -545,18 +539,18 @@ def delete_from_document_map(self, project_id: str, ids: list): course_name: str ids: list of str """ - print("in delete_from_document_map()") + logging.info("in delete_from_document_map()") try: # fetch project from Nomic project = AtlasProject(project_id=project_id, add_datums_if_exists=True) # delete the ids from Nomic - print("Deleting point from document map:", project.delete_data(ids)) + logging.info("Deleting point from document map:", project.delete_data(ids)) with project.wait_for_project_lock(): project.rebuild_maps() return "Successfully deleted from Nomic map" except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return "Error in deleting from document map: {e}" diff --git a/ai_ta_backend/service/plants_of_India_demo.db b/ai_ta_backend/service/plants_of_India_demo.db new file mode 100644 index 00000000..634e4596 Binary files /dev/null and b/ai_ta_backend/service/plants_of_India_demo.db differ diff --git a/ai_ta_backend/service/poi_agent_service.py b/ai_ta_backend/service/poi_agent_service.py new file mode 100644 index 00000000..b249a5f7 --- /dev/null +++ b/ai_ta_backend/service/poi_agent_service.py @@ -0,0 +1,482 @@ +from dotenv import load_dotenv +from langchain_community.utilities.sql_database import SQLDatabase +from langchain_openai import ChatOpenAI, OpenAI +from langchain_community.agent_toolkits import create_sql_agent +from langchain_core.prompts import ( + ChatPromptTemplate, + FewShotPromptTemplate, + MessagesPlaceholder, + PromptTemplate, + SystemMessagePromptTemplate, +) +import os +import logging +from flask_sqlalchemy import SQLAlchemy + +from langchain_openai import ChatOpenAI + +from operator import itemgetter + +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool +from langchain_community.agent_toolkits import create_sql_agent +from langchain.tools import BaseTool, StructuredTool, Tool, tool +import random +from langgraph.prebuilt.tool_executor import ToolExecutor +from langchain.tools.render import format_tool_to_openai_function + + +from typing import TypedDict, Annotated, Sequence +import operator +from langchain_core.messages import BaseMessage + +from langchain_core.agents import AgentFinish +from langgraph.prebuilt import ToolInvocation +import json +from langchain_core.messages import FunctionMessage +from langchain_community.utilities import SQLDatabase +from langchain_community.vectorstores import FAISS +from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_openai import OpenAIEmbeddings + +load_dotenv() + + +def get_dynamic_prompt_template(): + + examples = [ + { + "input": "How many accepted names are only distributed in Karnataka?", + "query": 'SELECT COUNT(*) as unique_pairs_count FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "Karnataka%"));' + }, + { + "input": "How many names were authored by Roxb?", + "query": 'SELECT COUNT(*) as unique_pairs_count FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Author_Name" LIKE "%Roxb%" AND "Record_Type_Code" IN ("AN", "SN"));' + }, + { + "input": "How many species have distributions in Myanmar, Meghalaya and Andhra Pradesh?", + "query": 'SELECT COUNT(*) FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%"));' + }, + { + "input": "List the accepted names common to Myanmar, Meghalaya, Odisha, Andhra Pradesh.", + "query": 'SELECT DISTINCT Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Odisha%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%"));' + }, + { + "input": "List the accepted names that represent 'tree'.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "HB" AND "Additional_Details_2" LIKE "%tree%");' + }, + { + "input": "List the accepted names linked with Endemic tag.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND "Additional_Details_2" LIKE "%Endemic%");' + }, + { + "input": "List the accepted names published in Fl. Brit. India [J. D. Hooker].", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" in ("AN", "SN") AND ("Publication" LIKE "%Fl. Brit. India [J. D. Hooker]%" OR "Publication" LIKE "%[J. D. Hooker]%" OR "Publication" LIKE "%Fl. Brit. India%");' + }, + { + "input": "How many accepted names have β€˜Silhet’/ β€˜Sylhet’ in their Type?", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "TY" AND ("Additional_Details_2" LIKE "%Silhet%" OR "Additional_Details_2" LIKE "%Sylhet%"));' + }, + { + "input": "How many species were distributed in Sikkim and Meghalaya?", + "query": 'SELECT COUNT(*) AS unique_pairs FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Sikkim%" AND Additional_Details_2 LIKE "%Meghalaya%"));' + }, + { + "input": "List the accepted names common to Kerala, Tamil Nadu, Andhra Pradesh, Karnataka, Maharashtra, Odisha, Meghalaya and Myanmar.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Odisha%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%" AND "Additional_Details_2" LIKE "%Kerala%" AND "Additional_Details_2" LIKE "%Tamil Nadu%" AND "Additional_Details_2" LIKE "%Karnataka%" AND "Additional_Details_2" LIKE "%Maharashtra%"));' + }, + { + "input": "List the accepted names common to Europe, Afghanistan, Jammu & Kashmir, Himachal, Nepal, Sikkim, Bhutan, Arunachal Pradesh and China.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Europe%" AND Additional_Details_2 LIKE "%Afghanistan%" AND "Additional_Details_2" LIKE "%Jammu & Kashmir%" AND "Additional_Details_2" LIKE "%Himachal%" AND "Additional_Details_2" LIKE "%Nepal%" AND "Additional_Details_2" LIKE "%Sikkim%" AND "Additional_Details_2" LIKE "%Bhutan%" AND "Additional_Details_2" LIKE "%Arunachal Pradesh%" AND "Additional_Details_2" LIKE "%China%"));' + }, + { + "input": "List the accepted names common to Europe, Afghanistan, Austria, Belgium, Czechoslovakia, Denmark, France, Greece, Hungary, Italy, Moldava, Netherlands, Poland, Romania, Spain, Switzerland, Jammu & Kashmir, Himachal, Nepal, and China.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Europe%" AND Additional_Details_2 LIKE "%Afghanistan%" AND "Additional_Details_2" LIKE "%Jammu & Kashmir%" AND "Additional_Details_2" LIKE "%Himachal%" AND "Additional_Details_2" LIKE "%Nepal%" AND "Additional_Details_2" LIKE "%Austria%" AND "Additional_Details_2" LIKE "%Belgium%" AND "Additional_Details_2" LIKE "%Czechoslovakia%" AND "Additional_Details_2" LIKE "%China%" AND "Additional_Details_2" LIKE "%Denmark%" AND "Additional_Details_2" LIKE "%Greece%" AND "Additional_Details_2" LIKE "%France%" AND "Additional_Details_2" LIKE "%Hungary%" AND "Additional_Details_2" LIKE "%Italy%" AND "Additional_Details_2" LIKE "%Moldava%" AND "Additional_Details_2" LIKE "%Netherlands%" AND "Additional_Details_2" LIKE "%Poland%" AND "Additional_Details_2" LIKE "%Poland%" AND "Additional_Details_2" LIKE "%Romania%" AND "Additional_Details_2" LIKE "%Spain%" AND "Additional_Details_2" LIKE "%Switzerland%"));' + }, + { + "input": "List the species which are distributed in Sikkim and Meghalaya.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Sikkim%" AND Additional_Details_2 LIKE "%Meghalaya%"));' + }, + { + "input": "How many species are common to America, Europe, Africa, Asia, and Australia?", + "query": 'SELECT COUNT(*) AS unique_pairs IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%America%" AND Additional_Details_2 LIKE "%Europe%" AND "Additional_Details_2" LIKE "%Africa%" AND "Additional_Details_2" LIKE "%Asia%" AND "Additional_Details_2" LIKE "%Australia%"));' + }, + { + "input": "List the species names common to India and Myanmar, Malaysia, Indonesia, and Australia.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number","Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%India%" AND Additional_Details_2 LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Malaysia%" AND Additional_Details_2 LIKE "%Indonesia%" AND Additional_Details_2 LIKE "%Australia%"));' + }, + { + "input": "List all plants which are tagged as urban.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Urban" = "YES";' + }, + { + "input": "List all plants which are tagged as fruit.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Fruit" = "YES";' + }, + { + "input": "List all plants which are tagged as medicinal.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Medicinal" = "YES";' + }, + { + "input": "List all family names which are gymnosperms.", + "query": 'SELECT DISTINCT "Family_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Groups" = "Gymnosperms";' + }, + { + "input": "How many accepted names are tagged as angiosperms?", + "query": 'SELECT COUNT(DISTINCT "Scientific_Name") FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Groups" = "Angiosperms";' + }, + { + "input": "How many accepted names belong to the 'Saxifraga' genus?", + "query": 'SELECT COUNT(DISTINCT "Scientific_Name") FROM plants WHERE "Genus_Name" = "Saxifraga";' + }, + { + "input": "List the accepted names tagged as 'perennial herb' or 'climber'.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "HB" AND ("Additional_Details_2" LIKE "%perennial herb%" OR "Additional_Details_2" LIKE "%climber%"));' + }, + { + "input": "How many accepted names are native to South Africa?", + "query": 'SELECT COUNT(*) FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%native%" AND "Additional_Details_2" LIKE "%south%" AND "Additional_Details_2" LIKE "%africa%");' + + }, + { + "input": "List the accepted names which were introduced and naturalized.", + "query": 'SELECT DISTINCT "Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%introduced%" AND "Additional_Details_2" LIKE "%naturalized%");' + }, + { + "input": "List all ornamental plants.", + "query": 'SELECT DISTINCT "Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%ornamental%");' + }, + { + "input": "How many plants from the 'Leguminosae' family have a altitudinal range up to 1000 m?", + "query": 'SELECT COUNT(*) FROM plants WHERE "Record_Type_Code" = "AL" AND "Family_Name" = "Leguminosae" AND "Additional_Details_2" LIKE "%1000%";' + }, + { + "input": "List the accepted names linked with the 'endemic' tag for Karnataka.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND "Additional_Details_2" LIKE "%Endemic%" AND "Additional_Details_2" LIKE "%Karnataka%");' + }, + {"input": "List all the accepted names under the family 'Gnetaceae'.", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Family_Number" IN ( + SELECT DISTINCT "Family_Number" FROM plants WHERE "Family_Name" = "Gnetaceae" + ) +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +"""}, + { + "input": "List all the accepted species that are introduced.", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Record_Type_Code" = 'RE'and "Additional_Details_2" LIKE '%cultivated%' +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input": "List all the accepted names with type 'Cycad'", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Record_Type_Code" = 'HB'and "Additional_Details_2" LIKE '%Cycad%' + +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input": "List all the accepted names under the genus 'Cycas' with more than two synonyms.", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Genus_Number" IN ( + SELECT DISTINCT "Genus_Number" FROM plants WHERE "Genus_Name" = 'Cycas' + ) + AND "Family_Number" IN ( + SELECT DISTINCT "Family_Number" FROM plants WHERE "Genus_Name" = 'Cycas' + ) + AND "Synonym_Number" > 2 +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input":'List all the accepted names published in Asian J. Conservation Biol.', + "query": """ + SELECT DISTINCT "Scientific_Name" + FROM plants + WHERE "Record_Type_Code" = 'AN' AND "Publication" LIKE '%Asian J. Conservation Biol%'; + +""", + }, + { + "input": 'List all the accepted names linked with endemic tag.', + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Record_Type_Code" = 'DB'and "Additional_Details_2" LIKE '%Endemic%' + +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input": 'List all the accepted names that have no synonyms.' , + "query": """ +SELECT DISTINCT a."Scientific_Name" FROM plants a +group by a."Family_Number",a."Genus_Number",a."Accepted_name_number" +HAVING SUM(a."Synonym_Number") = 0 AND a."Accepted_name_number" > 0; +""", + }, + { + "input": 'List all the accepted names authored by Roxb.', + "query": """ +SELECT "Scientific_Name" +FROM plants +WHERE "Record_Type_Code" = 'AN'AND "Author_Name" LIKE '%Roxb%'; +""", + }, + { + "input": 'List all genera within each family', + "query": """ +SELECT "Family_Name", "Genus_Name" +FROM plants +WHERE "Record_Type_Code" = 'GE'; +""", + }, + { + "input": 'Did Minq. discovered Cycas ryumphii?', + "query": """SELECT + CASE + WHEN EXISTS ( + SELECT 1 + FROM plants as a + WHERE a."Scientific_Name" = 'Cycas rumphii' + AND a."Author_Name" = 'Miq.' + ) THEN 'TRUE' + ELSE 'FALSE' + END AS ExistsCheck; +"""}, + + ] + + + example_selector = SemanticSimilarityExampleSelector.from_examples( + examples, + OpenAIEmbeddings(), + FAISS, + k=5, + input_keys=["input"], + ) + + + prefix_prompt = """ + You are an agent designed to interact with a SQL database. + Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. + You can order the results by a relevant column to return the most interesting examples in the database. + Never query for all the columns from a specific table, only ask for the relevant columns given the question. + You have access to tools for interacting with the database. + Only use the given tools. Only use the information returned by the tools to construct your final answer. + You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. + + - Restrict your queries to the "plants" table. + - Do not return more than {top_k} rows unless specified otherwise. + - Add a limit of 25 at the end of SQL query. + - If the SQLite query returns zero rows, return a message indicating the same. + - Only refer to the data contained in the {table_info} table. Do not fabricate any data. + - For filtering based on string comparison, always use the LIKE operator and enclose the string in `%`. + - Queries on the `Additional_Details_2` column should use sub-queries involving `Family_Number`, `Genus_Number` and `Accepted_name_number`. + + Refer to the table description below for more details on the columns: + 1. **Record_Type_Code**: Contains text codes indicating the type of information in the row. + - FA: Family Name, Genus Name, Scientific Name + - TY: Type + - GE: Genus Name + - AN: Family Name (Accepted Name), Genus Name, Scientific Name, Author Name, Publication, Volume:Page, Year of Publication + - HB: Habit + - DB: Distribution/location of the plant + - RE: Remarks + - SN: Family Name (Synonym), Genus Name, Scientific Name, Author Name, Publication, Volume:Page, Year of Publication + 2. **Family_Name**: Contains the Family Name of the plant. + 3. **Genus_Name**: Contains the Genus Name of the plant. + 4. **Scientific_Name**: Contains the Scientific Name of the plant species. + 5. **Publication_Name**: Name of the journal or book where the plant discovery information is published. Use LIKE for queries. + 6. **Volume:_Page**: The volume and page number of the publication. + 7. **Year_of_Publication**: The year in which the plant information was published. + 8. **Author_Name**: May contain multiple authors separated by `&`. Use LIKE for queries. + 9. **Additional_Details**: Contains type, habit, distribution, and remarks. Use LIKE for queries. + - Type: General location information. + - Remarks: Location information about cultivation or native area. + - Distribution: Locations where the plant is common. May contain multiple locations, use LIKE for queries. + 10. **Groups**: Contains either "Gymnosperms" or "Angiosperms". + 11. **Urban**: Contains either "YES" or "NO". Specifies whether the plant is urban. + 12. **Fruit**: Contains either "YES" or "NO". Specifies whether the plant is a fruit plant. + 13. **Medicinal**: Contains either "YES" or "NO". Specifies whether the plant is medicinal. + 14. **Genus_Number**: Contains the Genus Number of the plant. + 15. **Accepted_name_number**: Contains the Accepted Name Number of the plant. + + Below are examples of questions and their corresponding SQL queries. + """ + + + + agent_prompt = PromptTemplate.from_template("User input: {input}\nSQL Query: {query}") + agent_prompt_obj = FewShotPromptTemplate( + example_selector=example_selector, + example_prompt=agent_prompt, + prefix=prefix_prompt, + suffix="", + input_variables=["input"], + ) + + full_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate(prompt=agent_prompt_obj), + ("human", "{input}"), + MessagesPlaceholder("agent_scratchpad"), + ] + ) + return full_prompt + +def initalize_sql_agent(): + + ### LLM + llm = ChatOpenAI(model="gpt-4o", temperature=0) + + ### DATABASE + db = SQLDatabase.from_uri("sqlite:///C:/Users/rohan/OneDrive/Desktop/NCSA_self_hostable_chatbot/self-hostable-ai-ta-backend/ai_ta_backend/service/plants_of_India_demo.db") + + dynamic_few_shot_prompt = get_dynamic_prompt_template() + + agent = create_sql_agent(llm, db=db, prompt=dynamic_few_shot_prompt, agent_type="openai-tools", verbose=True) + + return agent + +def generate_response_agent(agent,user_question): + response = agent.invoke({"input": user_question}) + return response + +##### Setting up the Graph Nodes, Edges and message communication + +class AgentState(TypedDict): + messages: Annotated[Sequence[BaseMessage], operator.add] + + +@tool("plants_sql_tool", return_direct=True) +def generate_sql_query(input:str) -> str: + """Given a query looks for the three most relevant SQL sample queries""" + user_question = input + sql_agent = initalize_sql_agent() + response = generate_response_agent(sql_agent,user_question) + return response + +model = ChatOpenAI(model="gpt-4o", temperature=0) + +tools = [generate_sql_query] +tool_executor = ToolExecutor(tools) +functions = [format_tool_to_openai_function(t) for t in tools] +model = model.bind_functions(functions) + +# Define the function that determines whether to continue or not +def should_continue(state): + messages = state['messages'] + last_message = messages[-1] + # If there is no function call, then we finish + if "function_call" not in last_message.additional_kwargs: + return "end" + # Otherwise if there is, we continue + else: + return "continue" + +# Define the function that calls the model +def call_model(state): + messages = state['messages'] + response = model.invoke(messages) + # We return a list, because this will get added to the existing list + return {"messages": [response]} + +# Define the function to execute tools +def call_tool(state): + messages = state['messages'] + # Based on the continue condition + # we know the last message involves a function call + last_message = messages[-1] + # We construct an ToolInvocation from the function_call + action = ToolInvocation( + tool=last_message.additional_kwargs["function_call"]["name"], + tool_input=json.loads(last_message.additional_kwargs["function_call"]["arguments"]), + ) + print(f"The agent action is {action}") + # We call the tool_executor and get back a response + response = tool_executor.invoke(action) + print(f"The tool result is: {response}") + # We use the response to create a FunctionMessage + function_message = FunctionMessage(content=str(response), name=action.tool) + # We return a list, because this will get added to the existing list + return {"messages": [function_message]} + + +from langgraph.graph import StateGraph, END +# Define a new graph +workflow = StateGraph(AgentState) + +# Define the two nodes we will cycle between +workflow.add_node("agent", call_model) +workflow.add_node("action", call_tool) + +# Set the entrypoint as `agent` where we start +workflow.set_entry_point("agent") + +# We now add a conditional edge +workflow.add_conditional_edges( + # First, we define the start node. We use `agent`. + # This means these are the edges taken after the `agent` node is called. + "agent", + # Next, we pass in the function that will determine which node is called next. + should_continue, + # Finally we pass in a mapping. + # The keys are strings, and the values are other nodes. + # END is a special node marking that the graph should finish. + # What will happen is we will call `should_continue`, and then the output of that + # will be matched against the keys in this mapping. + # Based on which one it matches, that node will then be called. + { + # If `tools`, then we call the tool node. + "continue": "action", + # Otherwise we finish. + "end": END + } +) + +# We now add a normal edge from `tools` to `agent`. +# This means that after `tools` is called, `agent` node is called next. +workflow.add_edge('action', 'agent') + +# Finally, we compile it! +# This compiles it into a LangChain Runnable, +# meaning you can use it as you would any other runnable +app = workflow.compile() + + +def generate_response(user_input): + #agent = initialize_agent() + output = app.invoke(user_input) + return output diff --git a/ai_ta_backend/service/poi_agent_service_v2.py b/ai_ta_backend/service/poi_agent_service_v2.py new file mode 100644 index 00000000..1a0e7b9c --- /dev/null +++ b/ai_ta_backend/service/poi_agent_service_v2.py @@ -0,0 +1,126 @@ +import os +from injector import inject +from langchain_openai import ChatOpenAI +from pydantic import BaseModel +from ai_ta_backend.database.poi_sql import POISQLDatabase +from langgraph.graph import StateGraph, END +from langchain_openai import ChatOpenAI +from langchain_community.utilities.sql_database import SQLDatabase + + +from langchain_openai import ChatOpenAI + + +from langchain.tools import tool, StructuredTool +from langgraph.prebuilt.tool_executor import ToolExecutor + + +from typing import TypedDict, Annotated, Sequence +import operator +from langchain_core.messages import BaseMessage + +from langgraph.prebuilt import ToolInvocation +import json +from langchain_core.messages import FunctionMessage +from langchain_core.utils.function_calling import convert_to_openai_function +from langgraph.graph import StateGraph, END + +from ai_ta_backend.utils.agent_utils import generate_response_agent, initalize_sql_agent +import traceback + +##### Setting up the Graph Nodes, Edges and message communication + +class AgentState(TypedDict): + messages: Annotated[Sequence[BaseMessage], operator.add] + +class POIInput(BaseModel): + input: str + +@tool("plants_sql_tool", return_direct=True, args_schema=POIInput) +def generate_sql_query(input:str) -> str: + """Given a query looks for the three most relevant SQL sample queries""" + user_question = input + llm = ChatOpenAI(model="gpt-4o", temperature=0) + ### DATABASE + db = SQLDatabase.from_uri(f"sqlite:///{os.environ['POI_SQL_DB_NAME']}") + sql_agent = initalize_sql_agent(llm, db) + response = generate_response_agent(sql_agent,user_question) + return response['output'] + +class POIAgentService: + @inject + def __init__(self, poi_sql_db: POISQLDatabase): + self.poi_sql_db = poi_sql_db + self.model = ChatOpenAI(model="gpt-4o", temperature=0) + # self.tools = [StructuredTool.from_function(self.generate_sql_query, name="Run SQL Query", args_schema=POIInput)] + self.tools = [generate_sql_query] + self.tool_executor = ToolExecutor(self.tools) + self.functions = [convert_to_openai_function(t) for t in self.tools] + self.model = self.model.bind_functions(self.functions) + self.workflow = self.initialize_workflow(self.model) + + + + # Define the function that determines whether to continue or not + def should_continue(self, state): + messages = state['messages'] + last_message = messages[-1] + # If there is no function call, then we finish + if "function_call" not in last_message.additional_kwargs: + return "end" + # Otherwise if there is, we continue + else: + return "continue" + + # Define the function that calls the model + def call_model(self, state): + messages = state['messages'] + response = self.model.invoke(messages) + # We return a list, because this will get added to the existing list + return {"messages": [response]} + + # Define the function to execute tools + def call_tool(self, state): + messages = state['messages'] + # Based on the continue condition + # we know the last message involves a function call + last_message = messages[-1] + # We construct an ToolInvocation from the function_call + action = ToolInvocation( + tool=last_message.additional_kwargs["function_call"]["name"], + tool_input=json.loads(last_message.additional_kwargs["function_call"]["arguments"]), + ) + print(f"The agent action is {action}") + # We call the tool_executor and get back a response + response = self.tool_executor.invoke(action) + print(f"The tool result is: {response}") + # We use the response to create a FunctionMessage + function_message = FunctionMessage(content=str(response), name=action.tool) + # We return a list, because this will get added to the existing list + return {"messages": [function_message]} + + def initialize_workflow(self, agent): + workflow = StateGraph(AgentState) + workflow.add_node("agent", self.call_model) + workflow.add_node("action", self.call_tool) + workflow.set_entry_point("agent") + workflow.add_conditional_edges( + "agent", + self.should_continue, + { + "continue": "action", + "end": END + } + ) + workflow.add_edge('action', 'agent') + return workflow.compile() + + def run_workflow(self, user_input): + #agent = initialize_agent() + try: + + output = self.workflow.invoke(user_input) + return output + except Exception as e: + traceback.print_exc() + return str(e) \ No newline at end of file diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index c53bcefb..17888361 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -1,18 +1,19 @@ import inspect +import logging import os import time import traceback -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union -import openai from injector import inject from langchain.chat_models import AzureChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.schema import Document +import openai from ai_ta_backend.database.aws import AWSStorage -from ai_ta_backend.database.sql import SQLDatabase -from ai_ta_backend.database.vector import VectorDatabase +from ai_ta_backend.database.qdrant import VectorDatabase +from ai_ta_backend.database.sql import SQLAlchemyDatabase from ai_ta_backend.service.nomic_service import NomicService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.sentry_service import SentryService @@ -25,8 +26,8 @@ class RetrievalService: """ @inject - def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, posthog: PosthogService, - sentry: SentryService, nomicService: NomicService): + def __init__(self, vdb: VectorDatabase, sqlDb: SQLAlchemyDatabase, aws: AWSStorage, posthog: Optional[PosthogService], + sentry: Optional[SentryService], nomicService: Optional[NomicService]): self.vdb = vdb self.sqlDb = sqlDb self.aws = aws @@ -34,6 +35,8 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos self.posthog = posthog self.nomicService = nomicService + logging.info(f"Vector DB: {self.vdb}") + openai.api_key = os.environ["OPENAI_API_KEY"] self.embeddings = OpenAIEmbeddings( @@ -73,9 +76,7 @@ def getTopContexts(self, try: start_time_overall = time.monotonic() - found_docs: list[Document] = self.vector_search(search_query=search_query, - course_name=course_name, - doc_groups=doc_groups) + found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name, doc_groups=doc_groups) pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n" # count tokens at start and end, then also count each context. @@ -88,7 +89,7 @@ def getTopContexts(self, doc_string = f"Document: {doc.metadata['readable_filename']}{', page: ' + str(doc.metadata['pagenumber']) if doc.metadata['pagenumber'] else ''}\n{str(doc.page_content)}\n" num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore - print( + logging.info( f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. πŸ“„ File: {doc.metadata['readable_filename']}" ) if token_counter + num_tokens <= token_limit: @@ -98,33 +99,35 @@ def getTopContexts(self, # filled our token size, time to return break - print(f"Total tokens used: {token_counter}. Docs used: {len(valid_docs)} of {len(found_docs)} docs retrieved") - print(f"Course: {course_name} ||| search_query: {search_query}") - print(f"⏰ ^^ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds") + logging.info(f"Total tokens used: {token_counter}. Docs used: {len(valid_docs)} of {len(found_docs)} docs retrieved") + logging.info(f"Course: {course_name} ||| search_query: {search_query}") + logging.info(f"⏰ ^^ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds") if len(valid_docs) == 0: return [] - self.posthog.capture( - event_name="getTopContexts_success_DI", - properties={ - "user_query": search_query, - "course_name": course_name, - "token_limit": token_limit, - "total_tokens_used": token_counter, - "total_contexts_used": len(valid_docs), - "total_unique_docs_retrieved": len(found_docs), - "getTopContext_total_latency_sec": time.monotonic() - start_time_overall, - }, - ) + if self.posthog is not None: + self.posthog.capture( + event_name="getTopContexts_success_DI", + properties={ + "user_query": search_query, + "course_name": course_name, + "token_limit": token_limit, + "total_tokens_used": token_counter, + "total_contexts_used": len(valid_docs), + "total_unique_docs_retrieved": len(found_docs), + "getTopContext_total_latency_sec": time.monotonic() - start_time_overall, + }, + ) return self.format_for_json(valid_docs) except Exception as e: # return full traceback to front end # err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore - err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.print_exc} \n{e}" # type: ignore + err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.print_exc()} \n{e}" # type: ignore traceback.print_exc() - print(err) - self.sentry.capture_exception(e) + logging.info(err) + if self.sentry is not None: + self.sentry.capture_exception(e) return err def getAll( @@ -137,7 +140,6 @@ def getAll( Returns: list of dictionaries with distinct s3 path, readable_filename and course_name, url, base_url. """ - response = self.sqlDb.getAllMaterialsForCourse(course_name) data = response.data @@ -145,7 +147,7 @@ def getAll( distinct_dicts = [] for item in data: - combination = (item['s3_path'], item['readable_filename'], item['course_name'], item['url'], item['base_url']) + combination = (item.s3_path, item.readable_filename, item.course_name, item.url, item.base_url) if combination not in unique_combinations: unique_combinations.add(combination) distinct_dicts.append(item) @@ -154,7 +156,7 @@ def getAll( def delete_data(self, course_name: str, s3_path: str, source_url: str): """Delete file from S3, Qdrant, and Supabase.""" - print(f"Deleting data for course {course_name}") + logging.info(f"Deleting data for course {course_name}") # add delete from doc map logic here try: # Delete file from S3 @@ -163,7 +165,7 @@ def delete_data(self, course_name: str, s3_path: str, source_url: str): raise ValueError("S3_BUCKET_NAME environment variable is not set") identifier_key, identifier_value = ("s3_path", s3_path) if s3_path else ("url", source_url) - print(f"Deleting {identifier_value} from S3, Qdrant, and Supabase using {identifier_key}") + logging.info(f"Deleting {identifier_value} from S3, Qdrant, and Supabase using {identifier_key}") # Delete from S3 if identifier_key == "s3_path": @@ -178,36 +180,36 @@ def delete_data(self, course_name: str, s3_path: str, source_url: str): return "Success" except Exception as e: err: str = f"ERROR IN delete_data: Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore - print(err) - self.sentry.capture_exception(e) + logging.info(err) + if self.sentry is not None: + self.sentry.capture_exception(e) return err def delete_from_s3(self, bucket_name: str, s3_path: str): try: - print("Deleting from S3") + logging.info("Deleting from S3") response = self.aws.delete_file(bucket_name, s3_path) - print(f"AWS response: {response}") + logging.info(f"AWS response: {response}") except Exception as e: - print("Error in deleting file from s3:", e) - self.sentry.capture_exception(e) + logging.info("Error in deleting file from s3:", e) + if self.sentry is not None: + self.sentry.capture_exception(e) def delete_from_qdrant(self, identifier_key: str, identifier_value: str): try: - print("Deleting from Qdrant") + logging.info("Deleting from Qdrant") response = self.vdb.delete_data(os.environ['QDRANT_COLLECTION_NAME'], identifier_key, identifier_value) - print(f"Qdrant response: {response}") + logging.info(f"Qdrant response: {response}") except Exception as e: if "timed out" in str(e): # Timed out is fine. Still deletes. pass else: - print("Error in deleting file from Qdrant:", e) - self.sentry.capture_exception(e) + logging.info("Error in deleting file from Qdrant:", e) + if self.sentry is not None: + self.sentry.capture_exception(e) - def getTopContextsWithMQR(self, - search_query: str, - course_name: str, - token_limit: int = 4_000) -> Union[List[Dict], str]: + def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]: """ New info-retrieval pipeline that uses multi-query retrieval + filtering + reciprocal rank fusion + context padding. 1. Generate multiple queries based on the input search query. @@ -230,7 +232,7 @@ def getTopContextsWithMQR(self, # ) # generated_queries = generate_queries.invoke({"original_query": search_query}) - # print("generated_queries", generated_queries) + # logging.info("generated_queries", generated_queries) # # 2. VECTOR SEARCH FOR EACH QUERY # batch_found_docs_nested: list[list[Document]] = self.batch_vector_search(search_queries=generated_queries, @@ -240,10 +242,10 @@ def getTopContextsWithMQR(self, # # 3. RANK REMAINING DOCUMENTS -- good for parent doc padding of top 5 at the end. # found_docs = self.reciprocal_rank_fusion(batch_found_docs_nested) # found_docs = [doc for doc, score in found_docs] - # print(f"Num docs after re-ranking: {len(found_docs)}") + # logging.info(f"Num docs after re-ranking: {len(found_docs)}") # if len(found_docs) == 0: # return [] - # print(f"⏰ Total multi-query processing runtime: {(time.monotonic() - mq_start_time):.2f} seconds") + # logging.info(f"⏰ Total multi-query processing runtime: {(time.monotonic() - mq_start_time):.2f} seconds") # # 4. FILTER DOCS # filtered_docs = filter_top_contexts(contexts=found_docs, user_query=search_query, timeout=30, max_concurrency=180) @@ -252,7 +254,7 @@ def getTopContextsWithMQR(self, # # 5. TOP DOC CONTEXT PADDING // parent document retriever # final_docs = context_parent_doc_padding(filtered_docs, search_query, course_name) - # print(f"Number of final docs after context padding: {len(final_docs)}") + # logging.info(f"Number of final docs after context padding: {len(final_docs)}") # pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n" # token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + @@ -264,7 +266,7 @@ def getTopContextsWithMQR(self, # doc_string = f"Document: {doc['readable_filename']}{', page: ' + str(doc['pagenumber']) if doc['pagenumber'] else ''}\n{str(doc['text'])}\n" # num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore - # print(f"token_counter: {token_counter}, num_tokens: {num_tokens}, max_tokens: {token_limit}") + # logging.info(f"token_counter: {token_counter}, num_tokens: {num_tokens}, max_tokens: {token_limit}") # if token_counter + num_tokens <= token_limit: # token_counter += num_tokens # valid_docs.append(doc) @@ -272,9 +274,9 @@ def getTopContextsWithMQR(self, # # filled our token size, time to return # break - # print(f"Total tokens used: {token_counter} Used {len(valid_docs)} of total unique docs {len(found_docs)}.") - # print(f"Course: {course_name} ||| search_query: {search_query}") - # print(f"⏰ ^^ Runtime of getTopContextsWithMQR: {(time.monotonic() - start_time_overall):.2f} seconds") + # logging.info(f"Total tokens used: {token_counter} Used {len(valid_docs)} of total unique docs {len(found_docs)}.") + # logging.info(f"Course: {course_name} ||| search_query: {search_query}") + # logging.info(f"⏰ ^^ Runtime of getTopContextsWithMQR: {(time.monotonic() - start_time_overall):.2f} seconds") # if len(valid_docs) == 0: # return [] @@ -294,7 +296,7 @@ def getTopContextsWithMQR(self, # except Exception as e: # # return full traceback to front end # err: str = f"ERROR: In /getTopContextsWithMQR. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.format_exc()}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore - # print(err) + # logging.info(err) # sentry_sdk.capture_exception(e) # return err @@ -304,7 +306,7 @@ def format_for_json_mqr(self, found_docs) -> List[Dict]: """ for found_doc in found_docs: if "pagenumber" not in found_doc.keys(): - print("found no pagenumber") + logging.info("found no pagenumber") found_doc['pagenumber'] = found_doc['pagenumber_or_timestamp'] contexts = [ @@ -323,29 +325,35 @@ def format_for_json_mqr(self, found_docs) -> List[Dict]: def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, identifier_value: str): try: - print(f"Nomic delete. Course: {course_name} using {identifier_key}: {identifier_value}") + logging.info(f"Nomic delete. Course: {course_name} using {identifier_key}: {identifier_value}") response = self.sqlDb.getMaterialsForCourseAndKeyAndValue(course_name, identifier_key, identifier_value) - if not response.data: + data = response.data + if not data: raise Exception(f"No materials found for {course_name} using {identifier_key}: {identifier_value}") - data = response.data[0] # single record fetched - nomic_ids_to_delete = [str(data['id']) + "_" + str(i) for i in range(1, len(data['contexts']) + 1)] + data = data[0] # single record fetched + contexts_list = data.contexts if isinstance(data.contexts, list) else [] + nomic_ids_to_delete = [str(data.id) + "_" + str(i) for i in range(1, len(contexts_list) + 1)] # delete from Nomic response = self.sqlDb.getProjectsMapForCourse(course_name) - if not response.data: + data, _count = response.data, response.count + if not data: raise Exception(f"No document map found for this course: {course_name}") - project_id = response.data[0]['doc_map_id'] - self.nomicService.delete_from_document_map(project_id, nomic_ids_to_delete) + project_id = str(data[0].doc_map_id) + if self.nomicService is not None: + self.nomicService.delete_from_document_map(project_id, nomic_ids_to_delete) except Exception as e: - print(f"Nomic Error in deleting. {identifier_key}: {identifier_value}", e) - self.sentry.capture_exception(e) + logging.info(f"Nomic Error in deleting. {identifier_key}: {identifier_value}", e) + if self.sentry is not None: + self.sentry.capture_exception(e) try: - print(f"Supabase Delete. course: {course_name} using {identifier_key}: {identifier_value}") + logging.info(f"Supabase Delete. course: {course_name} using {identifier_key}: {identifier_value}") response = self.sqlDb.deleteMaterialsForCourseAndKeyAndValue(course_name, identifier_key, identifier_value) except Exception as e: - print(f"Supabase Error in delete. {identifier_key}: {identifier_value}", e) - self.sentry.capture_exception(e) + logging.info(f"Supabase Error in delete. {identifier_key}: {identifier_value}", e) + if self.sentry is not None: + self.sentry.capture_exception(e) def vector_search(self, search_query, course_name, doc_groups: List[str] | None = None): """ @@ -375,14 +383,15 @@ def _embed_query_and_measure_latency(self, search_query): return user_query_embedding def _capture_search_invoked_event(self, search_query, course_name, doc_groups): - self.posthog.capture( - event_name="vector_search_invoked", - properties={ - "user_query": search_query, - "course_name": course_name, - "doc_groups": doc_groups, - }, - ) + if self.posthog is not None: + self.posthog.capture( + event_name="vector_search_invoked", + properties={ + "user_query": search_query, + "course_name": course_name, + "doc_groups": doc_groups, + }, + ) def _perform_vector_search(self, search_query, course_name, doc_groups, user_query_embedding, top_n): qdrant_start_time = time.monotonic() @@ -402,26 +411,28 @@ def _process_search_results(self, search_results, course_name): found_docs.append(Document(page_content=page_content, metadata=metadata)) except Exception as e: - print(f"Error in vector_search(), for course: `{course_name}`. Error: {e}") - self.sentry.capture_exception(e) + logging.info(f"Error in vector_search(), for course: `{course_name}`. Error: {e}") + if self.sentry is not None: + self.sentry.capture_exception(e) return found_docs def _capture_search_succeeded_event(self, search_query, course_name, search_results): vector_score_calc_latency_sec = time.monotonic() max_vector_score, min_vector_score, avg_vector_score = self._calculate_vector_scores(search_results) - self.posthog.capture( - event_name="vector_search_succeeded", - properties={ - "user_query": search_query, - "course_name": course_name, - "qdrant_latency_sec": self.qdrant_latency_sec, - "openai_embedding_latency_sec": self.openai_embedding_latency, - "max_vector_score": max_vector_score, - "min_vector_score": min_vector_score, - "avg_vector_score": avg_vector_score, - "vector_score_calculation_latency_sec": time.monotonic() - vector_score_calc_latency_sec, - }, - ) + if self.posthog is not None: + self.posthog.capture( + event_name="vector_search_succeeded", + properties={ + "user_query": search_query, + "course_name": course_name, + "qdrant_latency_sec": self.qdrant_latency_sec, + "openai_embedding_latency_sec": self.openai_embedding_latency, + "max_vector_score": max_vector_score, + "min_vector_score": min_vector_score, + "avg_vector_score": avg_vector_score, + "vector_score_calculation_latency_sec": time.monotonic() - vector_score_calc_latency_sec, + }, + ) def _calculate_vector_scores(self, search_results): max_vector_score = 0 @@ -449,7 +460,7 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]: """ for found_doc in found_docs: if "pagenumber" not in found_doc.metadata.keys(): - print("found no pagenumber") + logging.info("found no pagenumber") found_doc.metadata["pagenumber"] = found_doc.metadata["pagenumber_or_timestamp"] contexts = [ diff --git a/ai_ta_backend/service/sentry_service.py b/ai_ta_backend/service/sentry_service.py index 53b780b0..6c35b066 100644 --- a/ai_ta_backend/service/sentry_service.py +++ b/ai_ta_backend/service/sentry_service.py @@ -1,7 +1,7 @@ import os -import sentry_sdk from injector import inject +import sentry_sdk class SentryService: diff --git a/ai_ta_backend/service/workflow_service.py b/ai_ta_backend/service/workflow_service.py index 1afaeda7..4f63c92d 100644 --- a/ai_ta_backend/service/workflow_service.py +++ b/ai_ta_backend/service/workflow_service.py @@ -1,11 +1,13 @@ -import requests -import time +import json +import logging import os -import supabase +import time from urllib.parse import quote -import json + from injector import inject -from ai_ta_backend.database.sql import SQLDatabase +import requests + +from ai_ta_backend.database.supabase import SQLDatabase class WorkflowService: @@ -78,12 +80,7 @@ def get_executions(self, limit, id=None, pagination: bool = True, api_key: str = else: return all_executions - def get_workflows(self, - limit, - pagination: bool = True, - api_key: str = "", - active: bool = False, - workflow_name: str = ''): + def get_workflows(self, limit, pagination: bool = True, api_key: str = "", active: bool = False, workflow_name: str = ''): if not api_key: raise ValueError('api_key is required') headers = {"X-N8N-API-KEY": api_key, "Accept": "application/json"} @@ -148,7 +145,7 @@ def format_data(self, inputted, api_key: str, workflow_name): new_data[data[k]] = v return new_data except Exception as e: - print("Error in format_data: ", e) + logging.info("Error in format_data: ", e) def switch_workflow(self, id, api_key: str = "", activate: 'str' = 'True'): if not api_key: @@ -165,7 +162,7 @@ def switch_workflow(self, id, api_key: str = "", activate: 'str' = 'True'): def main_flow(self, name: str, api_key: str = "", data: str = ""): if not api_key: raise ValueError('api_key is required') - print("Starting") + logging.info("Starting") hookId = self.get_hook(name, api_key) hook = self.url + f"/form/{hookId}" @@ -179,22 +176,22 @@ def main_flow(self, name: str, api_key: str = "", data: str = ""): if len(ids) > 0: id = max(ids) + 1 - print("Execution found in supabase: ", id) + logging.info("Execution found in supabase: ", id) else: execution = self.get_executions(limit=1, api_key=api_key, pagination=False) - print("Got executions") + logging.info("Got executions") if execution: - print(execution) + logging.info(execution) id = int(execution[0]['id']) + 1 - print("Execution found through n8n: ", id) + logging.info("Execution found through n8n: ", id) else: raise Exception('No executions found') id = str(id) try: self.sqlDb.lockWorkflow(id) - print("inserted flow into supabase") + logging.info("inserted flow into supabase") self.execute_flow(hook, new_data) - print("Executed workflow") + logging.info("Executed workflow") except Exception as e: # TODO: Decrease number by one, is locked false # self.supabase_client.table('n8n_workflows').update({"latest_workflow_id": str(int(id) - 1), "is_locked": False}).eq('latest_workflow_id', id).execute() @@ -208,11 +205,11 @@ def main_flow(self, name: str, api_key: str = "", data: str = ""): executions = self.get_executions(20, id, True, api_key) while executions is None: executions = self.get_executions(20, id, True, api_key) - print("Can't find id in executions") + logging.info("Can't find id in executions") time.sleep(1) - print("Found id in executions ") + logging.info("Found id in executions ") self.sqlDb.deleteLatestWorkflowId(id) - print("Deleted id") + logging.info("Deleted id") except Exception as e: self.sqlDb.deleteLatestWorkflowId(id) return {"error": str(e)} diff --git a/ai_ta_backend/utils/agent_utils.py b/ai_ta_backend/utils/agent_utils.py new file mode 100644 index 00000000..5181c828 --- /dev/null +++ b/ai_ta_backend/utils/agent_utils.py @@ -0,0 +1,370 @@ +import json +# from ai_ta_backend.model.response import FunctionMessage, ToolInvocation +from dotenv import load_dotenv +from langchain_community.utilities.sql_database import SQLDatabase +from langchain_openai import ChatOpenAI, OpenAI +from langchain_community.agent_toolkits import create_sql_agent +from langchain_core.prompts import ( + ChatPromptTemplate, + FewShotPromptTemplate, + MessagesPlaceholder, + PromptTemplate, + SystemMessagePromptTemplate, +) +import os +import logging +from flask_sqlalchemy import SQLAlchemy + +from langchain_openai import ChatOpenAI + +from operator import itemgetter + +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool +from langchain_community.agent_toolkits import create_sql_agent +from langchain.tools import BaseTool, StructuredTool, Tool, tool +import random +from langgraph.prebuilt.tool_executor import ToolExecutor +from langchain.tools.render import format_tool_to_openai_function + + +from typing import TypedDict, Annotated, Sequence +import operator +from langchain_core.messages import BaseMessage + +from langchain_core.agents import AgentFinish +from langgraph.prebuilt import ToolInvocation +import json +from langchain_core.messages import FunctionMessage +from langchain_community.utilities import SQLDatabase +from langchain_community.vectorstores import FAISS +from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_openai import OpenAIEmbeddings + + + +def get_dynamic_prompt_template(): + + examples = [ + { + "input": "How many accepted names are only distributed in Karnataka?", + "query": 'SELECT COUNT(*) as unique_pairs_count FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "Karnataka%"));' + }, + { + "input": "How many names were authored by Roxb?", + "query": 'SELECT COUNT(*) as unique_pairs_count FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Author_Name" LIKE "%Roxb%" AND "Record_Type_Code" IN ("AN", "SN"));' + }, + { + "input": "How many species have distributions in Myanmar, Meghalaya and Andhra Pradesh?", + "query": 'SELECT COUNT(*) FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%"));' + }, + { + "input": "List the accepted names common to Myanmar, Meghalaya, Odisha, Andhra Pradesh.", + "query": 'SELECT DISTINCT Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Odisha%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%"));' + }, + { + "input": "List the accepted names that represent 'tree'.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "HB" AND "Additional_Details_2" LIKE "%tree%");' + }, + { + "input": "List the accepted names linked with Endemic tag.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND "Additional_Details_2" LIKE "%Endemic%");' + }, + { + "input": "List the accepted names published in Fl. Brit. India [J. D. Hooker].", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" in ("AN", "SN") AND ("Publication" LIKE "%Fl. Brit. India [J. D. Hooker]%" OR "Publication" LIKE "%[J. D. Hooker]%" OR "Publication" LIKE "%Fl. Brit. India%");' + }, + { + "input": "How many accepted names have β€˜Silhet’/ β€˜Sylhet’ in their Type?", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "TY" AND ("Additional_Details_2" LIKE "%Silhet%" OR "Additional_Details_2" LIKE "%Sylhet%"));' + }, + { + "input": "How many species were distributed in Sikkim and Meghalaya?", + "query": 'SELECT COUNT(*) AS unique_pairs FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Sikkim%" AND Additional_Details_2 LIKE "%Meghalaya%"));' + }, + { + "input": "List the accepted names common to Kerala, Tamil Nadu, Andhra Pradesh, Karnataka, Maharashtra, Odisha, Meghalaya and Myanmar.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Odisha%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%" AND "Additional_Details_2" LIKE "%Kerala%" AND "Additional_Details_2" LIKE "%Tamil Nadu%" AND "Additional_Details_2" LIKE "%Karnataka%" AND "Additional_Details_2" LIKE "%Maharashtra%"));' + }, + { + "input": "List the accepted names common to Europe, Afghanistan, Jammu & Kashmir, Himachal, Nepal, Sikkim, Bhutan, Arunachal Pradesh and China.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Europe%" AND Additional_Details_2 LIKE "%Afghanistan%" AND "Additional_Details_2" LIKE "%Jammu & Kashmir%" AND "Additional_Details_2" LIKE "%Himachal%" AND "Additional_Details_2" LIKE "%Nepal%" AND "Additional_Details_2" LIKE "%Sikkim%" AND "Additional_Details_2" LIKE "%Bhutan%" AND "Additional_Details_2" LIKE "%Arunachal Pradesh%" AND "Additional_Details_2" LIKE "%China%"));' + }, + { + "input": "List the accepted names common to Europe, Afghanistan, Austria, Belgium, Czechoslovakia, Denmark, France, Greece, Hungary, Italy, Moldava, Netherlands, Poland, Romania, Spain, Switzerland, Jammu & Kashmir, Himachal, Nepal, and China.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Europe%" AND Additional_Details_2 LIKE "%Afghanistan%" AND "Additional_Details_2" LIKE "%Jammu & Kashmir%" AND "Additional_Details_2" LIKE "%Himachal%" AND "Additional_Details_2" LIKE "%Nepal%" AND "Additional_Details_2" LIKE "%Austria%" AND "Additional_Details_2" LIKE "%Belgium%" AND "Additional_Details_2" LIKE "%Czechoslovakia%" AND "Additional_Details_2" LIKE "%China%" AND "Additional_Details_2" LIKE "%Denmark%" AND "Additional_Details_2" LIKE "%Greece%" AND "Additional_Details_2" LIKE "%France%" AND "Additional_Details_2" LIKE "%Hungary%" AND "Additional_Details_2" LIKE "%Italy%" AND "Additional_Details_2" LIKE "%Moldava%" AND "Additional_Details_2" LIKE "%Netherlands%" AND "Additional_Details_2" LIKE "%Poland%" AND "Additional_Details_2" LIKE "%Poland%" AND "Additional_Details_2" LIKE "%Romania%" AND "Additional_Details_2" LIKE "%Spain%" AND "Additional_Details_2" LIKE "%Switzerland%"));' + }, + { + "input": "List the species which are distributed in Sikkim and Meghalaya.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Sikkim%" AND Additional_Details_2 LIKE "%Meghalaya%"));' + }, + { + "input": "How many species are common to America, Europe, Africa, Asia, and Australia?", + "query": 'SELECT COUNT(*) AS unique_pairs IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%America%" AND Additional_Details_2 LIKE "%Europe%" AND "Additional_Details_2" LIKE "%Africa%" AND "Additional_Details_2" LIKE "%Asia%" AND "Additional_Details_2" LIKE "%Australia%"));' + }, + { + "input": "List the species names common to India and Myanmar, Malaysia, Indonesia, and Australia.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number","Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%India%" AND Additional_Details_2 LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Malaysia%" AND Additional_Details_2 LIKE "%Indonesia%" AND Additional_Details_2 LIKE "%Australia%"));' + }, + { + "input": "List all plants which are tagged as urban.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Urban" = "YES";' + }, + { + "input": "List all plants which are tagged as fruit.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Fruit" = "YES";' + }, + { + "input": "List all plants which are tagged as medicinal.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Medicinal" = "YES";' + }, + { + "input": "List all family names which are gymnosperms.", + "query": 'SELECT DISTINCT "Family_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Groups" = "Gymnosperms";' + }, + { + "input": "How many accepted names are tagged as angiosperms?", + "query": 'SELECT COUNT(DISTINCT "Scientific_Name") FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Groups" = "Angiosperms";' + }, + { + "input": "How many accepted names belong to the 'Saxifraga' genus?", + "query": 'SELECT COUNT(DISTINCT "Scientific_Name") FROM plants WHERE "Genus_Name" = "Saxifraga";' + }, + { + "input": "List the accepted names tagged as 'perennial herb' or 'climber'.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "HB" AND ("Additional_Details_2" LIKE "%perennial herb%" OR "Additional_Details_2" LIKE "%climber%"));' + }, + { + "input": "How many accepted names are native to South Africa?", + "query": 'SELECT COUNT(*) FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%native%" AND "Additional_Details_2" LIKE "%south%" AND "Additional_Details_2" LIKE "%africa%");' + + }, + { + "input": "List the accepted names which were introduced and naturalized.", + "query": 'SELECT DISTINCT "Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%introduced%" AND "Additional_Details_2" LIKE "%naturalized%");' + }, + { + "input": "List all ornamental plants.", + "query": 'SELECT DISTINCT "Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%ornamental%");' + }, + { + "input": "How many plants from the 'Leguminosae' family have a altitudinal range up to 1000 m?", + "query": 'SELECT COUNT(*) FROM plants WHERE "Record_Type_Code" = "AL" AND "Family_Name" = "Leguminosae" AND "Additional_Details_2" LIKE "%1000%";' + }, + { + "input": "List the accepted names linked with the 'endemic' tag for Karnataka.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND "Additional_Details_2" LIKE "%Endemic%" AND "Additional_Details_2" LIKE "%Karnataka%");' + }, + {"input": "List all the accepted names under the family 'Gnetaceae'.", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Family_Number" IN ( + SELECT DISTINCT "Family_Number" FROM plants WHERE "Family_Name" = "Gnetaceae" + ) +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +"""}, + { + "input": "List all the accepted species that are introduced.", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Record_Type_Code" = 'RE'and "Additional_Details_2" LIKE '%cultivated%' +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input": "List all the accepted names with type 'Cycad'", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Record_Type_Code" = 'HB'and "Additional_Details_2" LIKE '%Cycad%' + +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input": "List all the accepted names under the genus 'Cycas' with more than two synonyms.", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Genus_Number" IN ( + SELECT DISTINCT "Genus_Number" FROM plants WHERE "Genus_Name" = 'Cycas' + ) + AND "Family_Number" IN ( + SELECT DISTINCT "Family_Number" FROM plants WHERE "Genus_Name" = 'Cycas' + ) + AND "Synonym_Number" > 2 +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input":'List all the accepted names published in Asian J. Conservation Biol.', + "query": """ + SELECT DISTINCT "Scientific_Name" + FROM plants + WHERE "Record_Type_Code" = 'AN' AND "Publication" LIKE '%Asian J. Conservation Biol%'; + +""", + }, + { + "input": 'List all the accepted names linked with endemic tag.', + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Record_Type_Code" = 'DB'and "Additional_Details_2" LIKE '%Endemic%' + +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input": 'List all the accepted names that have no synonyms.' , + "query": """ +SELECT DISTINCT a."Scientific_Name" FROM plants a +group by a."Family_Number",a."Genus_Number",a."Accepted_name_number" +HAVING SUM(a."Synonym_Number") = 0 AND a."Accepted_name_number" > 0; +""", + }, + { + "input": 'List all the accepted names authored by Roxb.', + "query": """ +SELECT "Scientific_Name" +FROM plants +WHERE "Record_Type_Code" = 'AN'AND "Author_Name" LIKE '%Roxb%'; +""", + }, + { + "input": 'List all genera within each family', + "query": """ +SELECT "Family_Name", "Genus_Name" +FROM plants +WHERE "Record_Type_Code" = 'GE'; +""", + }, + { + "input": 'Did Minq. discovered Cycas ryumphii?', + "query": """SELECT + CASE + WHEN EXISTS ( + SELECT 1 + FROM plants as a + WHERE a."Scientific_Name" = 'Cycas rumphii' + AND a."Author_Name" = 'Miq.' + ) THEN 'TRUE' + ELSE 'FALSE' + END AS ExistsCheck; +"""}, + + ] + + + example_selector = SemanticSimilarityExampleSelector.from_examples( + examples, + OpenAIEmbeddings(), + FAISS, + k=5, + input_keys=["input"], + ) + + + prefix_prompt = """ + You are an agent designed to interact with a SQL database. + Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. + You can order the results by a relevant column to return the most interesting examples in the database. + Never query for all the columns from a specific table, only ask for the relevant columns given the question. + You have access to tools for interacting with the database. + Only use the given tools. Only use the information returned by the tools to construct your final answer. + You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. + + - Restrict your queries to the "plants" table. + - Do not return more than {top_k} rows unless specified otherwise. + - Add a limit of 25 at the end of SQL query. + - If the SQLite query returns zero rows, return a message indicating the same. + - Only refer to the data contained in the {table_info} table. Do not fabricate any data. + - For filtering based on string comparison, always use the LIKE operator and enclose the string in `%`. + - Queries on the `Additional_Details_2` column should use sub-queries involving `Family_Number`, `Genus_Number` and `Accepted_name_number`. + + Refer to the table description below for more details on the columns: + 1. **Record_Type_Code**: Contains text codes indicating the type of information in the row. + - FA: Family Name, Genus Name, Scientific Name + - TY: Type + - GE: Genus Name + - AN: Family Name (Accepted Name), Genus Name, Scientific Name, Author Name, Publication, Volume:Page, Year of Publication + - HB: Habit + - DB: Distribution/location of the plant + - RE: Remarks + - SN: Family Name (Synonym), Genus Name, Scientific Name, Author Name, Publication, Volume:Page, Year of Publication + 2. **Family_Name**: Contains the Family Name of the plant. + 3. **Genus_Name**: Contains the Genus Name of the plant. + 4. **Scientific_Name**: Contains the Scientific Name of the plant species. + 5. **Publication_Name**: Name of the journal or book where the plant discovery information is published. Use LIKE for queries. + 6. **Volume:_Page**: The volume and page number of the publication. + 7. **Year_of_Publication**: The year in which the plant information was published. + 8. **Author_Name**: May contain multiple authors separated by `&`. Use LIKE for queries. + 9. **Additional_Details**: Contains type, habit, distribution, and remarks. Use LIKE for queries. + - Type: General location information. + - Remarks: Location information about cultivation or native area. + - Distribution: Locations where the plant is common. May contain multiple locations, use LIKE for queries. + 10. **Groups**: Contains either "Gymnosperms" or "Angiosperms". + 11. **Urban**: Contains either "YES" or "NO". Specifies whether the plant is urban. + 12. **Fruit**: Contains either "YES" or "NO". Specifies whether the plant is a fruit plant. + 13. **Medicinal**: Contains either "YES" or "NO". Specifies whether the plant is medicinal. + 14. **Genus_Number**: Contains the Genus Number of the plant. + 15. **Accepted_name_number**: Contains the Accepted Name Number of the plant. + + Below are examples of questions and their corresponding SQL queries. + """ + + + + agent_prompt = PromptTemplate.from_template("User input: {input}\nSQL Query: {query}") + agent_prompt_obj = FewShotPromptTemplate( + example_selector=example_selector, + example_prompt=agent_prompt, + prefix=prefix_prompt, + suffix="", + input_variables=["input"], + ) + + full_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate(prompt=agent_prompt_obj), + ("human", "{input}"), + MessagesPlaceholder("agent_scratchpad"), + ] + ) + return full_prompt + +def initalize_sql_agent(llm, db): + + dynamic_few_shot_prompt = get_dynamic_prompt_template() + + agent = create_sql_agent(llm, db=db, prompt=dynamic_few_shot_prompt, agent_type="openai-tools", verbose=True) + + return agent + +def generate_response_agent(agent,user_question): + response = agent.invoke({"input": user_question}) + return response \ No newline at end of file diff --git a/ai_ta_backend/utils/context_parent_doc_padding.py b/ai_ta_backend/utils/context_parent_doc_padding.py index fc0ba19c..e42aed5f 100644 --- a/ai_ta_backend/utils/context_parent_doc_padding.py +++ b/ai_ta_backend/utils/context_parent_doc_padding.py @@ -1,8 +1,9 @@ -import os -import time from concurrent.futures import ProcessPoolExecutor from functools import partial +import logging from multiprocessing import Manager +import os +import time DOCUMENTS_TABLE = os.environ['SUPABASE_DOCUMENTS_TABLE'] # SUPABASE_CLIENT = supabase.create_client(supabase_url=os.environ['SUPABASE_URL'], @@ -13,7 +14,7 @@ def context_parent_doc_padding(found_docs, search_query, course_name): """ Takes top N contexts acquired from QRANT similarity search and pads them """ - print("inside main context padding") + logging.info("inside main context padding") start_time = time.monotonic() with Manager() as manager: @@ -33,7 +34,7 @@ def context_parent_doc_padding(found_docs, search_query, course_name): result_contexts = supabase_contexts_no_duplicates + list(qdrant_contexts) - print(f"⏰ Context padding runtime: {(time.monotonic() - start_time):.2f} seconds") + logging.info(f"⏰ Context padding runtime: {(time.monotonic() - start_time):.2f} seconds") return result_contexts @@ -68,14 +69,11 @@ def supabase_context_padding(doc, course_name, result_docs): # query by url or s3_path if 'url' in doc.metadata.keys() and doc.metadata['url']: parent_doc_id = doc.metadata['url'] - response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', - course_name).eq('url', parent_doc_id).execute() + response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', course_name).eq('url', parent_doc_id).execute() else: parent_doc_id = doc.metadata['s3_path'] - response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', - course_name).eq('s3_path', - parent_doc_id).execute() + response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', course_name).eq('s3_path', parent_doc_id).execute() data = response.data @@ -83,10 +81,10 @@ def supabase_context_padding(doc, course_name, result_docs): # do the padding filename = data[0]['readable_filename'] contexts = data[0]['contexts'] - #print("no of contexts within the og doc: ", len(contexts)) + #logging.info("no of contexts within the og doc: ", len(contexts)) if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys(): - #print("inside chunk index") + #logging.info("inside chunk index") # pad contexts by chunk index + 3 and - 3 target_chunk_index = doc.metadata['chunk_index'] for context in contexts: @@ -100,7 +98,7 @@ def supabase_context_padding(doc, course_name, result_docs): result_docs.append(context) elif doc.metadata['pagenumber'] != '': - #print("inside page number") + #logging.info("inside page number") # pad contexts belonging to same page number pagenumber = doc.metadata['pagenumber'] @@ -115,7 +113,7 @@ def supabase_context_padding(doc, course_name, result_docs): result_docs.append(context) else: - #print("inside else") + #logging.info("inside else") # refactor as a Supabase object and append context_dict = { 'text': doc.page_content, diff --git a/ai_ta_backend/utils/emails.py b/ai_ta_backend/utils/emails.py index 4312a35d..1d001bfa 100644 --- a/ai_ta_backend/utils/emails.py +++ b/ai_ta_backend/utils/emails.py @@ -1,7 +1,7 @@ -import os -import smtplib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +import os +import smtplib def send_email(subject: str, body_text: str, sender: str, receipients: list, bcc_receipients: list): diff --git a/ai_ta_backend/utils/filtering_contexts.py b/ai_ta_backend/utils/filtering_contexts.py index 03deede0..83d502c1 100644 --- a/ai_ta_backend/utils/filtering_contexts.py +++ b/ai_ta_backend/utils/filtering_contexts.py @@ -45,7 +45,7 @@ # def filter_context(self, context, user_query, langsmith_prompt_obj): # final_prompt = str(langsmith_prompt_obj.format(context=context, user_query=user_query)) -# # print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^") +# # logging.info(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^") # try: # # completion = run_caii_hosted_llm(final_prompt) # # completion = run_replicate(final_prompt) @@ -53,7 +53,7 @@ # return {"completion": completion, "context": context} # except Exception as e: # sentry_sdk.capture_exception(e) -# print(f"Error: {e}") +# logging.info(f"Error: {e}") # def run_caii_hosted_llm(prompt, max_tokens=300, temp=0.3, **kwargs): # """ @@ -87,7 +87,7 @@ # # "max_new_tokens": 250, # # "presence_penalty": 1 # # }) -# print(output) +# logging.info(output) # return output # def run_anyscale(prompt, model_name="HuggingFaceH4/zephyr-7b-beta"): @@ -110,12 +110,12 @@ # ) # output = ret["choices"][0]["message"]["content"] # type: ignore -# print("Response from Anyscale:", output[:150]) +# logging.info("Response from Anyscale:", output[:150]) # # input_length = len(tokenizer.encode(prompt)) # # output_length = len(tokenizer.encode(output)) # # Input tokens {input_length}, output tokens: {output_length}" -# print(f"^^^^ one anyscale call Runtime: {(time.monotonic() - start_time):.2f} seconds.") +# logging.info(f"^^^^ one anyscale call Runtime: {(time.monotonic() - start_time):.2f} seconds.") # return output # def parse_result(result: str): @@ -130,7 +130,7 @@ # timeout: Optional[float] = None, # max_concurrency: Optional[int] = 180): -# print("⏰⏰⏰ Starting filter_top_contexts() ⏰⏰⏰") +# logging.info("⏰⏰⏰ Starting filter_top_contexts() ⏰⏰⏰") # timeout = timeout or float(os.environ["FILTER_TOP_CONTEXTS_TIMEOUT_SECONDS"]) # # langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr") # TOO UNSTABLE, service offline @@ -138,8 +138,8 @@ # posthog = Posthog(sync_mode=True, project_api_key=os.environ['POSTHOG_API_KEY'], host='https://app.posthog.com') # max_concurrency = min(100, len(contexts)) -# print("max_concurrency is max of 100, or len(contexts), whichever is less ---- Max concurrency:", max_concurrency) -# print("Num contexts to filter:", len(contexts)) +# logging.info("max_concurrency is max of 100, or len(contexts), whichever is less ---- Max concurrency:", max_concurrency) +# logging.info("Num contexts to filter:", len(contexts)) # # START TASKS # actor = AsyncActor.options(max_concurrency=max_concurrency, num_cpus=0.001).remote() # type: ignore @@ -161,10 +161,10 @@ # r['context'] for r in results if r and 'context' in r and 'completion' in r and parse_result(r['completion']) # ] -# print("🧠🧠 TOTAL DOCS PROCESSED BY ANYSCALE FILTERING:", len(results)) -# print("🧠🧠 TOTAL DOCS KEPT, AFTER FILTERING:", len(best_contexts_to_keep)) +# logging.info("🧠🧠 TOTAL DOCS PROCESSED BY ANYSCALE FILTERING:", len(results)) +# logging.info("🧠🧠 TOTAL DOCS KEPT, AFTER FILTERING:", len(best_contexts_to_keep)) # mqr_runtime = round(time.monotonic() - start_time, 2) -# print(f"⏰ Total elapsed time: {mqr_runtime} seconds") +# logging.info(f"⏰ Total elapsed time: {mqr_runtime} seconds") # posthog.capture('distinct_id_of_the_user', # event='filter_top_contexts', @@ -182,9 +182,9 @@ # def run_main(): # start_time = time.monotonic() # # final_passage_list = filter_top_contexts(contexts=CONTEXTS * 2, user_query=USER_QUERY) -# # print("βœ…βœ…βœ… TOTAL included in results: ", len(final_passage_list)) -# print(f"⏰⏰⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds") -# # print("Total contexts:", len(CONTEXTS) * 2) +# # logging.info("βœ…βœ…βœ… TOTAL included in results: ", len(final_passage_list)) +# logging.info(f"⏰⏰⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds") +# # logging.info("Total contexts:", len(CONTEXTS) * 2) # # ! CONDA ENV: llm-serving # if __name__ == "__main__": diff --git a/ai_ta_backend/utils/utils_tokenization.py b/ai_ta_backend/utils/utils_tokenization.py index 956cc196..3736b418 100644 --- a/ai_ta_backend/utils/utils_tokenization.py +++ b/ai_ta_backend/utils/utils_tokenization.py @@ -1,13 +1,13 @@ +import logging import os from typing import Any import tiktoken -def count_tokens_and_cost( - prompt: str, - completion: str = '', - openai_model_name: str = "gpt-3.5-turbo"): # -> tuple[int, float] | tuple[int, float, int, float]: +def count_tokens_and_cost(prompt: str, + completion: str = '', + openai_model_name: str = "gpt-3.5-turbo"): # -> tuple[int, float] | tuple[int, float, int, float]: """ # TODO: improve w/ extra tokens used by model: https://github.com/openai/openai-cookbook/blob/d00e9a48a63739f5b038797594c81c8bb494fc09/examples/How_to_count_tokens_with_tiktoken.ipynb Returns the number of tokens in a text string. @@ -56,7 +56,7 @@ def count_tokens_and_cost( completion_token_cost = 0.0001 / 1_000 else: # no idea of cost - print(f"NO IDEA OF COST, pricing not supported for model model: `{openai_model_name}`") + logging.info(f"NO IDEA OF COST, pricing not supported for model model: `{openai_model_name}`") prompt_token_cost = 0 completion_token_cost = 0 @@ -90,7 +90,7 @@ def analyze_conversations(supabase_client: Any = None): supabase_key=os.getenv('SUPABASE_API_KEY')) # type: ignore # Get all conversations response = supabase_client.table('llm-convo-monitor').select('convo').execute() - # print("total entries", response.data.count) + # logging.info("total entries", response.data.count) total_convos = 0 total_messages = 0 @@ -101,10 +101,10 @@ def analyze_conversations(supabase_client: Any = None): # for convo in response['data']: for convo in response.data: total_convos += 1 - # print(convo) + # logging.info(convo) # prase json from convo # parse json into dict - # print(type(convo)) + # logging.info(type(convo)) # convo = json.loads(convo) convo = convo['convo'] messages = convo['messages'] @@ -122,15 +122,13 @@ def analyze_conversations(supabase_client: Any = None): if role == 'user': num_tokens, cost = count_tokens_and_cost(prompt=content, openai_model_name=model_name) total_prompt_cost += cost - print(f'User Prompt: {content}, Tokens: {num_tokens}, cost: {cost}') + logging.info(f'User Prompt: {content}, Tokens: {num_tokens}, cost: {cost}') # If the message is from the assistant, it's a completion elif role == 'assistant': - num_tokens_completion, cost_completion = count_tokens_and_cost(prompt='', - completion=content, - openai_model_name=model_name) + num_tokens_completion, cost_completion = count_tokens_and_cost(prompt='', completion=content, openai_model_name=model_name) total_completion_cost += cost_completion - print(f'Assistant Completion: {content}\nTokens: {num_tokens_completion}, cost: {cost_completion}') + logging.info(f'Assistant Completion: {content}\nTokens: {num_tokens_completion}, cost: {cost_completion}') return total_convos, total_messages, total_prompt_cost, total_completion_cost @@ -138,7 +136,7 @@ def analyze_conversations(supabase_client: Any = None): pass # if __name__ == '__main__': -# print('starting main') +# logging.info('starting main') # total_convos, total_messages, total_prompt_cost, total_completion_cost = analyze_conversations() -# print(f'total_convos: {total_convos}, total_messages: {total_messages}') -# print(f'total_prompt_cost: {total_prompt_cost}, total_completion_cost: {total_completion_cost}') +# logging.info(f'total_convos: {total_convos}, total_messages: {total_messages}') +# logging.info(f'total_prompt_cost: {total_prompt_cost}, total_completion_cost: {total_completion_cost}') diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 00000000..c4d3c68d --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,77 @@ +services: + redis: + image: redis:latest + ports: + - 6379:6379 + networks: + - my-network + volumes: + - redis-data:/data + + qdrant: + image: qdrant/qdrant:v1.9.5 + restart: always + container_name: qdrant + ports: + - 6333:6333 + - 6334:6334 + expose: + - 6333 + - 6334 + - 6335 + volumes: + - ./qdrant_data:/qdrant/storage + - ./qdrant_config.yaml:/qdrant/config/production.yaml # Mount the config file directly as a volume + networks: + - my-network + healthcheck: + test: + [ + CMD, + curl, + -f, + -H, + { Authorization: Bearer qd-SbvSWrYpa473J33yPjdL }, + http://localhost:6333/health, + ] + interval: 30s + timeout: 10s + retries: 3 + + minio: + image: minio/minio:RELEASE.2024-06-13T22-53-53Z + environment: + MINIO_ROOT_USER: minioadmin # Customize access key + MINIO_ROOT_PASSWORD: minioadmin # Customize secret key + command: server /data + ports: + - 9000:9000 # Console access + - 9001:9001 # API access + networks: + - my-network + volumes: + - minio-data:/data + + flask_app: + build: . # Directory with Dockerfile for Flask app + # image: kastanday/ai-ta-backend:gunicorn + ports: + - 8000:8000 + volumes: + - ./db:/usr/src/app/db # Mount local directory to store SQLite database + networks: + - my-network + depends_on: + - qdrant + - redis + - minio + +# declare the network resource +# this will allow you to use service discovery and address a container by its name from within the network +networks: + my-network: {} + +volumes: + redis-data: {} + qdrant-data: {} + minio-data: {} diff --git a/init-scripts/init-db.sql b/init-scripts/init-db.sql new file mode 100644 index 00000000..49b1b646 --- /dev/null +++ b/init-scripts/init-db.sql @@ -0,0 +1,15 @@ +CREATE TABLE public.documents ( + id BIGINT GENERATED BY DEFAULT AS IDENTITY, + created_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NOW(), + s3_path TEXT NULL, + readable_filename TEXT NULL, + course_name TEXT NULL, + url TEXT NULL, + contexts JSONB NULL, + base_url TEXT NULL, + CONSTRAINT documents_pkey PRIMARY KEY (id) +) TABLESPACE pg_default; + +CREATE INDEX IF NOT EXISTS documents_course_name_idx ON public.documents USING hash (course_name) TABLESPACE pg_default; + +CREATE INDEX IF NOT EXISTS documents_created_at_idx ON public.documents USING btree (created_at) TABLESPACE pg_default; diff --git a/plants_of_India_demo.db b/plants_of_India_demo.db new file mode 100644 index 00000000..e69de29b diff --git a/qdrant_config.yaml b/qdrant_config.yaml new file mode 100644 index 00000000..3adc978a --- /dev/null +++ b/qdrant_config.yaml @@ -0,0 +1,204 @@ +debug: false +log_level: INFO + +storage: + # Where to store all the data + # KEY: use the default location, then map that using docker -v nvme/storage:/qdrant/storage + storage_path: /qdrant/storage + + # Where to store snapshots + snapshots_path: /qdrant/storage/snapshots + + # Optional setting. Specify where else to store temp files as default is ./storage. + # Route to another location on your system to reduce network disk use. + temp_path: /qdrant/storage/temp + + # If true - a point's payload will not be stored in memory. + # It will be read from the disk every time it is requested. + # This setting saves RAM by (slightly) increasing the response time. + # Note: those payload values that are involved in filtering and are indexed - remain in RAM. + on_disk_payload: false + + # Write-ahead-log related configuration + wal: + # Size of a single WAL segment + wal_capacity_mb: 32 + + # Number of WAL segments to create ahead of actual data requirement + wal_segments_ahead: 0 + + # Normal node - receives all updates and answers all queries + node_type: Normal + + # Listener node - receives all updates, but does not answer search/read queries + # Useful for setting up a dedicated backup node + # node_type: "Listener" + + performance: + # Number of parallel threads used for search operations. If 0 - auto selection. + max_search_threads: 4 + # Max total number of threads, which can be used for running optimization processes across all collections. + # Note: Each optimization thread will also use `max_indexing_threads` for index building. + # So total number of threads used for optimization will be `max_optimization_threads * max_indexing_threads` + max_optimization_threads: 1 + + optimizers: + # The minimal fraction of deleted vectors in a segment, required to perform segment optimization + deleted_threshold: 0.2 + + # The minimal number of vectors in a segment, required to perform segment optimization + vacuum_min_vector_number: 1000 + + # Target amount of segments optimizer will try to keep. + # Real amount of segments may vary depending on multiple parameters: + # - Amount of stored points + # - Current write RPS + # + # It is recommended to select default number of segments as a factor of the number of search threads, + # so that each segment would be handled evenly by one of the threads. + # If `default_segment_number = 0`, will be automatically selected by the number of available CPUs + default_segment_number: 0 + + # Do not create segments larger this size (in KiloBytes). + # Large segments might require disproportionately long indexation times, + # therefore it makes sense to limit the size of segments. + # + # If indexation speed have more priority for your - make this parameter lower. + # If search speed is more important - make this parameter higher. + # Note: 1Kb = 1 vector of size 256 + # If not set, will be automatically selected considering the number of available CPUs. + max_segment_size_kb: null + + # Maximum size (in KiloBytes) of vectors to store in-memory per segment. + # Segments larger than this threshold will be stored as read-only memmaped file. + # To enable memmap storage, lower the threshold + # Note: 1Kb = 1 vector of size 256 + # To explicitly disable mmap optimization, set to `0`. + # If not set, will be disabled by default. + memmap_threshold_kb: null + + # Maximum size (in KiloBytes) of vectors allowed for plain index. + # Default value based on https://github.com/google-research/google-research/blob/master/scann/docs/algorithms.md + # Note: 1Kb = 1 vector of size 256 + # To explicitly disable vector indexing, set to `0`. + # If not set, the default value will be used. + indexing_threshold_kb: 20000 + + # Interval between forced flushes. + flush_interval_sec: 5 + + # Max number of threads, which can be used for optimization per collection. + # Note: Each optimization thread will also use `max_indexing_threads` for index building. + # So total number of threads used for optimization will be `max_optimization_threads * max_indexing_threads` + # If `max_optimization_threads = 0`, optimization will be disabled. + max_optimization_threads: 1 + + # Default parameters of HNSW Index. Could be overridden for each collection or named vector individually + hnsw_index: + # Number of edges per node in the index graph. Larger the value - more accurate the search, more space required. + m: 16 + # Number of neighbours to consider during the index building. Larger the value - more accurate the search, more time required to build index. + ef_construct: 100 + # Minimal size (in KiloBytes) of vectors for additional payload-based indexing. + # If payload chunk is smaller than `full_scan_threshold_kb` additional indexing won't be used - + # in this case full-scan search should be preferred by query planner and additional indexing is not required. + # Note: 1Kb = 1 vector of size 256 + full_scan_threshold_kb: 10000 + # Number of parallel threads used for background index building. If 0 - auto selection. + max_indexing_threads: 0 + # Store HNSW index on disk. If set to false, index will be stored in RAM. Default: false + on_disk: false + # Custom M param for hnsw graph built for payload index. If not set, default M will be used. + payload_m: null + +service: + # Maximum size of POST data in a single request in megabytes + max_request_size_mb: 32 + + # Number of parallel workers used for serving the api. If 0 - equal to the number of available cores. + # If missing - Same as storage.max_search_threads + max_workers: 0 + + # Host to bind the service on + host: 0.0.0.0 + + # HTTP(S) port to bind the service on + http_port: 6333 + + # gRPC port to bind the service on. + # If `null` - gRPC is disabled. Default: null + grpc_port: 6334 + # Uncomment to enable gRPC: + # grpc_port: 6334 + + # Enable CORS headers in REST API. + # If enabled, browsers would be allowed to query REST endpoints regardless of query origin. + # More info: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + # Default: true + enable_cors: true + + # Use HTTPS for the REST API + enable_tls: false + + # Check user HTTPS client certificate against CA file specified in tls config + verify_https_client_certificate: false + + # Set an api-key. + # If set, all requests must include a header with the api-key. + # example header: `api-key: ` + # + # If you enable this you should also enable TLS. + # (Either above or via an external service like nginx.) + # Sending an api-key over an unencrypted channel is insecure. + # + # Uncomment to enable. + api_key: qd-SbvSWrYpa473J33yPjdL + +cluster: + # Use `enabled: true` to run Qdrant in distributed deployment mode + enabled: false + + # Configuration of the inter-cluster communication + p2p: + # Port for internal communication between peers + port: 6335 + + # Use TLS for communication between peers + enable_tls: false + + # Configuration related to distributed consensus algorithm + consensus: + # How frequently peers should ping each other. + # Setting this parameter to lower value will allow consensus + # to detect disconnected nodes earlier, but too frequent + # tick period may create significant network and CPU overhead. + # We encourage you NOT to change this parameter unless you know what you are doing. + tick_period_ms: 100 + +# Set to true to prevent service from sending usage statistics to the developers. +# Read more: https://qdrant.tech/documentation/telemetry +telemetry_disabled: false + +# TLS configuration. +# Required if either service.enable_tls or cluster.p2p.enable_tls is true. +tls: + # Server certificate chain file + cert: ./tls/cert.pem + + # Server private key file + key: ./tls/key.pem + + # Certificate authority certificate file. + # This certificate will be used to validate the certificates + # presented by other nodes during inter-cluster communication. + # + # If verify_https_client_certificate is true, it will verify + # HTTPS client certificate + # + # Required if cluster.p2p.enable_tls is true. + ca_cert: ./tls/cacert.pem + + # TTL, in seconds, to re-load certificate from disk. Useful for certificate rotations, + # Only works for HTTPS endpoints, gRPC endpoints (including intra-cluster communication) + # doesn't support certificate re-load + cert_ttl: 3600 diff --git a/railway.json b/railway.json index d6d92535..9810ca27 100644 --- a/railway.json +++ b/railway.json @@ -9,7 +9,7 @@ "cmds": [ "python -m venv --copies /opt/venv && . /opt/venv/bin/activate", "pip install pip==23.3.1", - "pip install -r requirements.txt" + "pip install -r ai_ta_backend/requirements.txt" ] }, "setup": {