Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] Feat/optimize query history report #4391

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions backend/ee/onyx/background/celery/apps/primary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from datetime import datetime

from celery import Task

from ee.onyx.background.celery.tasks.query_history import query_history_report
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
from ee.onyx.background.task_name_builders import name_chat_ttl_task
from ee.onyx.background.task_name_builders import name_query_history_report_task
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
Expand Down Expand Up @@ -66,6 +72,11 @@ def check_ttl_management_task(*, tenant_id: str) -> None:
)


#####
# Non-Periodic Tasks
#####


@celery_app.task(
name="autogenerate_usage_report_task",
ignore_result=True,
Expand All @@ -79,3 +90,17 @@ def autogenerate_usage_report_task(*, tenant_id: str) -> None:
user_id=None,
period=None,
)


@build_celery_task_wrapper(name_query_history_report_task)
@celery_app.task(
name="query_history_report_task",
bind=True,
# ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def query_history_report_task(self: Task, *, start: datetime, end: datetime) -> str:
with get_session_with_current_tenant() as db_session:
return query_history_report(
db_session=db_session, request_id=self.request.id, start=start, end=end
)
78 changes: 78 additions & 0 deletions backend/ee/onyx/background/celery/tasks/query_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import csv
import io
from datetime import datetime

from sqlalchemy.orm import Session

from ee.onyx.server.query_history.models import ChatSessionSnapshot
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from ee.onyx.server.query_history.utils import fetch_and_process_chat_session_history
from ee.onyx.server.query_history.utils import ONYX_ANONYMIZED_EMAIL
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import QueryHistoryType
from onyx.file_store.file_store import get_default_file_store


def query_history_report(
db_session: Session, request_id: str, start: datetime, end: datetime
) -> str:
chat_session_history = fetch_chat_session_history(
db_session=db_session, start=start, end=end
)
qa_pairs = construct_qa_pairs(chat_session_history)
persist_chat_session_history(
db_session=db_session,
report_name=f"query_history_report_{request_id}.csv",
qa_pairs=qa_pairs,
)


def fetch_chat_session_history(
db_session: Session, start: datetime, end: datetime
) -> list[ChatSessionSnapshot]:
return fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
feedback_type=None,
limit=None,
)


def construct_qa_pairs(
chat_session_history: list[ChatSessionSnapshot],
) -> list[QuestionAnswerPairSnapshot]:
qa_pairs: list[QuestionAnswerPairSnapshot] = []
for chat_session_snapshot in chat_session_history:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL

qa_pairs.extend(
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
)

return qa_pairs


def persist_chat_session_history(
db_session: Session, report_name: str, qa_pairs: list[QuestionAnswerPairSnapshot]
):
file_store = get_default_file_store(db_session)
stream = io.StringIO()
writer = csv.DictWriter(
stream, fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys())
)

writer.writeheader()
for row in qa_pairs:
writer.writerow(row.to_json())

stream.seek(0)
file_store.save_file(
file_name=report_name,
content=stream,
display_name=report_name,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)
9 changes: 9 additions & 0 deletions backend/ee/onyx/background/task_name_builders.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
from datetime import datetime


def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None) -> str:
return f"chat_ttl_{retention_limit_days}_days"


def name_query_history_report_task(start: datetime, end: datetime) -> str:
start_epoch = int(start.timestamp())
end_epoch = int(end.timestamp())
return f"query_history_report_{start_epoch}_{end_epoch}"
195 changes: 84 additions & 111 deletions backend/ee/onyx/server/query_history/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import csv
import io
from datetime import datetime
from datetime import timezone
Expand All @@ -9,107 +8,33 @@
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Response
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session

from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
from ee.onyx.background.celery.apps.primary import query_history_report_task
from ee.onyx.db.query_history import get_page_of_chat_sessions
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from ee.onyx.server.query_history.models import MessageSnapshot
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from ee.onyx.server.query_history.utils import ONYX_ANONYMIZED_EMAIL
from ee.onyx.server.query_history.utils import snapshot_from_chat_session
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.chat.chat_utils import create_chat_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.engine import get_session
from onyx.db.models import ChatSession
from onyx.db.enums import TaskStatus
from onyx.db.models import User
from onyx.db.tasks import get_task_by_task_id
from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse

router = APIRouter()

ONYX_ANONYMIZED_EMAIL = "[email protected]"


