|
1 |
| -import csv |
2 | 1 | import io
|
3 | 2 | from datetime import datetime
|
4 | 3 | from datetime import timezone
|
|
9 | 8 | from fastapi import Depends
|
10 | 9 | from fastapi import HTTPException
|
11 | 10 | from fastapi import Query
|
| 11 | +from fastapi import Response |
12 | 12 | from fastapi.responses import StreamingResponse
|
13 | 13 | from sqlalchemy.orm import Session
|
14 | 14 |
|
15 |
| -from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time |
| 15 | +from ee.onyx.background.celery.apps.primary import query_history_report_task |
16 | 16 | from ee.onyx.db.query_history import get_page_of_chat_sessions
|
17 | 17 | from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
|
18 | 18 | from ee.onyx.server.query_history.models import ChatSessionMinimal
|
19 | 19 | from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
20 |
| -from ee.onyx.server.query_history.models import MessageSnapshot |
21 |
| -from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot |
| 20 | +from ee.onyx.server.query_history.utils import ONYX_ANONYMIZED_EMAIL |
| 21 | +from ee.onyx.server.query_history.utils import snapshot_from_chat_session |
22 | 22 | from onyx.auth.users import current_admin_user
|
23 |
| -from onyx.auth.users import get_display_email |
24 |
| -from onyx.chat.chat_utils import create_chat_chain |
25 | 23 | from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
26 |
| -from onyx.configs.constants import MessageType |
27 | 24 | from onyx.configs.constants import QAFeedbackType
|
28 | 25 | from onyx.configs.constants import QueryHistoryType
|
29 |
| -from onyx.configs.constants import SessionType |
30 | 26 | from onyx.db.chat import get_chat_session_by_id
|
31 | 27 | from onyx.db.chat import get_chat_sessions_by_user
|
32 | 28 | from onyx.db.engine import get_session
|
33 |
| -from onyx.db.models import ChatSession |
| 29 | +from onyx.db.enums import TaskStatus |
34 | 30 | from onyx.db.models import User
|
| 31 | +from onyx.db.tasks import get_task_by_task_id |
35 | 32 | from onyx.server.documents.models import PaginatedReturn
|
36 | 33 | from onyx.server.query_and_chat.models import ChatSessionDetails
|
37 | 34 | from onyx.server.query_and_chat.models import ChatSessionsResponse
|
38 | 35 |
|
39 | 36 | router = APIRouter()
|
40 | 37 |
|
41 |
| -ONYX_ANONYMIZED_EMAIL = "[email protected]" |
42 |
| - |
43 |
| - |
44 |
| -def fetch_and_process_chat_session_history( |
45 |
| - db_session: Session, |
46 |
| - start: datetime, |
47 |
| - end: datetime, |
48 |
| - feedback_type: QAFeedbackType | None, |
49 |
| - limit: int | None = 500, |
50 |
| -) -> list[ChatSessionSnapshot]: |
51 |
| - # observed to be slow a scale of 8192 sessions and 4 messages per session |
52 |
| - |
53 |
| - # this is a little slow (5 seconds) |
54 |
| - chat_sessions = fetch_chat_sessions_eagerly_by_time( |
55 |
| - start=start, end=end, db_session=db_session, limit=limit |
56 |
| - ) |
57 |
| - |
58 |
| - # this is VERY slow (80 seconds) due to create_chat_chain being called |
59 |
| - # for each session. Needs optimizing. |
60 |
| - chat_session_snapshots = [ |
61 |
| - snapshot_from_chat_session(chat_session=chat_session, db_session=db_session) |
62 |
| - for chat_session in chat_sessions |
63 |
| - ] |
64 |
| - |
65 |
| - valid_snapshots = [ |
66 |
| - snapshot for snapshot in chat_session_snapshots if snapshot is not None |
67 |
| - ] |
68 |
| - |
69 |
| - if feedback_type: |
70 |
| - valid_snapshots = [ |
71 |
| - snapshot |
72 |
| - for snapshot in valid_snapshots |
73 |
| - if any( |
74 |
| - message.feedback_type == feedback_type for message in snapshot.messages |
75 |
| - ) |
76 |
| - ] |
77 |
| - |
78 |
| - return valid_snapshots |
79 |
| - |
80 |
| - |
81 |
| -def snapshot_from_chat_session( |
82 |
| - chat_session: ChatSession, |
83 |
| - db_session: Session, |
84 |
| -) -> ChatSessionSnapshot | None: |
85 |
| - try: |
86 |
| - # Older chats may not have the right structure |
87 |
| - last_message, messages = create_chat_chain( |
88 |
| - chat_session_id=chat_session.id, db_session=db_session |
89 |
| - ) |
90 |
| - messages.append(last_message) |
91 |
| - except RuntimeError: |
92 |
| - return None |
93 |
| - |
94 |
| - flow_type = SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT |
95 |
| - |
96 |
| - return ChatSessionSnapshot( |
97 |
| - id=chat_session.id, |
98 |
| - user_email=get_display_email( |
99 |
| - chat_session.user.email if chat_session.user else None |
100 |
| - ), |
101 |
| - name=chat_session.description, |
102 |
| - messages=[ |
103 |
| - MessageSnapshot.build(message) |
104 |
| - for message in messages |
105 |
| - if message.message_type != MessageType.SYSTEM |
106 |
| - ], |
107 |
| - assistant_id=chat_session.persona_id, |
108 |
| - assistant_name=chat_session.persona.name if chat_session.persona else None, |
109 |
| - time_created=chat_session.time_created, |
110 |
| - flow_type=flow_type, |
111 |
| - ) |
112 |
| - |
113 | 38 |
|
114 | 39 | @router.get("/admin/chat-sessions")
|
115 | 40 | def get_user_chat_sessions(
|
@@ -238,52 +163,100 @@ def get_chat_session_admin(
|
238 | 163 | return snapshot
|
239 | 164 |
|
240 | 165 |
|
241 |
| -@router.get("/admin/query-history-csv") |
242 |
| -def get_query_history_as_csv( |
243 |
| - _: User | None = Depends(current_admin_user), |
| 166 | +@router.post("/admin/query-history-csv") |
| 167 | +def post_query_history_as_csv( |
| 168 | + response: Response, |
244 | 169 | start: datetime | None = None,
|
245 | 170 | end: datetime | None = None,
|
| 171 | + _: User | None = Depends(current_admin_user), |
246 | 172 | db_session: Session = Depends(get_session),
|
247 |
| -) -> StreamingResponse: |
| 173 | +) -> None: |
248 | 174 | if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
249 | 175 | raise HTTPException(
|
250 | 176 | status_code=HTTPStatus.FORBIDDEN,
|
251 | 177 | detail="Query history has been disabled by the administrator.",
|
252 | 178 | )
|
253 | 179 |
|
254 |
| - # this call is very expensive and is timing out via endpoint |
255 |
| - # TODO: optimize call and/or generate via background task |
256 |
| - complete_chat_session_history = fetch_and_process_chat_session_history( |
257 |
| - db_session=db_session, |
258 |
| - start=start or datetime.fromtimestamp(0, tz=timezone.utc), |
259 |
| - end=end or datetime.now(tz=timezone.utc), |
260 |
| - feedback_type=None, |
261 |
| - limit=None, |
| 180 | + start = start or datetime.fromtimestamp(0, tz=timezone.utc) |
| 181 | + end = end or datetime.now(tz=timezone.utc) |
| 182 | + task = query_history_report_task.delay( |
| 183 | + start=start, |
| 184 | + end=end, |
262 | 185 | )
|
263 | 186 |
|
264 |
| - question_answer_pairs: list[QuestionAnswerPairSnapshot] = [] |
265 |
| - for chat_session_snapshot in complete_chat_session_history: |
266 |
| - if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED: |
267 |
| - chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL |
| 187 | + response.status_code = HTTPStatus.ACCEPTED |
| 188 | + response.headers[ |
| 189 | + "Location" |
| 190 | + ] = f"/admin/query-history-csv/status?request_id={task.id}" |
| 191 | + return {"request_id": task.id} |
| 192 | + |
| 193 | + |
| 194 | +@router.get("/admin/query-history-csv/status") |
| 195 | +def get_query_history_csv_status( |
| 196 | + request_id: str, |
| 197 | + response: Response, |
| 198 | + _: User | None = Depends(current_admin_user), |
| 199 | + db_session: Session = Depends(get_session), |
| 200 | +) -> dict[str, TaskStatus]: |
| 201 | + if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED: |
| 202 | + raise HTTPException( |
| 203 | + status_code=HTTPStatus.FORBIDDEN, |
| 204 | + detail="Query history has been disabled by the administrator.", |
| 205 | + ) |
268 | 206 |
|
269 |
| - question_answer_pairs.extend( |
270 |
| - QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot) |
| 207 | + task_queue_state = get_task_by_task_id(request_id, db_session) |
| 208 | + if task_queue_state is None: |
| 209 | + raise HTTPException( |
| 210 | + status_code=HTTPStatus.NOT_FOUND, |
| 211 | + detail="Task queue state not found for task id.", |
271 | 212 | )
|
272 | 213 |
|
273 |
| - # Create an in-memory text stream |
274 |
| - stream = io.StringIO() |
275 |
| - writer = csv.DictWriter( |
276 |
| - stream, fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()) |
277 |
| - ) |
278 |
| - writer.writeheader() |
279 |
| - for row in question_answer_pairs: |
280 |
| - writer.writerow(row.to_json()) |
| 214 | + return { |
| 215 | + "status": task_queue_state.status, |
| 216 | + } |
| 217 | + |
| 218 | + |
| 219 | +@router.get("/admin/query-history-csv/download") |
| 220 | +def download_query_history_csv( |
| 221 | + request_id: str, |
| 222 | + _: User | None = Depends(current_admin_user), |
| 223 | + db_session: Session = Depends(get_session), |
| 224 | +) -> StreamingResponse: |
| 225 | + if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED: |
| 226 | + raise HTTPException( |
| 227 | + status_code=HTTPStatus.FORBIDDEN, |
| 228 | + detail="Query history has been disabled by the administrator.", |
| 229 | + ) |
281 | 230 |
|
282 |
| - # Reset the stream's position to the start |
283 |
| - stream.seek(0) |
| 231 | + task_queue_state = get_task_by_task_id(request_id, db_session) |
| 232 | + if task_queue_state is None: |
| 233 | + raise HTTPException( |
| 234 | + status_code=HTTPStatus.NOT_FOUND, |
| 235 | + detail="Task queue state not found for task id.", |
| 236 | + ) |
| 237 | + elif task_queue_state.status == TaskStatus.FAILURE: |
| 238 | + raise HTTPException( |
| 239 | + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, |
| 240 | + detail="Task failed to complete.", |
| 241 | + ) |
| 242 | + elif task_queue_state.status != TaskStatus.SUCCESS: |
| 243 | + raise HTTPException( |
| 244 | + status_code=HTTPStatus.NO_CONTENT, |
| 245 | + detail="Task is still pending.", |
| 246 | + ) |
284 | 247 |
|
| 248 | + # TODO: change this to read from the file store with the file name |
| 249 | + # `query_history_report_{request_id}.csv` |
| 250 | + test_csv = "user_message,assistant_message,date\n" |
| 251 | + test_csv += "Hello, how are you?,I am fine,2021-01-01\n" |
| 252 | + test_csv += ( |
| 253 | + "What is the weather in Tokyo?,The weather in Tokyo is sunny,2021-01-02\n" |
| 254 | + ) |
| 255 | + test_csv += ( |
| 256 | + "What is the capital of France?,The capital of France is Paris,2021-01-03\n" |
| 257 | + ) |
285 | 258 | return StreamingResponse(
|
286 |
| - iter([stream.getvalue()]), |
| 259 | + io.StringIO(test_csv), |
287 | 260 | media_type="text/csv",
|
288 |
| - headers={"Content-Disposition": "attachment;filename=onyx_query_history.csv"}, |
| 261 | + headers={"Content-Disposition": "attachment;filename=query_history.csv"}, |
289 | 262 | )
|
0 commit comments