Skip to content

Commit 9247b6d

Browse files
authored
Support message status for assistants (#81)
1 parent c30a993 commit 9247b6d

File tree

8 files changed

+107
-27
lines changed

8 files changed

+107
-27
lines changed

examples/async/assistants/runs.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,15 @@ async def main() -> None:
3939
run = await assistant.run(thread)
4040
print(f'second {run=}')
4141
result = await run
42-
print(f'run {result=}')
43-
44-
run = await sdk.runs.get_last_by_thread(thread)
45-
print(f'last run in thread, same as last one: {run}')
42+
print(f'run {result=} with a run status {result.status.name}')
43+
44+
# you could get access to message status, which is different from run status!
45+
assert result.message
46+
print(f'resulting message have status {result.message.status}')
47+
# and check if message was not censored
48+
assert result.message.status.name != 'FILTERED_CONTENT'
49+
# or truncated because of token limits
50+
assert result.message.status.name != 'TRUNCATED'
4651

4752
# NB: it doesn't work at the moment at the backend
4853
# async for run in sdk.runs.list(page_size=10):

examples/async/assistants/threads.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ async def main() -> None:
2626
print("and now we could read it:")
2727
async for message in thread:
2828
print(f" {message=}")
29-
print(f" {message.text=}\n")
29+
print(f" {message.text=}")
30+
# Also every message could have TRUNCATED or FILTERED_CONTENT status
31+
print(f" {message.status.name=}\n")
3032

3133
async for thread in sdk.threads.list():
3234
print(f"deleting thread {thread=}")

examples/sync/assistants/runs.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,15 @@ def main() -> None:
3737
run = assistant.run(thread)
3838
print(f'second {run=}')
3939
result = run.wait()
40-
print(f'run {result=}')
41-
42-
run = sdk.runs.get_last_by_thread(thread)
43-
print(f'last run in thread, same as last one: {run}')
40+
print(f'run {result=} with a run status {result.status.name}')
41+
42+
# you could get access to message status, which is different from run status!
43+
assert result.message
44+
print(f'resulting message have status {result.message.status}')
45+
# and check if message was not censored
46+
assert result.message.status.name != 'FILTERED_CONTENT'
47+
# or truncated because of token limits
48+
assert result.message.status.name != 'TRUNCATED'
4449

4550
# NB: it doesn't work at the moment at the backend
4651
# for run in sdk.runs.list(page_size=10):

examples/sync/assistants/threads.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def main() -> None:
2424
print("and now we could read it:")
2525
for message in thread:
2626
print(f" {message=}")
27-
print(f" {message.text=}\n")
27+
print(f" {message.text=}")
28+
# Also every message could have TRUNCATED or FILTERED_CONTENT status
29+
print(f" {message.status.name=}\n")
2830

2931
for thread in sdk.threads.list():
3032
print(f"deleting thread {thread=}")

src/yandex_cloud_ml_sdk/_messages/message.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
from __future__ import annotations
33

44
import dataclasses
5+
import enum
56
from datetime import datetime
67
from typing import TYPE_CHECKING, Any
78

89
from yandex.cloud.ai.assistants.v1.threads.message_pb2 import Message as ProtoMessage
910
from yandex.cloud.ai.assistants.v1.threads.message_pb2 import MessageContent
1011

1112
from yandex_cloud_ml_sdk._types.resource import BaseResource
13+
from yandex_cloud_ml_sdk._utils.proto import ProtoEnumBase
1214

1315
from .base import BaseMessage
1416
from .citations import Citation
@@ -17,6 +19,17 @@
1719
from yandex_cloud_ml_sdk._sdk import BaseSDK
1820

1921

22+
class MessageStatus(ProtoEnumBase, enum.IntEnum):
23+
MESSAGE_STATUS_UNSPECIFIED = ProtoMessage.MessageStatus.MESSAGE_STATUS_UNSPECIFIED
24+
25+
# Message was successfully created by a user or generated by an assistant.
26+
COMPLETED = ProtoMessage.MessageStatus.COMPLETED
27+
# Message generation was truncated due to reaching the maximum allowed number of tokens.
28+
TRUNCATED = ProtoMessage.MessageStatus.TRUNCATED
29+
# Message generation was stopped because potentially sensitive content was detected either in the prompt or in the generated response.
30+
FILTERED_CONTENT = ProtoMessage.MessageStatus.FILTERED_CONTENT
31+
32+
2033
@dataclasses.dataclass(frozen=True)
2134
class Author:
2235
id: str
@@ -31,6 +44,7 @@ class Message(BaseMessage, BaseResource):
3144
labels: dict[str, str] | None
3245
author: Author
3346
citations: tuple[Citation, ...]
47+
status: MessageStatus
3448

3549
@classmethod
3650
def _kwargs_from_message(cls, proto: ProtoMessage, sdk: BaseSDK) -> dict[str, Any]: # type: ignore[override]
@@ -55,6 +69,7 @@ def _kwargs_from_message(cls, proto: ProtoMessage, sdk: BaseSDK) -> dict[str, An
5569
Citation._from_proto(proto=citation, sdk=sdk)
5670
for citation in raw_citations
5771
)
72+
kwargs['status'] = MessageStatus._coerce(proto.status)
5873

5974
return kwargs
6075

tests/assistants/cassettes/test_messages/test_message.gprc.json

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
"module": "yandex.cloud.endpoint.api_endpoint_service_pb2",
1212
"message": {
1313
"endpoints": [
14+
{
15+
"id": "ai-assistants",
16+
"address": "assistant.api.cloud.yandex.net:443"
17+
},
18+
{
19+
"id": "ai-files",
20+
"address": "assistant.api.cloud.yandex.net:443"
21+
},
1422
{
1523
"id": "ai-foundation-models",
1624
"address": "llm.api.cloud.yandex.net:443"
@@ -91,9 +99,21 @@
9199
"id": "certificate-manager-data",
92100
"address": "data.certificate-manager.api.cloud.yandex.net:443"
93101
},
102+
{
103+
"id": "certificate-manager-private-ca",
104+
"address": "private-ca.certificate-manager.api.cloud.yandex.net:443"
105+
},
106+
{
107+
"id": "certificate-manager-private-ca-data",
108+
"address": "data.private-ca.certificate-manager.api.cloud.yandex.net:443"
109+
},
94110
{
95111
"id": "cic",
96-
"address": "cic-api.api.cloud.yandex.net:443"
112+
"address": "cic.api.cloud.yandex.net:443"
113+
},
114+
{
115+
"id": "cloud-registry",
116+
"address": "registry.api.cloud.yandex.net:443"
97117
},
98118
{
99119
"id": "cloudapps",
@@ -109,7 +129,7 @@
109129
},
110130
{
111131
"id": "cloudrouter",
112-
"address": "cic-api.api.cloud.yandex.net:443"
132+
"address": "cloudrouter.api.cloud.yandex.net:443"
113133
},
114134
{
115135
"id": "cloudvideo",
@@ -147,6 +167,14 @@
147167
"id": "endpoint",
148168
"address": "api.cloud.yandex.net:443"
149169
},
170+
{
171+
"id": "fomo-dataset",
172+
"address": "fomo-dataset.api.cloud.yandex.net:443"
173+
},
174+
{
175+
"id": "fomo-tuning",
176+
"address": "fomo-tuning.api.cloud.yandex.net:443"
177+
},
150178
{
151179
"id": "iam",
152180
"address": "iam.api.cloud.yandex.net:443"
@@ -231,6 +259,10 @@
231259
"id": "managed-kubernetes",
232260
"address": "mks.api.cloud.yandex.net:443"
233261
},
262+
{
263+
"id": "managed-metastore",
264+
"address": "metastore.api.cloud.yandex.net:443"
265+
},
234266
{
235267
"id": "managed-mongodb",
236268
"address": "mdb.api.cloud.yandex.net:443"
@@ -251,6 +283,10 @@
251283
"id": "managed-redis",
252284
"address": "mdb.api.cloud.yandex.net:443"
253285
},
286+
{
287+
"id": "managed-spark",
288+
"address": "spark.api.cloud.yandex.net:443"
289+
},
254290
{
255291
"id": "managed-sqlserver",
256292
"address": "mdb.api.cloud.yandex.net:443"
@@ -259,6 +295,10 @@
259295
"id": "marketplace",
260296
"address": "marketplace.api.cloud.yandex.net:443"
261297
},
298+
{
299+
"id": "marketplace-pim",
300+
"address": "marketplace.api.cloud.yandex.net:443"
301+
},
262302
{
263303
"id": "mdb-clickhouse",
264304
"address": "mdb.api.cloud.yandex.net:443"
@@ -303,6 +343,14 @@
303343
"id": "organizationmanager",
304344
"address": "organization-manager.api.cloud.yandex.net:443"
305345
},
346+
{
347+
"id": "quota-manager",
348+
"address": "quota-manager.api.cloud.yandex.net:443"
349+
},
350+
{
351+
"id": "quotamanager",
352+
"address": "quota-manager.api.cloud.yandex.net:443"
353+
},
306354
{
307355
"id": "resource-manager",
308356
"address": "resource-manager.api.cloud.yandex.net:443"
@@ -395,18 +443,18 @@
395443
"cls": "Thread",
396444
"module": "yandex.cloud.ai.assistants.v1.threads.thread_pb2",
397445
"message": {
398-
"id": "fvtdau7nugneg7src7jq",
446+
"id": "fvtitqjkhfg44j7i42kv",
399447
"folderId": "b1ghsjum2v37c2un8h64",
400-
"defaultMessageAuthorId": "fvt83tco9sr23t13qvo2",
401-
"createdBy": "ajek27c96hekgf8f8016",
402-
"createdAt": "2024-10-09T18:21:16.153971Z",
403-
"updatedBy": "ajek27c96hekgf8f8016",
404-
"updatedAt": "2024-10-09T18:21:16.153971Z",
448+
"defaultMessageAuthorId": "fvt9ot56k7q2men11ugp",
449+
"createdBy": "aje6euqn63oa635coh28",
450+
"createdAt": "2025-04-03T17:21:52.850407Z",
451+
"updatedBy": "aje6euqn63oa635coh28",
452+
"updatedAt": "2025-04-03T17:21:52.850407Z",
405453
"expirationConfig": {
406454
"expirationPolicy": "SINCE_LAST_ACTIVE",
407455
"ttlDays": "7"
408456
},
409-
"expiresAt": "2024-10-16T18:21:16.153971Z"
457+
"expiresAt": "2025-04-10T17:21:52.850407Z"
410458
}
411459
}
412460
},
@@ -415,7 +463,7 @@
415463
"cls": "CreateMessageRequest",
416464
"module": "yandex.cloud.ai.assistants.v1.threads.message_service_pb2",
417465
"message": {
418-
"threadId": "fvtdau7nugneg7src7jq",
466+
"threadId": "fvtitqjkhfg44j7i42kv",
419467
"labels": {
420468
"foo": "bar"
421469
},
@@ -434,12 +482,12 @@
434482
"cls": "Message",
435483
"module": "yandex.cloud.ai.assistants.v1.threads.message_pb2",
436484
"message": {
437-
"id": "fvtllnag6l7fbe09vgii",
438-
"threadId": "fvtdau7nugneg7src7jq",
439-
"createdBy": "ajek27c96hekgf8f8016",
440-
"createdAt": "2024-10-09T18:21:16.363275Z",
485+
"id": "fvtkkjv3nps3mqbbklaj",
486+
"threadId": "fvtitqjkhfg44j7i42kv",
487+
"createdBy": "aje6euqn63oa635coh28",
488+
"createdAt": "2025-04-03T17:21:52.982345Z",
441489
"author": {
442-
"id": "fvt83tco9sr23t13qvo2",
490+
"id": "fvt9ot56k7q2men11ugp",
443491
"role": "USER"
444492
},
445493
"labels": {
@@ -453,7 +501,8 @@
453501
}
454502
}
455503
]
456-
}
504+
},
505+
"status": "COMPLETED"
457506
}
458507
}
459508
},
@@ -462,7 +511,7 @@
462511
"cls": "DeleteThreadRequest",
463512
"module": "yandex.cloud.ai.assistants.v1.threads.thread_service_pb2",
464513
"message": {
465-
"threadId": "fvtdau7nugneg7src7jq"
514+
"threadId": "fvtitqjkhfg44j7i42kv"
466515
}
467516
},
468517
"response": {

tests/assistants/test_messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ async def test_message(async_sdk):
2222
assert message.thread_id == thread.id
2323
assert message.text == 'foo'
2424
assert message.author.role == 'USER'
25+
assert message.status.name == 'COMPLETED'
2526

2627
await thread.delete()
2728

tests/models/test_image_generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def check_messages(messages, expected):
9292
author=None,
9393
thread_id='2',
9494
citations=(),
95+
status=0,
9596
)
9697
messages = messages_to_proto(assistant_message)
9798
check_messages(messages, ['a\nb'])

0 commit comments

Comments
 (0)