Skip to content
7 changes: 7 additions & 0 deletions src/yandex_cloud_ml_sdk/_messages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@

@dataclasses.dataclass(frozen=True)
class BaseMessage(BaseProtoResult[ProtoMessageTypeT_contra]):
"""
Base class for all message types in the SDK.
"""
#: Tuple containing message parts (can be strings or other types)
parts: tuple[Any, ...]

@property
def text(self) -> str:
"""
Get concatenated string of all text parts in the message by joining all string parts.
"""
return '\n'.join(
part for part in self.parts
if isinstance(part, str)
Expand Down
25 changes: 25 additions & 0 deletions src/yandex_cloud_ml_sdk/_messages/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

@dataclasses.dataclass(frozen=True)
class Citation(BaseProtoResult[ProtoCitation]):
"""
Represents a citation with multiple sources.
"""
#: Tuple of Source objects referenced in this citation
sources: tuple[Source, ...]

@classmethod
Expand All @@ -33,9 +37,15 @@ def _from_proto(cls, proto: ProtoCitation, sdk: BaseSDK) -> Citation: # type: i
)

class Source(BaseProtoResult[ProtoSource]):
"""
Abstract base class for citation sources.
"""
@property
@abc.abstractmethod
def type(self) -> str:
"""
Get the type identifier of this source.
"""
pass

@classmethod
Expand All @@ -48,11 +58,19 @@ def _from_proto(cls, proto: ProtoSource, sdk: BaseSDK) -> Source: # type: ignor

@dataclasses.dataclass(frozen=True)
class FileChunk(Source, BaseMessage[ProtoSource]):
"""
Represents a file chunk citation source.
"""
#: Search index this chunk belongs to
search_index: BaseSearchIndex
#: File this chunk belongs to (optional)
file: BaseFile | None

@property
def type(self) -> str:
"""
Get the type identifier for file chunks. Always returns 'filechunk'
"""
return 'filechunk'

@classmethod
Expand Down Expand Up @@ -85,10 +103,17 @@ def _from_proto(cls, proto: ProtoSource, sdk: BaseSDK) -> FileChunk | UnknownSou

@dataclasses.dataclass(frozen=True)
class UnknownSource(Source):
"""
Represents an unknown citation source type.
"""
#: Description of the unknown source
text: str

@property
def type(self) -> str:
"""
Get the type identifier for unknown sources. Always returns 'unknown'.
"""
return 'unknown'

@classmethod
Expand Down
38 changes: 36 additions & 2 deletions src/yandex_cloud_ml_sdk/_messages/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
from yandex_cloud_ml_sdk._types.domain import BaseDomain
from yandex_cloud_ml_sdk._types.message import MessageType, coerce_to_text_message_dict
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value
from yandex_cloud_ml_sdk._utils.doc import doc_from
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator

from .message import Message


class BaseMessages(BaseDomain):
"""
Base class for message operations (sync and async implementations).
"""
_message_impl = Message

