Skip to content

Commit 6804535

Browse files
authored
Support for list_upload_schemas & batch task types (#86)
1 parent 9494c83 commit 6804535

File tree

5 files changed

+148
-1
lines changed

5 files changed

+148
-1
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import pprint
7+
8+
from yandex_cloud_ml_sdk import AsyncYCloudML
9+
10+
11+
async def main() -> None:
12+
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
13+
sdk.setup_default_logging()
14+
15+
for task_type in (
16+
'TextToTextGeneration',
17+
'TextToTextGenerationRequest',
18+
'ImageTextToTextGenerationRequest',
19+
'TextEmbeddingsPair',
20+
'TextEmbeddingsTriplet',
21+
'TextClassificationMultilabel',
22+
'TextClassificationMulticlass',
23+
):
24+
schemas = await sdk.datasets.list_upload_schemas(task_type)
25+
print(f'Schemas for {task_type=}:')
26+
pprint.pprint([schema.json for schema in schemas])
27+
28+
29+
if __name__ == '__main__':
30+
asyncio.run(main())
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import pprint
6+
7+
from yandex_cloud_ml_sdk import YCloudML
8+
9+
10+
def main() -> None:
11+
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
12+
sdk.setup_default_logging()
13+
14+
for task_type in (
15+
'TextToTextGeneration',
16+
'TextToTextGenerationRequest',
17+
'ImageTextToTextGenerationRequest',
18+
'TextEmbeddingsPair',
19+
'TextEmbeddingsTriplet',
20+
'TextClassificationMultilabel',
21+
'TextClassificationMulticlass',
22+
):
23+
schemas = sdk.datasets.list_upload_schemas(task_type)
24+
print(f'Schemas for {task_type=}:')
25+
pprint.pprint([schema.json for schema in schemas])
26+
27+
28+
if __name__ == '__main__':
29+
main()

src/yandex_cloud_ml_sdk/_datasets/domain.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# pylint: disable=protected-access,no-name-in-module
22
from __future__ import annotations
33

4+
import warnings
45
from typing import AsyncIterator, Generic, Iterable, Iterator, Union
56

67
from typing_extensions import TypeAlias
78
from yandex.cloud.ai.dataset.v1.dataset_service_pb2 import (
89
CreateDatasetRequest, CreateDatasetResponse, DescribeDatasetRequest, DescribeDatasetResponse, ListDatasetsRequest,
9-
ListDatasetsResponse, ListUploadFormatsRequest, ListUploadFormatsResponse
10+
ListDatasetsResponse, ListUploadFormatsRequest, ListUploadFormatsResponse, ListUploadSchemasRequest,
11+
ListUploadSchemasResponse
1012
)
1113
from yandex.cloud.ai.dataset.v1.dataset_service_pb2_grpc import DatasetServiceStub
1214

@@ -17,6 +19,7 @@
1719

1820
from .dataset import AsyncDataset, Dataset, DatasetTypeT
1921
from .draft import AsyncDatasetDraft, DatasetDraft, DatasetDraftT
22+
from .schema import DatasetUploadSchema
2023
from .status import DatasetStatus
2124
from .task_types import KnownTaskType, TaskTypeProxy
2225

@@ -172,6 +175,8 @@ async def _list_upload_formats(
172175
*,
173176
timeout: float = 60,
174177
) -> tuple[str, ...]:
178+
warnings.warn("dataset.list_upload_formats is deprecated", category=DeprecationWarning)
179+
175180
logger.debug('Fetching available dataset upload formats for task_type=%s', task_type)
176181
request = ListUploadFormatsRequest(
177182
task_type=task_type
@@ -191,6 +196,35 @@ async def _list_upload_formats(
191196
)
192197
return tuple(response.formats)
193198

199+
async def _list_upload_schemas(
200+
self,
201+
task_type: str,
202+
*,
203+
timeout: float = 60,
204+
) -> tuple[DatasetUploadSchema, ...]:
205+
logger.debug('Fetching available dataset upload schemas for task_type=%s', task_type)
206+
request = ListUploadSchemasRequest(
207+
task_type=task_type,
208+
folder_id=self._folder_id,
209+
)
210+
211+
async with self._client.get_service_stub(DatasetServiceStub, timeout=timeout) as stub:
212+
response = await self._client.call_service(
213+
stub.ListUploadSchemas,
214+
request,
215+
timeout=timeout,
216+
expected_type=ListUploadSchemasResponse,
217+
)
218+
219+
logger.info(
220+
'%d dataset upload schemas successfully fetched for a task_type=%s',
221+
len(response.schemas), task_type,
222+
)
223+
return tuple(
224+
DatasetUploadSchema._from_proto(proto=schema, sdk=self._sdk)
225+
for schema in response.schemas
226+
)
227+
194228

195229
class AsyncDatasets(BaseDatasets[AsyncDataset, AsyncDatasetDraft]):
196230
_dataset_impl = AsyncDataset
@@ -231,6 +265,14 @@ async def list_upload_formats(
231265
) -> tuple[str, ...]:
232266
return await self._list_upload_formats(task_type=task_type, timeout=timeout)
233267

268+
async def list_upload_schemas(
269+
self,
270+
task_type: str,
271+
*,
272+
timeout: float = 60,
273+
) -> tuple[DatasetUploadSchema, ...]:
274+
return await self._list_upload_schemas(task_type=task_type, timeout=timeout)
275+
234276

235277
class Datasets(BaseDatasets[Dataset, DatasetDraft]):
236278
_dataset_impl = Dataset
@@ -239,6 +281,7 @@ class Datasets(BaseDatasets[Dataset, DatasetDraft]):
239281
__get = run_sync(BaseDatasets._get)
240282
__list = run_sync_generator(BaseDatasets._list)
241283
__list_upload_formats = run_sync(BaseDatasets._list_upload_formats)
284+
__list_upload_schemas = run_sync(BaseDatasets._list_upload_schemas)
242285

243286
def get(
244287
self,
@@ -273,3 +316,11 @@ def list_upload_formats(
273316
timeout: float = 60,
274317
) -> tuple[str, ...]:
275318
return self.__list_upload_formats(task_type=task_type, timeout=timeout)
319+
320+
def list_upload_schemas(
321+
self,
322+
task_type: str,
323+
*,
324+
timeout: float = 60,
325+
) -> tuple[DatasetUploadSchema, ...]:
326+
return self.__list_upload_schemas(task_type=task_type, timeout=timeout)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# pylint: disable=no-name-in-module
2+
from __future__ import annotations
3+
4+
import ast
5+
from dataclasses import dataclass
6+
7+
from typing_extensions import Self
8+
from yandex.cloud.ai.dataset.v1.dataset_pb2 import DatasetUploadSchema as ProtoDatasetUploadSchema
9+
10+
from yandex_cloud_ml_sdk._types.proto import ProtoBased, SDKType
11+
from yandex_cloud_ml_sdk._types.schemas import JsonSchemaType
12+
13+
14+
@dataclass(frozen=True)
15+
class DatasetUploadSchema(ProtoBased[ProtoDatasetUploadSchema]):
16+
task_type: str
17+
upload_format: str
18+
raw_schema: str
19+
20+
@classmethod
21+
def _from_proto(cls, *, proto: ProtoDatasetUploadSchema, sdk: SDKType) -> Self:
22+
return cls(
23+
task_type=proto.task_type,
24+
upload_format=proto.upload_format,
25+
raw_schema=proto.schema
26+
)
27+
28+
@property
29+
def json(self) -> JsonSchemaType:
30+
return ast.literal_eval(self.raw_schema)

src/yandex_cloud_ml_sdk/_datasets/task_types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ class KnownTaskType(str, Enum):
1919
TextEmbeddingsPair = 'TextEmbeddingsPair'
2020
TextEmbeddingsTriplet = 'TextEmbeddingsTriplet'
2121

22+
TextToTextGenerationRequest = 'TextToTextGenerationRequest'
23+
ImageTextToTextGenerationRequest = 'ImageTextToTextGenerationRequest'
24+
2225

2326
class TaskTypeProxy:
2427
def __init__(self, task_type: KnownTaskType):
@@ -66,6 +69,10 @@ def draft_from_path(self):
6669
def list_upload_formats(self):
6770
return partial(self._domain.list_upload_formats, task_type=self._task_type)
6871

72+
@property
73+
def list_upload_schemas(self):
74+
return partial(self._domain.list_upload_schemas, task_type=self._task_type)
75+
6976
@property
7077
def list(self):
7178
return partial(self._domain.list, task_type=self._task_type)

0 commit comments

Comments
 (0)