Skip to content

Commit 385ede1

Browse files
authored
Fix datasets list (#67)
1 parent 56a1e1b commit 385ede1

File tree

4 files changed

+47
-11
lines changed

4 files changed

+47
-11
lines changed

examples/async/tuning/datasets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ async def main() -> None:
4848
print(f"New {bad_dataset=} have a bad status {bad_dataset.status=}")
4949
await dataset.delete()
5050

51+
# You could call .list not only on .datasets,
52+
# but on .completions helper as well, it will substitute corresponding task_type as a filter
53+
async for dataset in sdk.datasets.completions.list():
54+
await dataset.delete()
55+
5156
async for dataset in sdk.datasets.list():
5257
await dataset.delete()
5358

examples/sync/tuning/datasets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def main() -> None:
4747
print(f"New {bad_dataset=} have a bad status {dataset.status=}")
4848
dataset.delete()
4949

50+
# You could call .list not only on .datasets,
51+
# but on .completions helper as well, it will substitute corresponding task_type as a filter
52+
for dataset in sdk.datasets.completions.list():
53+
dataset.delete()
54+
5055
for dataset in sdk.datasets.list():
5156
dataset.delete()
5257

src/yandex_cloud_ml_sdk/_datasets/domain.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# pylint: disable=protected-access,no-name-in-module
22
from __future__ import annotations
33

4-
from typing import AsyncIterator, Generic, Iterator
4+
from typing import AsyncIterator, Generic, Iterable, Iterator, Union
55

6+
from typing_extensions import TypeAlias
67
from yandex.cloud.ai.dataset.v1.dataset_service_pb2 import (
78
CreateDatasetRequest, CreateDatasetResponse, DescribeDatasetRequest, DescribeDatasetResponse, ListDatasetsRequest,
89
ListDatasetsResponse, ListUploadFormatsRequest, ListUploadFormatsResponse
@@ -22,6 +23,10 @@
2223
logger = get_logger(__name__)
2324

2425

26+
SingleDatasetStatus: TypeAlias = Union[str, DatasetStatus]
27+
DatasetStatusInput: TypeAlias = Union[SingleDatasetStatus, Iterable[SingleDatasetStatus]]
28+
29+
2530
class BaseDatasets(BaseDomain, Generic[DatasetTypeT, DatasetDraftT]):
2631
_dataset_impl: type[DatasetTypeT]
2732
_dataset_draft_impl: type[DatasetDraftT]
@@ -120,20 +125,33 @@ async def _get(
120125
async def _list(
121126
self,
122127
*,
123-
status: UndefinedOr[str] | DatasetStatus = UNDEFINED,
128+
status: UndefinedOr[DatasetStatusInput] = UNDEFINED,
124129
name_pattern: UndefinedOr[str] = UNDEFINED,
130+
task_type: UndefinedOr[str] | Iterable[str] = UNDEFINED,
125131
timeout: float = 60
126132
) -> AsyncIterator[DatasetTypeT]:
127-
logger.debug('Fetching datasets list with status=%s and name_pattern=%s', status, name_pattern)
133+
status_: DatasetStatusInput = get_defined_value(status, [])
134+
status_list: list[SingleDatasetStatus] = [status_] if isinstance(status_, (str, DatasetStatus)) else list(status_)
135+
coerced_status_list: list[DatasetStatus] = [
136+
DatasetStatus._from_str(s) if isinstance(s, str) else s
137+
for s in status_list
138+
]
128139

129-
status_: str | DatasetStatus = get_defined_value(status, DatasetStatus.STATUS_UNSPECIFIED)
130-
if isinstance(status_, str):
131-
status_ = DatasetStatus._from_str(status_)
140+
task_type_: str | Iterable[str] = get_defined_value(task_type, [])
141+
task_type_list: list[str] = [task_type_] if isinstance(task_type_, str) else list(task_type_)
142+
143+
name_pattern_: str = get_defined_value(name_pattern, '')
132144

133145
request = ListDatasetsRequest(
134146
folder_id=self._folder_id,
135-
status=status_, # type: ignore[arg-type]
136-
dataset_name_pattern=get_defined_value(name_pattern, ''),
147+
status=coerced_status_list, # type: ignore[arg-type]
148+
dataset_name_pattern=name_pattern_,
149+
task_type_filter=task_type_list,
150+
)
151+
152+
logger.debug(
153+
'Fetching datasets list with status=%r, name_pattern=%r and task_type_filter=%r',
154+
coerced_status_list, name_pattern, task_type_list,
137155
)
138156

139157
async with self._client.get_service_stub(DatasetServiceStub, timeout=timeout) as stub:
@@ -192,14 +210,16 @@ async def get(
192210
async def list(
193211
self,
194212
*,
195-
status: UndefinedOr[str] | DatasetStatus = UNDEFINED,
213+
status: UndefinedOr[str] | DatasetStatus | Iterable[str | DatasetStatus] = UNDEFINED,
196214
name_pattern: UndefinedOr[str] = UNDEFINED,
215+
task_type: UndefinedOr[str] | Iterable[str] = UNDEFINED,
197216
timeout: float = 60
198217
) -> AsyncIterator[AsyncDataset]:
199218
async for dataset in self._list(
200219
status=status,
201220
name_pattern=name_pattern,
202-
timeout=timeout
221+
task_type=task_type,
222+
timeout=timeout,
203223
):
204224
yield dataset
205225

@@ -234,13 +254,15 @@ def get(
234254
def list(
235255
self,
236256
*,
237-
status: UndefinedOr[str] | DatasetStatus = UNDEFINED,
257+
status: UndefinedOr[str] | DatasetStatus | Iterable[str | DatasetStatus] = UNDEFINED,
238258
name_pattern: UndefinedOr[str] = UNDEFINED,
259+
task_type: UndefinedOr[str] | Iterable[str] = UNDEFINED,
239260
timeout: float = 60
240261
) -> Iterator[Dataset]:
241262
yield from self.__list(
242263
status=status,
243264
name_pattern=name_pattern,
265+
task_type=task_type,
244266
timeout=timeout
245267
)
246268

src/yandex_cloud_ml_sdk/_datasets/task_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,9 @@ def draft_from_path(self):
6666
def list_upload_formats(self):
6767
return partial(self._domain.list_upload_formats, task_type=self._task_type)
6868

69+
@property
70+
def list(self):
71+
return partial(self._domain.list, task_type=self._task_type)
72+
6973
def __repr__(self) -> str:
7074
return f'{self.__class__.__name__}(task_type={self._task_type})'

0 commit comments

Comments
 (0)