Skip to content

Commit 164fd51

Browse files
authored
Add docstrings for _threads (#127)
1 parent 9428ac7 commit 164fd51

File tree

2 files changed

+102
-4
lines changed

2 files changed

+102
-4
lines changed

src/yandex_cloud_ml_sdk/_threads/domain.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@
1212
from yandex_cloud_ml_sdk._types.domain import BaseDomain
1313
from yandex_cloud_ml_sdk._types.expiration import ExpirationConfig, ExpirationPolicyAlias
1414
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value, is_defined
15+
from yandex_cloud_ml_sdk._utils.doc import doc_from
1516
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator
1617

1718
from .thread import AsyncThread, Thread, ThreadTypeT
1819

1920

2021
class BaseThreads(BaseDomain, Generic[ThreadTypeT]):
22+
"""A class for managing threads. It is a part of Assistants API.
23+
24+
This class provides methods to create, retrieve, and list threads.
25+
"""
2126
_thread_impl: type[ThreadTypeT]
2227

2328
async def _create(
@@ -30,6 +35,19 @@ async def _create(
3035
expiration_policy: UndefinedOr[ExpirationPolicyAlias] = UNDEFINED,
3136
timeout: float = 60,
3237
) -> ThreadTypeT:
38+
"""Create a new thread.
39+
40+
This method creates a new thread with the specified parameters.
41+
42+
:param name: the name of the thread.
43+
:param description: a description for the thread.
44+
:param labels: a set of labels for the thread.
45+
:param ttl_days: time-to-live in days for the thread.
46+
:param expiration_policy: expiration policy for the file.
47+
Assepts for passing ``static`` or ``since_last_active`` strings. Should be defined if ``ttl_days`` has been defined, otherwise both parameters should be undefined.
48+
:param timeout: timeout for the service call in seconds.
49+
Defaults to 60 seconds.
50+
"""
3351
if is_defined(ttl_days) != is_defined(expiration_policy):
3452
raise ValueError("ttl_days and expiration policy must be both defined either undefined")
3553

@@ -59,6 +77,14 @@ async def _get(
5977
*,
6078
timeout: float = 60,
6179
) -> ThreadTypeT:
80+
"""Retrieve a thread by its id.
81+
82+
This method fetches an already created thread using its unique identifier.
83+
84+
:param thread_id: the unique identifier of the thread to retrieve.
85+
:param timeout: timeout for the service call in seconds.
86+
Defaults to 60 seconds.
87+
"""
6288
# TODO: we need a global per-sdk cache on ids to rule out
6389
# possibility we have two Threads with same ids but different fields
6490
request = GetThreadRequest(thread_id=thread_id)
@@ -79,6 +105,15 @@ async def _list(
79105
page_size: UndefinedOr[int] = UNDEFINED,
80106
timeout: float = 60
81107
) -> AsyncIterator[ThreadTypeT]:
108+
"""List threads in the specified folder.
109+
110+
This method retrieves a list of threads. It continues
111+
to fetch threads until there are no more available.
112+
113+
:param page_size: the maximum number of threads to return per page.
114+
:param timeout: timeout for the service call in seconds.
115+
Defaults to 60 seconds.
116+
"""
82117
page_token_ = ''
83118
page_size_ = get_defined_value(page_size, 0)
84119

@@ -104,10 +139,11 @@ async def _list(
104139

105140
page_token_ = response.next_page_token
106141

107-
142+
@doc_from(BaseThreads)
108143
class AsyncThreads(BaseThreads[AsyncThread]):
109144
_thread_impl = AsyncThread
110145

146+
@doc_from(BaseThreads._create)
111147
async def create(
112148
self,
113149
*,
@@ -127,6 +163,7 @@ async def create(
127163
timeout=timeout,
128164
)
129165

166+
@doc_from(BaseThreads._get)
130167
async def get(
131168
self,
132169
thread_id: str,
@@ -138,6 +175,7 @@ async def get(
138175
timeout=timeout,
139176
)
140177

178+
@doc_from(BaseThreads._list)
141179
async def list(
142180
self,
143181
*,
@@ -150,14 +188,15 @@ async def list(
150188
):
151189
yield thread
152190

153-
191+
@doc_from(BaseThreads)
154192
class Threads(BaseThreads[Thread]):
155193
_thread_impl = Thread
156194

157195
__get = run_sync(BaseThreads._get)
158196
__create = run_sync(BaseThreads._create)
159197
__list = run_sync_generator(BaseThreads._list)
160198

199+
@doc_from(BaseThreads._create)
161200
def create(
162201
self,
163202
*,
@@ -177,6 +216,7 @@ def create(
177216
timeout=timeout,
178217
)
179218

219+
@doc_from(BaseThreads._get)
180220
def get(
181221
self,
182222
thread_id: str,
@@ -188,6 +228,7 @@ def get(
188228
timeout=timeout,
189229
)
190230

231+
@doc_from(BaseThreads._list)
191232
def list(
192233
self,
193234
*,

src/yandex_cloud_ml_sdk/_threads/thread.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,16 @@
1717
from yandex_cloud_ml_sdk._types.message import MessageType
1818
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value
1919
from yandex_cloud_ml_sdk._types.resource import ExpirableResource, safe_on_delete
20+
from yandex_cloud_ml_sdk._utils.doc import doc_from
2021
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator
2122

2223

2324
@dataclasses.dataclass(frozen=True)
2425
class BaseThread(ExpirableResource):
26+
"""A class for a thread resource.
27+
28+
It provides methods for working with messages that the thread contains (e.g. updating, deleting, writing to, and reading from).
29+
"""
2530
@safe_on_delete
2631
async def _update(
2732
self,
@@ -33,6 +38,17 @@ async def _update(
3338
expiration_policy: UndefinedOr[ExpirationPolicyAlias] = UNDEFINED,
3439
timeout: float = 60,
3540
) -> Self:
41+
"""Update the thread's properties, including the name, the description, labels,
42+
ttl days, and the expiration policy of the thread.
43+
44+
:param name: the new name of the thread.
45+
:param description: the new description for the thread.
46+
:param labels: a set of new labels for the thread.
47+
:param ttl_days: the updated time-to-live in days for the thread.
48+
:param expiration_policy: an updated expiration policy for the file.
49+
:param timeout: timeout for the operation in seconds.
50+
Defaults to 60 seconds.
51+
"""
3652
# pylint: disable=too-many-locals
3753
name_ = get_defined_value(name, '')
3854
description_ = get_defined_value(description, '')
@@ -78,6 +94,14 @@ async def _delete(
7894
*,
7995
timeout: float = 60,
8096
) -> None:
97+
"""Delete the thread.
98+
99+
This method deletes the thread and marks it as deleted.
100+
Raises an exception if the deletion fails.
101+
102+
:param timeout: timeout for the operation.
103+
Defaults to 60 seconds.
104+
"""
81105
request = DeleteThreadRequest(thread_id=self.id)
82106

83107
async with self._client.get_service_stub(ThreadServiceStub, timeout=timeout) as stub:
@@ -97,6 +121,16 @@ async def _write(
97121
labels: UndefinedOr[dict[str, str]] = UNDEFINED,
98122
timeout: float = 60,
99123
) -> Message:
124+
"""Write a message to the thread.
125+
126+
This method allows sending a message to the thread with optional labels.
127+
128+
:param message: the message to be sent to the thread. Could be a string, a dictionary, or a result object.
129+
Read more about other possible message types in the `documentation <https://yandex.cloud/docs/foundation-models/sdk/#usage>`_.
130+
:param labels: optional labels for the message.
131+
:param timeout: timeout for the operation.
132+
Defaults to 60 seconds.
133+
"""
100134
# pylint: disable-next=protected-access
101135
return await self._sdk._messages._create(
102136
thread_id=self.id,
@@ -110,6 +144,13 @@ async def _read(
110144
*,
111145
timeout: float = 60,
112146
) -> AsyncIterator[Message]:
147+
"""Read messages from the thread.
148+
149+
This method allows iterating over messages in the thread.
150+
151+
:param timeout: timeout for the operation.
152+
Defaults to 60 seconds.
153+
"""
113154
# NB: in other methods it is solved via @safe decorator, but it is doesn't work
114155
# with iterators, so, temporary here will be small copypaste
115156
# Also I'm not sure enough if we need to put whole thread reading under a lock
@@ -125,17 +166,26 @@ async def _read(
125166

126167
@dataclasses.dataclass(frozen=True)
127168
class RichThread(BaseThread):
169+
#: the name of the thread
128170
name: str | None
171+
#: the description of the thread
129172
description: str | None
173+
#: the identifier of the user who created the thread
130174
created_by: str
175+
#: the timestamp when the thread was created
131176
created_at: datetime
177+
#: the identifier of the user who last updated the thread
132178
updated_by: str
179+
#: the timestamp when the thread was last updated
133180
updated_at: datetime
181+
#: the timestamp when the thread will expire
134182
expires_at: datetime
183+
#: additional labels associated with the thread
135184
labels: dict[str, str] | None
136185

137-
138186
class AsyncThread(RichThread):
187+
188+
@doc_from(BaseThread._update)
139189
async def update(
140190
self,
141191
*,
@@ -155,13 +205,15 @@ async def update(
155205
timeout=timeout,
156206
)
157207

208+
@doc_from(BaseThread._delete)
158209
async def delete(
159210
self,
160211
*,
161212
timeout: float = 60,
162213
) -> None:
163214
await self._delete(timeout=timeout)
164215

216+
@doc_from(BaseThread._write)
165217
async def write(
166218
self,
167219
message: MessageType,
@@ -175,6 +227,7 @@ async def write(
175227
timeout=timeout
176228
)
177229

230+
@doc_from(BaseThread._read)
178231
async def read(
179232
self,
180233
*,
@@ -183,15 +236,16 @@ async def read(
183236
async for message in self._read(timeout=timeout):
184237
yield message
185238

239+
#: alias for the read method
186240
__aiter__ = read
187241

188-
189242
class Thread(RichThread):
190243
__update = run_sync(RichThread._update)
191244
__delete = run_sync(RichThread._delete)
192245
__write = run_sync(RichThread._write)
193246
__read = run_sync_generator(RichThread._read)
194247

248+
@doc_from(BaseThread._update)
195249
def update(
196250
self,
197251
*,
@@ -211,13 +265,15 @@ def update(
211265
timeout=timeout,
212266
)
213267

268+
@doc_from(BaseThread._delete)
214269
def delete(
215270
self,
216271
*,
217272
timeout: float = 60,
218273
) -> None:
219274
self.__delete(timeout=timeout)
220275

276+
@doc_from(BaseThread._write)
221277
def write(
222278
self,
223279
message: MessageType,
@@ -231,6 +287,7 @@ def write(
231287
timeout=timeout
232288
)
233289

290+
@doc_from(BaseThread._read)
234291
def read(
235292
self,
236293
*,

0 commit comments

Comments
 (0)