Skip to content

Commit 821c0f5

Browse files
committed
Add citations to run result
1 parent f304b53 commit 821c0f5

File tree

9 files changed

+444
-119
lines changed

9 files changed

+444
-119
lines changed

examples/async/assistants/assistant_with_search_index.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@ async def main() -> None:
4444
print('Question:', search_query)
4545
print('Answer:', result.text)
4646

47+
# You could access .citations attribute for debug purposes
48+
for citation in result.citations:
49+
for source in citation.sources:
50+
# In future there will be more source types
51+
if source.type != 'filechunk':
52+
continue
53+
print('Example source:', source)
54+
# One source will be enough for example, it takes too much screen space to print
55+
break
56+
else:
57+
continue
58+
break
59+
4760
search_query = "Cколько пошлина в Анталье"
4861
await thread.write(search_query)
4962

examples/sync/assistants/assistant_with_search_index.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@ def main() -> None:
4242
print('Question:', search_query)
4343
print('Answer:', result.text)
4444

45+
# You could access .citations attribute for debug purposes
46+
for citation in result.citations:
47+
for source in citation.sources:
48+
# In future there will be more source types
49+
if source.type != 'filechunk':
50+
continue
51+
print('Example source:', source)
52+
# One source will be enough for example, it takes too much screen space to print
53+
break
54+
else:
55+
continue
56+
break
57+
4558
search_query = "Cколько пошлина в Анталье"
4659
thread.write(search_query)
4760

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from __future__ import annotations
2+
3+
import dataclasses
4+
from typing import Any
5+
6+
from yandex_cloud_ml_sdk._types.result import BaseResult
7+
8+
9+
@dataclasses.dataclass(frozen=True)
10+
class BaseMessage(BaseResult):
11+
parts: tuple[Any, ...]
12+
13+
@property
14+
def text(self):
15+
return '\n'.join(
16+
part for part in self.parts
17+
if isinstance(part, str)
18+
)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# pylint: disable=no-name-in-module
2+
3+
from __future__ import annotations
4+
5+
import abc
6+
import dataclasses
7+
from typing import TYPE_CHECKING
8+
9+
from yandex.cloud.ai.assistants.v1.threads.message_pb2 import Citation as ProtoCitation
10+
from yandex.cloud.ai.assistants.v1.threads.message_pb2 import Source as ProtoSource
11+
12+
from yandex_cloud_ml_sdk._files.file import BaseFile
13+
from yandex_cloud_ml_sdk._search_indexes.search_index import BaseSearchIndex
14+
from yandex_cloud_ml_sdk._types.result import BaseResult
15+
16+
from .base import BaseMessage
17+
18+
if TYPE_CHECKING:
19+
from yandex_cloud_ml_sdk._sdk import BaseSDK
20+
21+
22+
@dataclasses.dataclass(frozen=True)
23+
class Citation(BaseResult):
24+
sources: tuple[Source, ...]
25+
26+
@classmethod
27+
def _from_proto(cls, proto: ProtoCitation, sdk: BaseSDK) -> Citation: # type: ignore[override]
28+
return cls(
29+
sources=tuple(
30+
Source._from_proto(proto=source, sdk=sdk)
31+
for source in proto.sources
32+
)
33+
)
34+
35+
class Source(BaseResult):
36+
@property
37+
@abc.abstractmethod
38+
def type(self) -> str:
39+
pass
40+
41+
@classmethod
42+
def _from_proto(cls, proto: ProtoSource, sdk: BaseSDK) -> Source: # type: ignore[override]
43+
if proto.HasField('chunk'):
44+
return FileChunk._from_proto(proto=proto, sdk=sdk)
45+
46+
return UnknownSource._from_proto(proto=proto, sdk=sdk)
47+
48+
49+
@dataclasses.dataclass(frozen=True)
50+
class FileChunk(Source, BaseMessage):
51+
search_index: BaseSearchIndex
52+
file: BaseFile | None
53+
54+
@property
55+
def type(self) -> str:
56+
return 'filechunk'
57+
58+
@classmethod
59+
def _from_proto(cls, proto: ProtoSource, sdk: BaseSDK) -> FileChunk | UnknownSource: # type: ignore[override]
60+
# pylint: disable=protected-access
61+
chunk = proto.chunk
62+
assert chunk
63+
64+
raw_parts = (part.text.content for part in chunk.content.content)
65+
parts = tuple(part for part in raw_parts if part)
66+
67+
search_index = sdk.search_indexes._impl._from_proto(proto=chunk.search_index, sdk=sdk)
68+
file: BaseFile | None = None
69+
70+
# NB: at the moment backend always returns non-empty source_file field
71+
# but in case it deleted, source_file will content empty File structure
72+
if (
73+
chunk.HasField('source_file') and
74+
chunk.source_file and
75+
chunk.source_file.id
76+
):
77+
file = sdk.files._file_impl._from_proto(proto=chunk.source_file, sdk=sdk)
78+
79+
return cls(
80+
search_index=search_index,
81+
file=file,
82+
parts=parts,
83+
)
84+
85+
86+
@dataclasses.dataclass(frozen=True)
87+
class UnknownSource(Source):
88+
text: str
89+
90+
@property
91+
def type(self) -> str:
92+
return 'unknown'
93+
94+
@classmethod
95+
def _from_proto(cls, proto: ProtoSource, sdk: BaseSDK) -> UnknownSource: # type: ignore[override]
96+
return cls(
97+
text="Source's protobuf have unknown fields; try to update yandex-cloud-ml-sdk"
98+
)