def fetch_and_process_chat_session_history(
db_session: Session,
start: datetime,
end: datetime,
feedback_type: QAFeedbackType | None,
limit: int | None = 500,
) -> list[ChatSessionSnapshot]:
# observed to be slow a scale of 8192 sessions and 4 messages per session

# this is a little slow (5 seconds)
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)

# this is VERY slow (80 seconds) due to create_chat_chain being called
# for each session. Needs optimizing.
chat_session_snapshots = [
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for chat_session in chat_sessions
]

valid_snapshots = [
snapshot for snapshot in chat_session_snapshots if snapshot is not None
]

if feedback_type:
valid_snapshots = [
snapshot
for snapshot in valid_snapshots
if any(
message.feedback_type == feedback_type for message in snapshot.messages
)
]

return valid_snapshots


def snapshot_from_chat_session(
chat_session: ChatSession,
db_session: Session,
) -> ChatSessionSnapshot | None:
try:
# Older chats may not have the right structure
last_message, messages = create_chat_chain(
chat_session_id=chat_session.id, db_session=db_session
)
messages.append(last_message)
except RuntimeError:
return None

flow_type = SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT

return ChatSessionSnapshot(
id=chat_session.id,
user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description,
messages=[
MessageSnapshot.build(message)
for message in messages
if message.message_type != MessageType.SYSTEM
],
assistant_id=chat_session.persona_id,
assistant_name=chat_session.persona.name if chat_session.persona else None,
time_created=chat_session.time_created,
flow_type=flow_type,
)


@router.get("/admin/chat-sessions")
def get_user_chat_sessions(
Expand Down Expand Up @@ -238,52 +163,100 @@ def get_chat_session_admin(
return snapshot


@router.get("/admin/query-history-csv")
def get_query_history_as_csv(
_: User | None = Depends(current_admin_user),
@router.post("/admin/query-history-csv")
def post_query_history_as_csv(
response: Response,
start: datetime | None = None,
end: datetime | None = None,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
) -> None:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)

# this call is very expensive and is timing out via endpoint
# TODO: optimize call and/or generate via background task
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
end=end or datetime.now(tz=timezone.utc),
feedback_type=None,
limit=None,
start = start or datetime.fromtimestamp(0, tz=timezone.utc)
end = end or datetime.now(tz=timezone.utc)
task = query_history_report_task.delay(
start=start,
end=end,
)

question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
for chat_session_snapshot in complete_chat_session_history:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
response.status_code = HTTPStatus.ACCEPTED
response.headers[
"Location"
] = f"/admin/query-history-csv/status?request_id={task.id}"
return {"request_id": task.id}


@router.get("/admin/query-history-csv/status")
def get_query_history_csv_status(
request_id: str,
response: Response,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, TaskStatus]:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)

question_answer_pairs.extend(
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
task_queue_state = get_task_by_task_id(request_id, db_session)
if task_queue_state is None:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Task queue state not found for task id.",
)

# Create an in-memory text stream
stream = io.StringIO()
writer = csv.DictWriter(
stream, fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys())
)
writer.writeheader()
for row in question_answer_pairs:
writer.writerow(row.to_json())
return {
"status": task_queue_state.status,
}


@router.get("/admin/query-history-csv/download")
def download_query_history_csv(
request_id: str,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)

# Reset the stream's position to the start
stream.seek(0)
task_queue_state = get_task_by_task_id(request_id, db_session)
if task_queue_state is None:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Task queue state not found for task id.",
)
elif task_queue_state.status == TaskStatus.FAILURE:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Task failed to complete.",
)
elif task_queue_state.status != TaskStatus.SUCCESS:
raise HTTPException(
status_code=HTTPStatus.NO_CONTENT,
detail="Task is still pending.",
)

# TODO: change this to read from the file store with the file name
# `query_history_report_{request_id}.csv`
test_csv = "user_message,assistant_message,date\n"
test_csv += "Hello, how are you?,I am fine,2021-01-01\n"
test_csv += (
"What is the weather in Tokyo?,The weather in Tokyo is sunny,2021-01-02\n"
)
test_csv += (
"What is the capital of France?,The capital of France is Paris,2021-01-03\n"
)
return StreamingResponse(
iter([stream.getvalue()]),
io.StringIO(test_csv),
media_type="text/csv",
headers={"Content-Disposition": "attachment;filename=onyx_query_history.csv"},
headers={"Content-Disposition": "attachment;filename=query_history.csv"},
)
Loading
Loading