Skip to content

Commit 5d18f84

Browse files
authored
Merge branch 'master' into refactor_proto_result
2 parents 4fa3ddd + eee6521 commit 5d18f84

File tree

5 files changed

+191
-6
lines changed

5 files changed

+191
-6
lines changed

CONTRIBUTING.md

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Tests
2+
3+
## Run all the tests
4+
```bash
5+
tox
6+
```
7+
8+
## Run specific python version
9+
10+
```bash
11+
tox -e python3.12-extra-deps
12+
```
13+
14+
## Filtering tests further
15+
16+
```bash
17+
tox -e ... -- path/to/test/file.py -k <substring_of_test_name>
18+
```
19+
20+
or to speedup the process, look into `tox.ini` and run commands without tox environment:
21+
22+
```bash
23+
pip install -r test_requirements.txt -r test_requirements_extra.txt
24+
25+
# to run one and only one test
26+
pytest path/to/test/file.py -k <substring_of_test_name>
27+
28+
# to run only flakes
29+
pytest --flakes src/ examples/
30+
```
31+
32+
Bash above is just examples how to run some of the commands without tox, there many more possibilities.
33+
34+
## Test cassetes
35+
36+
In some tests we are using [pytest-recording](https://github.com/kiwicom/pytest-recording) library to record
37+
http requests to a local test cassetes to not to use network in future runs and to increase determinacy
38+
of test runs (with the assupmtion that backend will not break contract and will not make any breaking changes).
39+
40+
But pytest-recording supports only http/https cassetes so we have something like this for grpc requests
41+
[written by us](src/yandex_cloud_ml_sdk/_testing/interceptor.py).
42+
43+
It is not so convenient to use and have it's problems, but basically to use it you need to:
44+
45+
1) Place `@pytest.mark.allow_grpc` to the test.
46+
47+
2) Edit `tests/conftest.py:fixture_folder_id` (TODO: make it using env vars by default)
48+
49+
3) Export any auth into environment like `export YC_API_KEY="..."` or any other auth method
50+
51+
4) Run your new test for the first time `pytest path/to/test/file.py -k <test_name> --generate-grpc` which will create
52+
a new cassete file
53+
54+
5) In case of the test failed or if you want to regenerate cassete -- `pytest <...> --regenerate-grpc` will help
55+
56+
6) When you will thinkn that cassete is okay, run pytest without any `--...-grpc` flags
57+
58+
7) Do not forget to commit new cassete file.
59+
60+
61+
# Pre-commit hooks
62+
63+
We are using https://pre-commit.com/ to improve our PR review experience.
64+
65+
It is generally a good idea to setup pre-commit locally, otherwise some robot will come to your PR and
66+
will either make commit with fixes (and you will need to pull it or overwrite with a force push),
67+
either it will break the tests with the things it can't fix.
68+
69+
70+
# Documentation
71+
72+
On PR we are running `sphinx-build -W` which are failing in case of any warnings such as
73+
unresolved references or wrong rst syntaxt in docstrings.
74+
75+
To install sphinx and other required for work with documentation, run
76+
`pip isntall -r docs/requirements.txt -e .` in the repo root.
77+
78+
To run sphinx locally to check if there are any errors or warnings,
79+
run `sphinx-build docs docs/_build -E -W` in the repo root.
80+
It will generate html in the `docs/_build` folder which you could check out in case
81+
you want to be sure how is your doc will looks like.
82+
83+
Also you could install `sphinx-autobuild` and run `sphinx-autobuild docs docs/_build --watch src`
84+
to run a webserver (it will be available at http://localhost:8000 by-default) and automatically
85+
rebuild a doc when you save a file.

src/yandex_cloud_ml_sdk/_threads/domain.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@
1212
from yandex_cloud_ml_sdk._types.domain import BaseDomain
1313
from yandex_cloud_ml_sdk._types.expiration import ExpirationConfig, ExpirationPolicyAlias
1414
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value, is_defined
15+
from yandex_cloud_ml_sdk._utils.doc import doc_from
1516
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator
1617

1718
from .thread import AsyncThread, Thread, ThreadTypeT
1819

1920

2021
class BaseThreads(BaseDomain, Generic[ThreadTypeT]):
22+
"""A class for managing threads. It is a part of Assistants API.
23+
24+
This class provides methods to create, retrieve, and list threads.
25+
"""
2126
_thread_impl: type[ThreadTypeT]
2227

2328
async def _create(
@@ -30,6 +35,19 @@ async def _create(
3035
expiration_policy: UndefinedOr[ExpirationPolicyAlias] = UNDEFINED,
3136
timeout: float = 60,
3237
) -> ThreadTypeT:
38+
"""Create a new thread.
39+
40+
This method creates a new thread with the specified parameters.
41+
42+
:param name: the name of the thread.
43+
:param description: a description for the thread.
44+
:param labels: a set of labels for the thread.
45+
:param ttl_days: time-to-live in days for the thread.
46+
:param expiration_policy: expiration policy for the file.
47+
Assepts for passing ``static`` or ``since_last_active`` strings. Should be defined if ``ttl_days`` has been defined, otherwise both parameters should be undefined.
48+
:param timeout: timeout for the service call in seconds.
49+
Defaults to 60 seconds.
50+
"""
3351
if is_defined(ttl_days) != is_defined(expiration_policy):
3452
raise ValueError("ttl_days and expiration policy must be both defined either undefined")
3553

@@ -59,6 +77,14 @@ async def _get(
5977
*,
6078
timeout: float = 60,
6179
) -> ThreadTypeT:
80+
"""Retrieve a thread by its id.
81+
82+
This method fetches an already created thread using its unique identifier.
83+
84+
:param thread_id: the unique identifier of the thread to retrieve.
85+
:param timeout: timeout for the service call in seconds.
86+
Defaults to 60 seconds.
87+
"""
6288
# TODO: we need a global per-sdk cache on ids to rule out
6389
# possibility we have two Threads with same ids but different fields
6490
request = GetThreadRequest(thread_id=thread_id)
@@ -79,6 +105,15 @@ async def _list(
79105
page_size: UndefinedOr[int] = UNDEFINED,
80106
timeout: float = 60
81107
) -> AsyncIterator[ThreadTypeT]:
108+
"""List threads in the specified folder.
109+
110+
This method retrieves a list of threads. It continues
111+
to fetch threads until there are no more available.
112+
113+
:param page_size: the maximum number of threads to return per page.
114+
:param timeout: timeout for the service call in seconds.
115+
Defaults to 60 seconds.
116+
"""
82117
page_token_ = ''
83118
page_size_ = get_defined_value(page_size, 0)
84119

@@ -104,10 +139,11 @@ async def _list(
104139

105140
page_token_ = response.next_page_token
106141

107-
142+
@doc_from(BaseThreads)
108143
class AsyncThreads(BaseThreads[AsyncThread]):
109144
_thread_impl = AsyncThread
110145

146+
@doc_from(BaseThreads._create)
111147
async def create(
112148
self,
113149
*,
@@ -127,6 +163,7 @@ async def create(
127163
timeout=timeout,
128164
)
129165

166+
@doc_from(BaseThreads._get)
130167
async def get(
131168
self,
132169
thread_id: str,
@@ -138,6 +175,7 @@ async def get(
138175
timeout=timeout,
139176
)
140177

178+
@doc_from(BaseThreads._list)
141179
async def list(
142180
self,
143181
*,
@@ -150,14 +188,15 @@ async def list(
150188
):
151189
yield thread
152190

153-
191+
@doc_from(BaseThreads)
154192
class Threads(BaseThreads[Thread]):
155193
_thread_impl = Thread
156194

157195
__get = run_sync(BaseThreads._get)
158196
__create = run_sync(BaseThreads._create)
159197
__list = run_sync_generator(BaseThreads._list)
160198

199+
@doc_from(BaseThreads._create)
161200
def create(
162201
self,
163202
*,
@@ -177,6 +216,7 @@ def create(
177216
timeout=timeout,
178217
)
179218

219+
@doc_from(BaseThreads._get)
180220
def get(
181221
self,
182222
thread_id: str,
@@ -188,6 +228,7 @@ def get(
188228
timeout=timeout,
189229
)
190230

231+
@doc_from(BaseThreads._list)
191232
def list(
192233
self,
193234
*,

src/yandex_cloud_ml_sdk/_threads/thread.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@
1717
from yandex_cloud_ml_sdk._types.message import MessageType
1818
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value
1919
from yandex_cloud_ml_sdk._types.resource import ExpirableResource, safe_on_delete
20+
from yandex_cloud_ml_sdk._utils.doc import doc_from
2021
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator
2122

2223

2324
@dataclasses.dataclass(frozen=True)
2425
class BaseThread(ExpirableResource[ProtoThread]):
26+
"""A class for a thread resource.
27+
28+
It provides methods for working with messages that the thread contains (e.g. updating, deleting, writing to, and reading from).
29+
"""
30+
2531
@safe_on_delete
2632
async def _update(
2733
self,
@@ -33,6 +39,17 @@ async def _update(
3339
expiration_policy: UndefinedOr[ExpirationPolicyAlias] = UNDEFINED,
3440
timeout: float = 60,
3541
) -> Self:
42+
"""Update the thread's properties, including the name, the description, labels,
43+
ttl days, and the expiration policy of the thread.
44+
45+
:param name: the new name of the thread.
46+
:param description: the new description for the thread.
47+
:param labels: a set of new labels for the thread.
48+
:param ttl_days: the updated time-to-live in days for the thread.
49+
:param expiration_policy: an updated expiration policy for the file.
50+
:param timeout: timeout for the operation in seconds.
51+
Defaults to 60 seconds.
52+
"""
3653
# pylint: disable=too-many-locals
3754
name_ = get_defined_value(name, '')
3855
description_ = get_defined_value(description, '')
@@ -78,6 +95,14 @@ async def _delete(
7895
*,
7996
timeout: float = 60,
8097
) -> None:
98+
"""Delete the thread.
99+
100+
This method deletes the thread and marks it as deleted.
101+
Raises an exception if the deletion fails.
102+
103+
:param timeout: timeout for the operation.
104+
Defaults to 60 seconds.
105+
"""
81106
request = DeleteThreadRequest(thread_id=self.id)
82107

83108
async with self._client.get_service_stub(ThreadServiceStub, timeout=timeout) as stub:
@@ -97,6 +122,16 @@ async def _write(
97122
labels: UndefinedOr[dict[str, str]] = UNDEFINED,
98123
timeout: float = 60,
99124
) -> Message:
125+
"""Write a message to the thread.
126+
127+
This method allows sending a message to the thread with optional labels.
128+
129+
:param message: the message to be sent to the thread. Could be a string, a dictionary, or a result object.
130+
Read more about other possible message types in the `documentation <https://yandex.cloud/docs/foundation-models/sdk/#usage>`_.
131+
:param labels: optional labels for the message.
132+
:param timeout: timeout for the operation.
133+
Defaults to 60 seconds.
134+
"""
100135
# pylint: disable-next=protected-access
101136
return await self._sdk._messages._create(
102137
thread_id=self.id,
@@ -110,6 +145,13 @@ async def _read(
110145
*,
111146
timeout: float = 60,
112147
) -> AsyncIterator[Message]:
148+
"""Read messages from the thread.
149+
150+
This method allows iterating over messages in the thread.
151+
152+
:param timeout: timeout for the operation.
153+
Defaults to 60 seconds.
154+
"""
113155
# NB: in other methods it is solved via @safe decorator, but it is doesn't work
114156
# with iterators, so, temporary here will be small copypaste
115157
# Also I'm not sure enough if we need to put whole thread reading under a lock
@@ -125,17 +167,26 @@ async def _read(
125167

126168
@dataclasses.dataclass(frozen=True)
127169
class RichThread(BaseThread):
170+
#: the name of the thread
128171
name: str | None
172+
#: the description of the thread
129173
description: str | None
174+
#: the identifier of the user who created the thread
130175
created_by: str
176+
#: the timestamp when the thread was created
131177
created_at: datetime
178+
#: the identifier of the user who last updated the thread
132179
updated_by: str
180+
#: the timestamp when the thread was last updated
133181
updated_at: datetime
182+
#: the timestamp when the thread will expire
134183
expires_at: datetime
184+
#: additional labels associated with the thread
135185
labels: dict[str, str] | None
136186

137-
138187
class AsyncThread(RichThread):
188+
189+
@doc_from(BaseThread._update)
139190
async def update(
140191
self,
141192
*,
@@ -155,13 +206,15 @@ async def update(
155206
timeout=timeout,
156207
)
157208

209+
@doc_from(BaseThread._delete)
158210
async def delete(
159211
self,
160212
*,
161213
timeout: float = 60,
162214
) -> None:
163215
await self._delete(timeout=timeout)
164216

217+
@doc_from(BaseThread._write)
165218
async def write(
166219
self,
167220
message: MessageType,
@@ -175,6 +228,7 @@ async def write(
175228
timeout=timeout
176229
)
177230

231+
@doc_from(BaseThread._read)
178232
async def read(
179233
self,
180234
*,
@@ -183,15 +237,16 @@ async def read(
183237
async for message in self._read(timeout=timeout):
184238
yield message
185239

240+
#: alias for the read method
186241
__aiter__ = read
187242

188-
189243
class Thread(RichThread):
190244
__update = run_sync(RichThread._update)
191245
__delete = run_sync(RichThread._delete)
192246
__write = run_sync(RichThread._write)
193247
__read = run_sync_generator(RichThread._read)
194248

249+
@doc_from(BaseThread._update)
195250
def update(
196251
self,
197252
*,
@@ -211,13 +266,15 @@ def update(
211266
timeout=timeout,
212267
)
213268

269+
@doc_from(BaseThread._delete)
214270
def delete(
215271
self,
216272
*,
217273
timeout: float = 60,
218274
) -> None:
219275
self.__delete(timeout=timeout)
220276

277+
@doc_from(BaseThread._write)
221278
def write(
222279
self,
223280
message: MessageType,
@@ -231,6 +288,7 @@ def write(
231288
timeout=timeout
232289
)
233290

291+
@doc_from(BaseThread._read)
234292
def read(
235293
self,
236294
*,

test_requirements_extra.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
langchain-core>=0.3; python_version >= '3.9'
2+
numpy

0 commit comments

Comments
 (0)