src/yandex_cloud_ml_sdk/_messages/message.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
from yandex_cloud_ml_sdk._types.resource import BaseResource
1212

13+
from .base import BaseMessage
14+
from .citations import Citation
15+
1316
if TYPE_CHECKING:
1417
from yandex_cloud_ml_sdk._sdk import BaseSDK
1518

@@ -21,24 +24,13 @@ class Author:
2124

2225

2326
@dataclasses.dataclass(frozen=True)
24-
class BaseMessage(BaseResource):
25-
parts: tuple[Any, ...]
26-
27-
@property
28-
def text(self):
29-
return '\n'.join(
30-
part for part in self.parts
31-
if isinstance(part, str)
32-
)
33-
34-
35-
@dataclasses.dataclass(frozen=True)
36-
class Message(BaseMessage):
27+
class Message(BaseMessage, BaseResource):
3728
thread_id: str
3829
created_by: str
3930
created_at: datetime
4031
labels: dict[str, str] | None
4132
author: Author
33+
citations: tuple[Citation, ...]
4234

4335
@classmethod
4436
def _kwargs_from_message(cls, proto: ProtoMessage, sdk: BaseSDK) -> dict[str, Any]: # type: ignore[override]
@@ -58,12 +50,17 @@ def _kwargs_from_message(cls, proto: ProtoMessage, sdk: BaseSDK) -> dict[str, An
5850
role=proto.author.role,
5951
id=proto.author.id
6052
)
53+
raw_citations = proto.citations or ()
54+
kwargs['citations'] = tuple(
55+
Citation._from_proto(proto=citation, sdk=sdk)
56+
for citation in raw_citations
57+
)
6158

6259
return kwargs
6360

6461

6562
@dataclasses.dataclass(frozen=True)
66-
class PartialMessage(BaseMessage):
63+
class PartialMessage(BaseMessage, BaseResource):
6764
@classmethod
6865
def _kwargs_from_message(cls, proto: MessageContent, sdk: BaseSDK) -> dict[str, Any]: # type: ignore[override]
6966
kwargs = super()._kwargs_from_message(proto, sdk=sdk)

src/yandex_cloud_ml_sdk/_runs/result.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from yandex.cloud.ai.assistants.v1.runs.run_pb2 import Run as ProtoRun
99
from yandex.cloud.ai.assistants.v1.runs.run_service_pb2 import StreamEvent as ProtoStreamEvent
1010

11+
from yandex_cloud_ml_sdk._messages.citations import Citation
1112
from yandex_cloud_ml_sdk._messages.message import BaseMessage, Message, PartialMessage
1213
from yandex_cloud_ml_sdk._models.completions.result import Usage
1314
from yandex_cloud_ml_sdk._types.result import BaseResult, ProtoMessage
@@ -60,7 +61,7 @@ def text(self) -> str:
6061
return self.message.text
6162

6263
@property
63-
def parts(self) -> tuple[Any]:
64+
def parts(self) -> tuple[Any, ...]:
6465
return self.message.parts
6566

6667

@@ -101,6 +102,10 @@ def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> RunResult:
101102
usage=usage,
102103
)
103104

105+
@property
106+
def citations(self) -> tuple[Citation, ...]:
107+
return self.message.citations
108+
104109

105110
@dataclasses.dataclass(frozen=True)
106111
class RunStreamEvent(BaseRunResult[StreamEvent, BaseMessage]):

0 commit comments

Comments
 (0)