async def _create(
Expand All @@ -30,6 +34,14 @@ async def _create(
labels: UndefinedOr[dict[str, str]] = UNDEFINED,
timeout: float = 60,
) -> Message:
"""Create a new message (internal implementation).

:param message: Message content to create
:param thread_id: ID of the thread to add message to
:param labels: Optional dictionary of message labels
:param timeout: The timeout, or the maximum time to wait for the request to complete in seconds.
Defaults to 60 seconds.
"""
message_dict = coerce_to_text_message_dict(message)
content = message_dict['text']
author: Author | None = None
Expand Down Expand Up @@ -67,6 +79,14 @@ async def _get(
message_id: str,
timeout: float = 60,
) -> Message:
"""
Get a message by ID (internal implementation).

:param thread_id: ID of the thread containing the message
:param message_id: ID of the message to retrieve
:param timeout: The timeout, or the maximum time to wait for the request to complete in seconds.
Defaults to 60 seconds.
"""
# TODO: we need a global per-sdk cache on ids to rule out
# possibility we have two Messages with same ids but different fields
request = GetMessageRequest(thread_id=thread_id, message_id=message_id)
Expand All @@ -87,6 +107,13 @@ async def _list(
thread_id: str,
timeout: float = 60
) -> AsyncIterator[Message]:
"""
List messages in a thread (internal implementation).

:param thread_id: ID of the thread to list messages from
:param timeout: The timeout, or the maximum time to wait for the request to complete in seconds.
Defaults to 60 seconds.
"""
request = ListMessagesRequest(thread_id=thread_id)

async with self._client.get_service_stub(MessageServiceStub, timeout=timeout) as stub:
Expand All @@ -98,8 +125,9 @@ async def _list(
):
yield self._message_impl._from_proto(proto=response, sdk=self._sdk)


@doc_from(BaseMessages)
class AsyncMessages(BaseMessages):
@doc_from(BaseMessages._create)
async def create(
self,
message: MessageType,
Expand All @@ -115,6 +143,7 @@ async def create(
timeout=timeout
)

@doc_from(BaseMessages._get)
async def get(
self,
*,
Expand All @@ -128,6 +157,7 @@ async def get(
timeout=timeout
)

@doc_from(BaseMessages._list)
async def list(
self,
*,
Expand All @@ -140,12 +170,14 @@ async def list(
):
yield message


@doc_from(BaseMessages)
class Messages(BaseMessages):

__get = run_sync(BaseMessages._get)
__create = run_sync(BaseMessages._create)
__list = run_sync_generator(BaseMessages._list)

@doc_from(BaseMessages._create)
def create(
self,
message: MessageType,
Expand All @@ -161,6 +193,7 @@ def create(
timeout=timeout
)

@doc_from(BaseMessages._get)
def get(
self,
*,
Expand All @@ -174,6 +207,7 @@ def get(
timeout=timeout
)

@doc_from(BaseMessages._list)
def list(
self,
*,
Expand Down
27 changes: 27 additions & 0 deletions src/yandex_cloud_ml_sdk/_messages/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@


class MessageStatus(ProtoEnumBase, enum.IntEnum):
"""
Enum representing possible message statuses.

.. note:: Values are inherited from protobuf definitions.
"""
MESSAGE_STATUS_UNSPECIFIED = ProtoMessage.MessageStatus.MESSAGE_STATUS_UNSPECIFIED

# Message was successfully created by a user or generated by an assistant.
Expand All @@ -32,22 +37,40 @@ class MessageStatus(ProtoEnumBase, enum.IntEnum):

@dataclasses.dataclass(frozen=True)
class Author:
"""
Represents the author of a message.
"""
#: Unique identifier of the message author
id: str
#: Role of the author (e.g., 'user', 'assistant')
role: str


@dataclasses.dataclass(frozen=True)
class Message(BaseMessage[ProtoMessage], BaseResource[ProtoMessage]):
"""
Represents a message in a conversation thread.
"""
#: ID of the thread containing this message
thread_id: str
#: ID of the user/assistant who created the message
created_by: str
#: Timestamp when the message was created
created_at: datetime
#: A set of labels for the message.
labels: dict[str, str] | None
#: Author information
author: Author
#: Tuple of citations in this message
citations: tuple[Citation, ...]
#: Current status of the message
status: MessageStatus

@property
def role(self) -> str:
"""
Get the role of the message author.
"""
return self.author.role

@classmethod
Expand Down Expand Up @@ -80,6 +103,10 @@ def _kwargs_from_message(cls, proto: ProtoMessage, sdk: BaseSDK) -> dict[str, An

@dataclasses.dataclass(frozen=True)
class PartialMessage(BaseMessage[MessageContent], BaseResource[MessageContent]):
"""
Represents a partial message (content only without full metadata).
"""

@classmethod
def _kwargs_from_message(cls, proto: MessageContent, sdk: BaseSDK) -> dict[str, Any]: # type: ignore[override]
kwargs = super()._kwargs_from_message(proto, sdk=sdk)
Expand Down