Skip to content

Commit d37c145

Browse files
authored
Provide in-line citations for file search (#485)
Closes #31 by adding an in-line citation mentioning the referencing file name when File Search is used. See discussion in #31 as to why this is the best we can do right now. Because of the limited amount of information, it would make no sense to direct users at the end of the message with footnotes.
1 parent 78e6c9a commit d37c145

5 files changed

Lines changed: 105 additions & 9 deletions

File tree

pingpong/ai.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from openai.types.beta.threads import ImageFile, MessageContentPartParam
1111
from openai.types.beta.threads.runs import ToolCallsStepDetails, CodeInterpreterToolCall
1212
from pingpong.schemas import CodeInterpreterMessage
13-
1413
import pingpong.models as models
1514
from .config import config
1615

@@ -118,9 +117,10 @@ async def get_ci_messages_from_step(
118117

119118

120119
class BufferedStreamHandler(openai.AsyncAssistantEventHandler):
121-
def __init__(self, *args, **kwargs):
120+
def __init__(self, file_names: dict[str, str], *args, **kwargs):
122121
super().__init__(*args, **kwargs)
123122
self.__buffer = io.BytesIO()
123+
self.file_names = file_names
124124

125125
def enqueue(self, data: Dict) -> None:
126126
self.__buffer.write(orjson.dumps(data))
@@ -150,10 +150,18 @@ async def on_message_created(self, message) -> None:
150150
)
151151

152152
async def on_message_delta(self, delta, snapshot) -> None:
153+
message_delta = delta.model_dump()
154+
for content in message_delta["content"]:
155+
if content["text"]["annotations"]:
156+
for annotation in content["text"]["annotations"]:
157+
if annotation["type"] == "file_citation":
158+
annotation["file_citation"]["file_name"] = self.file_names.get(
159+
annotation["file_citation"]["file_id"], ""
160+
)
153161
self.enqueue(
154162
{
155163
"type": "message_delta",
156-
"delta": delta.model_dump(),
164+
"delta": message_delta,
157165
}
158166
)
159167

@@ -199,6 +207,7 @@ async def run_thread(
199207
thread_id: str,
200208
assistant_id: int,
201209
message: list[MessageContentPartParam],
210+
file_names: dict[str, str] = {},
202211
metadata: Dict[str, str | int] | None = None,
203212
):
204213
try:
@@ -209,8 +218,7 @@ async def run_thread(
209218
content=message,
210219
metadata=metadata,
211220
)
212-
213-
handler = BufferedStreamHandler()
221+
handler = BufferedStreamHandler(file_names=file_names)
214222
async with cli.beta.threads.runs.stream(
215223
thread_id=thread_id,
216224
assistant_id=assistant_id,

pingpong/models.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import asyncio
12
import json
23
from datetime import datetime
3-
from typing import AsyncGenerator, List, Optional
4+
from typing import AsyncGenerator, List, Optional, Union
45

56
from sqlalchemy import Boolean, Column, DateTime, UniqueConstraint
67
from sqlalchemy import Enum as SQLEnum
@@ -687,6 +688,16 @@ async def get_file_ids_by_id(
687688
for file in vector_store.files:
688689
yield file.file_id, file.id
689690

691+
@classmethod
692+
async def get_file_names_ids_by_id(
693+
cls, session: AsyncSession, id_: int
694+
) -> dict[str, str]:
695+
stmt = select(VectorStore).where(VectorStore.id == int(id_))
696+
vector_store = await session.scalar(stmt)
697+
if not vector_store:
698+
return {}
699+
return {file.file_id: file.name for file in vector_store.files}
700+
690701
@classmethod
691702
async def add_files(
692703
cls, session: AsyncSession, vector_store_id: int, file_ids: list[str]
@@ -1698,3 +1709,52 @@ async def add_image_files(
16981709
)
16991710
)
17001711
await session.execute(stmt)
1712+
1713+
@classmethod
1714+
async def get_file_search_files(
1715+
cls, session: AsyncSession, thread_id: int
1716+
) -> dict[str, str]:
1717+
stmt = (
1718+
select(Thread)
1719+
.outerjoin(Thread.assistant)
1720+
.options(
1721+
contains_eager(Thread.assistant).load_only(Assistant.vector_store_id)
1722+
)
1723+
.where(Thread.id == thread_id)
1724+
)
1725+
thread = await session.scalar(stmt)
1726+
if not thread:
1727+
return {}
1728+
return await cls.get_file_search_files_by_thread(session, thread)
1729+
1730+
@classmethod
1731+
async def get_file_search_files_assistant(
1732+
cls, session: AsyncSession, thread_id: int
1733+
) -> tuple[Union["Assistant", None], dict[str, str]]:
1734+
stmt = (
1735+
select(Thread)
1736+
.options(joinedload(Thread.assistant))
1737+
.where(Thread.id == thread_id)
1738+
)
1739+
thread = await session.scalar(stmt)
1740+
if not thread:
1741+
return None, {}
1742+
return thread.assistant, await cls.get_file_search_files_by_thread(
1743+
session, thread
1744+
)
1745+
1746+
@classmethod
1747+
async def get_file_search_files_by_thread(
1748+
cls, session: AsyncSession, thread: "Thread"
1749+
) -> dict[str, str]:
1750+
vector_store_ids: list[int] = list(
1751+
filter(None, [thread.assistant.vector_store_id, thread.vector_store_id])
1752+
)
1753+
if not vector_store_ids:
1754+
return {}
1755+
tasks = [
1756+
VectorStore.get_file_names_ids_by_id(session, vector_store_id)
1757+
for vector_store_id in vector_store_ids
1758+
]
1759+
results = await asyncio.gather(*tasks)
1760+
return {k: v for result in results for k, v in result.items()}

pingpong/server.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,13 +1419,18 @@ async def get_thread(
14191419
class_id: str, thread_id: str, request: Request, openai_client: OpenAIClient
14201420
):
14211421
thread = await models.Thread.get_by_id(request.state.db, int(thread_id))
1422-
messages, assistant, runs_result = await asyncio.gather(
1422+
messages, [assistant, file_names], runs_result = await asyncio.gather(
14231423
openai_client.beta.threads.messages.list(
14241424
thread.thread_id, limit=20, order="desc"
14251425
),
1426-
models.Assistant.get_by_id(request.state.db, thread.assistant_id),
1426+
models.Thread.get_file_search_files_assistant(request.state.db, thread.id),
14271427
openai_client.beta.threads.runs.list(thread.thread_id, limit=1, order="desc"),
14281428
)
1429+
if not assistant:
1430+
raise HTTPException(
1431+
status_code=404,
1432+
detail="Assistant not found",
1433+
)
14291434
last_run = [r async for r in runs_result]
14301435
current_user_ids = [
14311436
request.state.session.user.id
@@ -1436,6 +1441,13 @@ async def get_thread(
14361441
users = {str(u.id): u for u in thread.users}
14371442

14381443
for message in messages.data:
1444+
for content in message.content:
1445+
if content.type and content.type == "text" and content.text.annotations:
1446+
for annotation in content.text.annotations:
1447+
if annotation.type == "file_citation":
1448+
annotation.file_citation.file_name = file_names.get(
1449+
annotation.file_citation.file_id, ""
1450+
)
14391451
user_id = message.metadata.pop("user_id", None)
14401452
if not user_id:
14411453
continue
@@ -1527,11 +1539,19 @@ async def list_thread_messages(
15271539
messages = await openai_client.beta.threads.messages.list(
15281540
thread.thread_id, limit=limit, order="asc", before=before
15291541
)
1542+
file_names = await models.Thread.get_file_search_files(request.state.db, thread.id)
15301543

15311544
if messages.data:
15321545
users = {u.id: u.created for u in thread.users}
15331546

15341547
for message in messages.data:
1548+
for content in message.content:
1549+
if content.type == "text" and content.text.annotations:
1550+
for annotation in content.text.annotations:
1551+
if annotation.type == "file_citation":
1552+
annotation.file_citation.file_name = file_names.get(
1553+
annotation.file_citation.file_id, ""
1554+
)
15351555
user_id = message.metadata.pop("user_id", None)
15361556
if not user_id:
15371557
continue
@@ -1831,12 +1851,13 @@ async def create_run(
18311851
):
18321852
thread = await models.Thread.get_by_id(request.state.db, int(thread_id))
18331853
asst = await models.Assistant.get_by_id(request.state.db, thread.assistant_id)
1834-
1854+
file_names = await models.Thread.get_file_search_files(request.state.db, thread.id)
18351855
stream = run_thread(
18361856
openai_client,
18371857
thread_id=thread.thread_id,
18381858
assistant_id=asst.assistant_id,
18391859
message=[],
1860+
file_names=file_names,
18401861
)
18411862

18421863
return StreamingResponse(stream, media_type="text/event-stream")
@@ -1926,13 +1947,15 @@ async def send_message(
19261947
thread=thread.thread_id,
19271948
)
19281949

1950+
file_names = await models.Thread.get_file_search_files(request.state.db, thread.id)
19291951
# Create a generator that will stream chunks to the client.
19301952
stream = run_thread(
19311953
openai_client,
19321954
thread_id=thread.thread_id,
19331955
assistant_id=asst.assistant_id,
19341956
message=messageContent,
19351957
metadata={"user_id": str(request.state.session.user.id)},
1958+
file_names=file_names,
19361959
)
19371960
return StreamingResponse(stream, media_type="text/event-stream")
19381961

web/pingpong/src/lib/api.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,7 @@ export type TextAnnotationFilePath = {
13071307

13081308
export type TextAnnotationFileCitationFileCitation = {
13091309
file_id: string;
1310+
file_name: string;
13101311
quote: string;
13111312
};
13121313

web/pingpong/src/lib/content.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ export const parseTextContent = (text: Text, baseUrl: string = '') => {
1818
const { start_index, end_index, file_path } = annotation;
1919
const url = join(baseUrl, `/file/${file_path.file_id}`);
2020
replacements.push({ start: start_index, end: end_index, newValue: url });
21+
} else if (annotation.type === 'file_citation') {
22+
const { start_index, end_index, file_citation } = annotation;
23+
const fileName = ` (${file_citation.file_name})`;
24+
replacements.push({ start: start_index, end: end_index, newValue: fileName });
2125
}
2226
}
2327
}

0 commit comments

Comments
 (0)