Skip to content

Commit 2167f24

Browse files
authored
Add .add_files method for search indices (#68)
1 parent cd96d9f commit 2167f24

File tree

7 files changed

+860
-22
lines changed

7 files changed

+860
-22
lines changed

examples/async/assistants/search_indexes.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def main() -> None:
2828
files = await asyncio.gather(*file_coros)
2929

3030
operation = await sdk.search_indexes.create_deferred(
31-
files,
31+
[files[0]],
3232
index_type=TextSearchIndexType(
3333
chunking_strategy=StaticIndexChunkingStrategy(
3434
max_chunk_size_tokens=700,
@@ -45,11 +45,16 @@ async def main() -> None:
4545
await search_index.update(name="foo")
4646
print(f"now with a name {search_index=}")
4747

48-
# NB: it doesn't work at the moment
49-
# index_files = [file async for file in search_index.list_files()]
50-
# print(f"search index files: {index_files}")
51-
# index_file = await search_index.get_file(index_files[0].id)
52-
# print(f"search index file: {index_file}")
48+
# We could also add files to index later:
49+
add_operation = await search_index.add_files_deferred(files[1])
50+
new_index_files = await add_operation.wait()
51+
print(f"{new_index_files=}")
52+
53+
index_files = [file async for file in search_index.list_files()]
54+
print(f"search index files: {index_files}")
55+
56+
index_file = await search_index.get_file(index_files[0].id)
57+
print(f"search index file: {index_file}")
5358

5459
for file in files:
5560
await file.delete()

examples/sync/assistants/search_indexes.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def main() -> None:
2626
files.append(file)
2727

2828
operation = sdk.search_indexes.create_deferred(
29-
files,
29+
[files[0]],
3030
index_type=TextSearchIndexType(
3131
chunking_strategy=StaticIndexChunkingStrategy(
3232
max_chunk_size_tokens=700,
@@ -43,11 +43,16 @@ def main() -> None:
4343
search_index.update(name="foo")
4444
print(f"now with a name {search_index=}")
4545

46-
# NB: it doesn't work at the moment
47-
# index_files = [file for file in search_index.list_files()]
48-
# print(f"search index files: {index_files}")
49-
# index_file = search_index.get_file(index_files[0].id)
50-
# print(f"search index file: {index_file}")
46+
# We could also add files to index later:
47+
add_operation = search_index.add_files_deferred(files[1])
48+
new_index_files = add_operation.wait()
49+
print(f"{new_index_files=}")
50+
51+
index_files = [file for file in search_index.list_files()]
52+
print(f"search index files: {index_files}")
53+
54+
index_file = search_index.get_file(index_files[0].id)
55+
print(f"search index file: {index_file}")
5156

5257
for file in files:
5358
file.delete()

src/yandex_cloud_ml_sdk/_search_indexes/search_index.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,29 @@
33

44
import dataclasses
55
from datetime import datetime
6-
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, TypeVar
6+
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, TypeVar, cast
77

8-
from typing_extensions import Self
8+
from typing_extensions import Self, TypeAlias
99
from yandex.cloud.ai.assistants.v1.searchindex.search_index_file_pb2 import SearchIndexFile as ProtoSearchIndexFile
1010
from yandex.cloud.ai.assistants.v1.searchindex.search_index_file_service_pb2 import (
11-
GetSearchIndexFileRequest, ListSearchIndexFilesRequest, ListSearchIndexFilesResponse
11+
BatchCreateSearchIndexFileRequest, BatchCreateSearchIndexFileResponse, GetSearchIndexFileRequest,
12+
ListSearchIndexFilesRequest, ListSearchIndexFilesResponse
1213
)
1314
from yandex.cloud.ai.assistants.v1.searchindex.search_index_file_service_pb2_grpc import SearchIndexFileServiceStub
1415
from yandex.cloud.ai.assistants.v1.searchindex.search_index_pb2 import SearchIndex as ProtoSearchIndex
1516
from yandex.cloud.ai.assistants.v1.searchindex.search_index_service_pb2 import (
1617
DeleteSearchIndexRequest, DeleteSearchIndexResponse, UpdateSearchIndexRequest
1718
)
1819
from yandex.cloud.ai.assistants.v1.searchindex.search_index_service_pb2_grpc import SearchIndexServiceStub
20+
from yandex.cloud.operation.operation_pb2 import Operation as ProtoOperation
1921

22+
from yandex_cloud_ml_sdk._files.file import BaseFile
2023
from yandex_cloud_ml_sdk._types.expiration import ExpirationConfig, ExpirationPolicyAlias
2124
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value
25+
from yandex_cloud_ml_sdk._types.operation import AsyncOperation, Operation, OperationTypeT, ReturnsOperationMixin
2226
from yandex_cloud_ml_sdk._types.resource import ExpirableResource, safe_on_delete
2327
from yandex_cloud_ml_sdk._types.result import BaseResult
28+
from yandex_cloud_ml_sdk._utils.coerce import ResourceType, coerce_resource_ids
2429
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator
2530

2631
from .file import SearchIndexFile
@@ -30,8 +35,11 @@
3035
from yandex_cloud_ml_sdk._sdk import BaseSDK
3136

3237

38+
SearchIndexFileTuple: TypeAlias = tuple[SearchIndexFile, ...]
39+
40+
3341
@dataclasses.dataclass(frozen=True)
34-
class BaseSearchIndex(ExpirableResource, BaseResult):
42+
class BaseSearchIndex(ExpirableResource, BaseResult, ReturnsOperationMixin[OperationTypeT]):
3543
@classmethod
3644
def _kwargs_from_message(cls, proto: ProtoSearchIndex, sdk: BaseSDK) -> dict[str, Any]: # type: ignore[override]
3745
kwargs = super()._kwargs_from_message(proto, sdk=sdk)
@@ -128,6 +136,42 @@ async def _get_file(
128136

129137
return SearchIndexFile._from_proto(proto=response, sdk=self._sdk)
130138

139+
# pylint: disable=unused-argument
140+
async def _transform_add_files(self, proto: BatchCreateSearchIndexFileResponse, timeout: float) -> SearchIndexFileTuple:
141+
return tuple(
142+
SearchIndexFile._from_proto(proto=f, sdk=self._sdk)
143+
for f in proto.files
144+
)
145+
146+
@safe_on_delete
147+
async def _add_files_deferred(
148+
self,
149+
files: ResourceType[BaseFile],
150+
*,
151+
timeout: float = 60,
152+
) -> OperationTypeT:
153+
file_ids = coerce_resource_ids(files, BaseFile)
154+
request = BatchCreateSearchIndexFileRequest(
155+
file_ids=file_ids,
156+
search_index_id=self.id
157+
)
158+
159+
async with self._client.get_service_stub(SearchIndexFileServiceStub, timeout=timeout) as stub:
160+
response = await self._client.call_service(
161+
stub.BatchCreate,
162+
request,
163+
timeout=timeout,
164+
expected_type=ProtoOperation
165+
)
166+
167+
return self._operation_impl(
168+
id=response.id,
169+
sdk=self._sdk,
170+
proto_result_type=BatchCreateSearchIndexFileResponse,
171+
result_type=SearchIndexFileTuple,
172+
transformer=self._transform_add_files
173+
)
174+
131175
async def _list_files(
132176
self,
133177
*,
@@ -161,7 +205,7 @@ async def _list_files(
161205

162206

163207
@dataclasses.dataclass(frozen=True)
164-
class RichSearchIndex(BaseSearchIndex):
208+
class RichSearchIndex(BaseSearchIndex[OperationTypeT]):
165209
folder_id: str
166210
name: str | None
167211
description: str | None
@@ -174,7 +218,9 @@ class RichSearchIndex(BaseSearchIndex):
174218
index_type: BaseSearchIndexType
175219

176220

177-
class AsyncSearchIndex(RichSearchIndex):
221+
class AsyncSearchIndex(RichSearchIndex[AsyncOperation[SearchIndexFileTuple]]):
222+
_operation_impl = AsyncOperation[SearchIndexFileTuple]
223+
178224
async def update(
179225
self,
180226
*,
@@ -224,12 +270,26 @@ async def list_files(
224270
):
225271
yield file
226272

273+
async def add_files_deferred(
274+
self,
275+
files: ResourceType[BaseFile],
276+
*,
277+
timeout: float = 60,
278+
) -> AsyncOperation[SearchIndexFileTuple]:
279+
return await self._add_files_deferred(
280+
files=files,
281+
timeout=timeout
282+
)
283+
227284

228-
class SearchIndex(RichSearchIndex):
285+
# pylint: disable=protected-access
286+
class SearchIndex(RichSearchIndex[Operation[SearchIndexFileTuple]]):
287+
_operation_impl = Operation[SearchIndexFileTuple]
229288
__update = run_sync(RichSearchIndex._update)
230289
__delete = run_sync(RichSearchIndex._delete)
231290
__get_file = run_sync(RichSearchIndex._get_file)
232291
__list_files = run_sync_generator(RichSearchIndex._list_files)
292+
__add_files_deferred = run_sync(RichSearchIndex._add_files_deferred)
233293

234294
def update(
235295
self,
@@ -279,5 +339,17 @@ def list_files(
279339
timeout=timeout,
280340
)
281341

342+
def add_files_deferred(
343+
self,
344+
files: ResourceType[BaseFile],
345+
*,
346+
timeout: float = 60,
347+
) -> Operation[SearchIndexFileTuple]:
348+
# mypy is going crazy as always, with run_sync over generic
349+
return cast(
350+
Operation[SearchIndexFileTuple],
351+
self.__add_files_deferred(files=files, timeout=timeout)
352+
)
353+
282354

283355
SearchIndexTypeT = TypeVar('SearchIndexTypeT', bound=BaseSearchIndex)

src/yandex_cloud_ml_sdk/_types/operation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
logger = get_logger(__name__)
2626

2727
AnyResultTypeT_co = TypeVar('AnyResultTypeT_co', covariant=True)
28-
ResultTypeT_co = TypeVar('ResultTypeT_co', bound=BaseResult, covariant=True)
28+
ResultTypeT_co = TypeVar('ResultTypeT_co', covariant=True)
2929

3030

3131
@dataclass(frozen=True)
@@ -183,6 +183,11 @@ def __repr__(self) -> str:
183183

184184
# pylint: disable=unused-argument
185185
async def _default_result_transofrmer(self, proto: Any, timeout: float) -> ResultTypeT_co:
186+
# NB: default_result_transformer should be used only with _result_type
187+
# which are BaseResult-compatible, but I don't know how to express it with typing,
188+
# maybe we need special operation class, which support transforming (probably a base one)
189+
assert isinstance(self._result_type, BaseResult)
190+
186191
# NB: mypy can't figure out that self._result_type._from_proto is
187192
# returning instance of self._result_type which is also is a ResultTypeT_co
188193
return cast(

src/yandex_cloud_ml_sdk/_types/result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Protocol
3+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
44

55
from typing_extensions import Self
66

@@ -9,7 +9,7 @@
99
if TYPE_CHECKING:
1010
from yandex_cloud_ml_sdk._sdk import BaseSDK
1111

12-
12+
@runtime_checkable
1313
class BaseResult(Protocol):
1414
@classmethod
1515
def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> Self:

0 commit comments

Comments
 (0)