|
1 | 1 | # pylint: disable=protected-access,no-name-in-module |
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | | -from typing import AsyncIterator, Generic, Iterator |
| 4 | +from typing import AsyncIterator, Generic, Iterable, Iterator, Union |
5 | 5 |
|
| 6 | +from typing_extensions import TypeAlias |
6 | 7 | from yandex.cloud.ai.dataset.v1.dataset_service_pb2 import ( |
7 | 8 | CreateDatasetRequest, CreateDatasetResponse, DescribeDatasetRequest, DescribeDatasetResponse, ListDatasetsRequest, |
8 | 9 | ListDatasetsResponse, ListUploadFormatsRequest, ListUploadFormatsResponse |
|
22 | 23 | logger = get_logger(__name__) |
23 | 24 |
|
24 | 25 |
|
| 26 | +SingleDatasetStatus: TypeAlias = Union[str, DatasetStatus] |
| 27 | +DatasetStatusInput: TypeAlias = Union[SingleDatasetStatus, Iterable[SingleDatasetStatus]] |
| 28 | + |
| 29 | + |
25 | 30 | class BaseDatasets(BaseDomain, Generic[DatasetTypeT, DatasetDraftT]): |
26 | 31 | _dataset_impl: type[DatasetTypeT] |
27 | 32 | _dataset_draft_impl: type[DatasetDraftT] |
@@ -120,20 +125,33 @@ async def _get( |
120 | 125 | async def _list( |
121 | 126 | self, |
122 | 127 | *, |
123 | | - status: UndefinedOr[str] | DatasetStatus = UNDEFINED, |
| 128 | + status: UndefinedOr[DatasetStatusInput] = UNDEFINED, |
124 | 129 | name_pattern: UndefinedOr[str] = UNDEFINED, |
| 130 | + task_type: UndefinedOr[str] | Iterable[str] = UNDEFINED, |
125 | 131 | timeout: float = 60 |
126 | 132 | ) -> AsyncIterator[DatasetTypeT]: |
127 | | - logger.debug('Fetching datasets list with status=%s and name_pattern=%s', status, name_pattern) |
| 133 | + status_: DatasetStatusInput = get_defined_value(status, []) |
| 134 | + status_list: list[SingleDatasetStatus] = [status_] if isinstance(status_, (str, DatasetStatus)) else list(status_) |
| 135 | + coerced_status_list: list[DatasetStatus] = [ |
| 136 | + DatasetStatus._from_str(s) if isinstance(s, str) else s |
| 137 | + for s in status_list |
| 138 | + ] |
128 | 139 |
|
129 | | - status_: str | DatasetStatus = get_defined_value(status, DatasetStatus.STATUS_UNSPECIFIED) |
130 | | - if isinstance(status_, str): |
131 | | - status_ = DatasetStatus._from_str(status_) |
| 140 | + task_type_: str | Iterable[str] = get_defined_value(task_type, []) |
| 141 | + task_type_list: list[str] = [task_type_] if isinstance(task_type_, str) else list(task_type_) |
| 142 | + |
| 143 | + name_pattern_: str = get_defined_value(name_pattern, '') |
132 | 144 |
|
133 | 145 | request = ListDatasetsRequest( |
134 | 146 | folder_id=self._folder_id, |
135 | | - status=status_, # type: ignore[arg-type] |
136 | | - dataset_name_pattern=get_defined_value(name_pattern, ''), |
| 147 | + status=coerced_status_list, # type: ignore[arg-type] |
| 148 | + dataset_name_pattern=name_pattern_, |
| 149 | + task_type_filter=task_type_list, |
| 150 | + ) |
| 151 | + |
| 152 | + logger.debug( |
| 153 | + 'Fetching datasets list with status=%r, name_pattern=%r and task_type_filter=%r', |
| 154 | + coerced_status_list, name_pattern, task_type_list, |
137 | 155 | ) |
138 | 156 |
|
139 | 157 | async with self._client.get_service_stub(DatasetServiceStub, timeout=timeout) as stub: |
@@ -192,14 +210,16 @@ async def get( |
192 | 210 | async def list( |
193 | 211 | self, |
194 | 212 | *, |
195 | | - status: UndefinedOr[str] | DatasetStatus = UNDEFINED, |
| 213 | + status: UndefinedOr[str] | DatasetStatus | Iterable[str | DatasetStatus] = UNDEFINED, |
196 | 214 | name_pattern: UndefinedOr[str] = UNDEFINED, |
| 215 | + task_type: UndefinedOr[str] | Iterable[str] = UNDEFINED, |
197 | 216 | timeout: float = 60 |
198 | 217 | ) -> AsyncIterator[AsyncDataset]: |
199 | 218 | async for dataset in self._list( |
200 | 219 | status=status, |
201 | 220 | name_pattern=name_pattern, |
202 | | - timeout=timeout |
| 221 | + task_type=task_type, |
| 222 | + timeout=timeout, |
203 | 223 | ): |
204 | 224 | yield dataset |
205 | 225 |
|
@@ -234,13 +254,15 @@ def get( |
234 | 254 | def list( |
235 | 255 | self, |
236 | 256 | *, |
237 | | - status: UndefinedOr[str] | DatasetStatus = UNDEFINED, |
| 257 | + status: UndefinedOr[str] | DatasetStatus | Iterable[str | DatasetStatus] = UNDEFINED, |
238 | 258 | name_pattern: UndefinedOr[str] = UNDEFINED, |
| 259 | + task_type: UndefinedOr[str] | Iterable[str] = UNDEFINED, |
239 | 260 | timeout: float = 60 |
240 | 261 | ) -> Iterator[Dataset]: |
241 | 262 | yield from self.__list( |
242 | 263 | status=status, |
243 | 264 | name_pattern=name_pattern, |
| 265 | + task_type=task_type, |
244 | 266 | timeout=timeout |
245 | 267 | ) |
246 | 268 |
|
|
0 commit comments