Skip to content

Commit 9a340e6

Browse files
authored
Merge branch 'master' into Add-docstrings-for-_runs
2 parents 7d3184f + 3e5e857 commit 9a340e6

File tree

25 files changed

+96
-88
lines changed

25 files changed

+96
-88
lines changed

src/yandex_cloud_ml_sdk/_assistants/assistant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333

3434
@dataclasses.dataclass(frozen=True)
35-
class BaseAssistant(ExpirableResource, Generic[RunTypeT, ThreadTypeT]):
35+
class BaseAssistant(ExpirableResource[ProtoAssistant], Generic[RunTypeT, ThreadTypeT]):
3636
expiration_config: ExpirationConfig
3737
model: BaseGPTModel
3838
instruction: str | None

src/yandex_cloud_ml_sdk/_datasets/dataset.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from yandex_cloud_ml_sdk._logging import get_logger
2626
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, PathLike, UndefinedOr, coerce_path, get_defined_value
27+
from yandex_cloud_ml_sdk._types.proto import ProtoBased
2728
from yandex_cloud_ml_sdk._types.resource import BaseDeleteableResource, safe_on_delete
2829
from yandex_cloud_ml_sdk._utils.packages import requires_package
2930
from yandex_cloud_ml_sdk._utils.pyarrow import read_dataset_records
@@ -40,13 +41,14 @@
4041
DEFAULT_MAX_PARALLEL_DOWNLOADS: Final[int] = 16 # maximum number of files open for writing during download
4142

4243
@dataclasses.dataclass(frozen=True)
43-
class ValidationErrorInfo:
44+
class ValidationErrorInfo(ProtoBased[ProtoValidationError]):
4445
error: str
4546
description: str
4647
rows: tuple[int, ...]
4748

49+
# pylint: disable=unused-argument
4850
@classmethod
49-
def _from_proto(cls, proto: ProtoValidationError) -> ValidationErrorInfo:
51+
def _from_proto(cls, *, proto: ProtoValidationError, sdk: BaseSDK) -> ValidationErrorInfo:
5052
return cls(
5153
error=proto.error,
5254
description=proto.error_description,
@@ -74,15 +76,15 @@ class DatasetInfo:
7476

7577

7678
@dataclasses.dataclass(frozen=True)
77-
class BaseDataset(DatasetInfo, BaseDeleteableResource):
79+
class BaseDataset(DatasetInfo, BaseDeleteableResource[ProtoDatasetInfo]):
7880
@classmethod
79-
def _kwargs_from_message(cls, proto: ProtoDatasetInfo, sdk: BaseSDK) -> dict[str, Any]: # type: ignore[override]
81+
def _kwargs_from_message(cls, proto: ProtoDatasetInfo, sdk: BaseSDK) -> dict[str, Any]:
8082
kwargs = super()._kwargs_from_message(proto, sdk=sdk)
8183
kwargs['id'] = proto.dataset_id
8284
kwargs['created_by'] = proto.created_by_id
8385
kwargs['status'] = DatasetStatus._from_proto(proto.status)
8486
kwargs['validation_errors'] = tuple(
85-
ValidationErrorInfo._from_proto(p) for p in proto.validation_error
87+
ValidationErrorInfo._from_proto(proto=p, sdk=sdk) for p in proto.validation_error
8688
)
8789
kwargs['allow_data_logging'] = proto.allow_data_log
8890
return kwargs

src/yandex_cloud_ml_sdk/_datasets/validation.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,28 @@
22
from __future__ import annotations
33

44
from dataclasses import dataclass, field
5-
from typing import TYPE_CHECKING, cast
5+
from typing import TYPE_CHECKING
66

77
from yandex.cloud.ai.dataset.v1.dataset_pb2 import ValidationError as ProtoValidationError
88
from yandex.cloud.ai.dataset.v1.dataset_service_pb2 import ValidateDatasetResponse
99

10-
from yandex_cloud_ml_sdk._types.result import BaseResult, ProtoMessage
10+
from yandex_cloud_ml_sdk._types.proto import ProtoBased
11+
from yandex_cloud_ml_sdk._types.result import BaseProtoResult
1112
from yandex_cloud_ml_sdk.exceptions import DatasetValidationError
1213

1314
if TYPE_CHECKING:
1415
from yandex_cloud_ml_sdk._sdk import BaseSDK
1516

1617

1718
@dataclass(frozen=True)
18-
class ValidationErrorInfo:
19+
class ValidationErrorInfo(ProtoBased[ProtoValidationError]):
1920
error: str
2021
description: str
2122
rows: tuple[int, ...]
2223

2324
# pylint: disable=unused-argument
2425
@classmethod
25-
def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> ValidationErrorInfo:
26-
proto = cast(ProtoValidationError, proto)
26+
def _from_proto(cls, *, proto: ProtoValidationError, sdk: BaseSDK) -> ValidationErrorInfo:
2727
return cls(
2828
error=proto.error,
2929
description=proto.error_description,
@@ -32,16 +32,14 @@ def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> ValidationErrorInf
3232

3333

3434
@dataclass(frozen=True)
35-
class DatasetValidationResult(BaseResult):
35+
class DatasetValidationResult(BaseProtoResult[ValidateDatasetResponse]):
3636
_sdk: BaseSDK = field(repr=False)
3737
dataset_id: str
3838
is_valid: bool
3939
errors: tuple[ValidationErrorInfo, ...]
4040

4141
@classmethod
42-
def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> DatasetValidationResult:
43-
proto = cast(ValidateDatasetResponse, proto)
44-
42+
def _from_proto(cls, *, proto: ValidateDatasetResponse, sdk: BaseSDK) -> DatasetValidationResult:
4543
return cls(
4644
dataset_id=proto.dataset_id,
4745
is_valid=proto.is_valid,

src/yandex_cloud_ml_sdk/_files/file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
@dataclasses.dataclass(frozen=True)
23-
class BaseFile(ExpirableResource):
23+
class BaseFile(ExpirableResource[ProtoFile]):
2424
@safe_on_delete
2525
async def _get_url(
2626
self,

src/yandex_cloud_ml_sdk/_messages/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import dataclasses
44
from typing import Any
55

6-
from yandex_cloud_ml_sdk._types.result import BaseResult
6+
from yandex_cloud_ml_sdk._types.result import BaseProtoResult, ProtoMessageTypeT_contra
77

88

99
@dataclasses.dataclass(frozen=True)
10-
class BaseMessage(BaseResult):
10+
class BaseMessage(BaseProtoResult[ProtoMessageTypeT_contra]):
1111
parts: tuple[Any, ...]
1212

1313
@property

src/yandex_cloud_ml_sdk/_messages/citations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from yandex_cloud_ml_sdk._files.file import BaseFile
1313
from yandex_cloud_ml_sdk._search_indexes.search_index import BaseSearchIndex
14-
from yandex_cloud_ml_sdk._types.result import BaseResult
14+
from yandex_cloud_ml_sdk._types.result import BaseProtoResult
1515

1616
from .base import BaseMessage
1717

@@ -20,7 +20,7 @@
2020

2121

2222
@dataclasses.dataclass(frozen=True)
23-
class Citation(BaseResult):
23+
class Citation(BaseProtoResult[ProtoCitation]):
2424
sources: tuple[Source, ...]
2525

2626
@classmethod
@@ -32,7 +32,7 @@ def _from_proto(cls, proto: ProtoCitation, sdk: BaseSDK) -> Citation: # type: i
3232
)
3333
)
3434

35-
class Source(BaseResult):
35+
class Source(BaseProtoResult[ProtoSource]):
3636
@property
3737
@abc.abstractmethod
3838
def type(self) -> str:
@@ -47,7 +47,7 @@ def _from_proto(cls, proto: ProtoSource, sdk: BaseSDK) -> Source: # type: ignor
4747

4848

4949
@dataclasses.dataclass(frozen=True)
50-
class FileChunk(Source, BaseMessage):
50+
class FileChunk(Source, BaseMessage[ProtoSource]):
5151
search_index: BaseSearchIndex
5252
file: BaseFile | None
5353

src/yandex_cloud_ml_sdk/_messages/message.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Author:
3737

3838

3939
@dataclasses.dataclass(frozen=True)
40-
class Message(BaseMessage, BaseResource):
40+
class Message(BaseMessage[ProtoMessage], BaseResource[ProtoMessage]):
4141
thread_id: str
4242
created_by: str
4343
created_at: datetime
@@ -79,7 +79,7 @@ def _kwargs_from_message(cls, proto: ProtoMessage, sdk: BaseSDK) -> dict[str, An
7979

8080

8181
@dataclasses.dataclass(frozen=True)
82-
class PartialMessage(BaseMessage, BaseResource):
82+
class PartialMessage(BaseMessage[MessageContent], BaseResource[MessageContent]):
8383
@classmethod
8484
def _kwargs_from_message(cls, proto: MessageContent, sdk: BaseSDK) -> dict[str, Any]: # type: ignore[override]
8585
kwargs = super()._kwargs_from_message(proto, sdk=sdk)

src/yandex_cloud_ml_sdk/_models/completions/result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from yandex_cloud_ml_sdk._tools.tool_call_list import ProtoCompletionsToolCallList, ToolCallList
1414
from yandex_cloud_ml_sdk._types.message import TextMessage
1515
from yandex_cloud_ml_sdk._types.proto import ProtoBased, SDKType
16-
from yandex_cloud_ml_sdk._types.result import BaseResult
16+
from yandex_cloud_ml_sdk._types.result import BaseProtoResult
1717

1818

1919
@dataclass(frozen=True)
@@ -109,7 +109,7 @@ def _from_proto(cls, *, proto: ProtoAlternative, sdk: SDKType) -> Alternative:
109109

110110

111111
@dataclass(frozen=True)
112-
class GPTModelResult(BaseResult[CompletionResponse], Sequence, HaveToolCalls[ToolCallTypeT]):
112+
class GPTModelResult(BaseProtoResult[CompletionResponse], Sequence, HaveToolCalls[ToolCallTypeT]):
113113
"""A class representing the result of a GPT model completion request."""
114114
#: a tuple of alternatives generated by the model
115115
alternatives: tuple[Alternative[ToolCallTypeT], ...]

src/yandex_cloud_ml_sdk/_models/image_generation/result.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,28 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import TYPE_CHECKING, cast
4+
from typing import TYPE_CHECKING
55

66
from typing_extensions import Self
77
# pylint: disable-next=no-name-in-module
88
from yandex.cloud.ai.foundation_models.v1.image_generation.image_generation_service_pb2 import ImageGenerationResponse
99

10-
from yandex_cloud_ml_sdk._types.result import BaseResult, ProtoMessage
10+
from yandex_cloud_ml_sdk._types.result import BaseProtoResult
1111

1212
if TYPE_CHECKING:
1313
from yandex_cloud_ml_sdk._sdk import BaseSDK
1414

1515

1616
@dataclass(frozen=True, repr=False)
17-
class ImageGenerationModelResult(BaseResult):
17+
class ImageGenerationModelResult(BaseProtoResult[ImageGenerationResponse]):
1818
"""This class represents the result of an image generation model inference."""
1919
#: the generated image in bytes
2020
image_bytes: bytes
2121
#: the version of the model used for generation
2222
model_version: str
2323

2424
@classmethod
25-
def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> Self: # pylint: disable=unused-argument
26-
proto = cast(ImageGenerationResponse, proto)
25+
def _from_proto(cls, *, proto: ImageGenerationResponse, sdk: BaseSDK) -> Self: # pylint: disable=unused-argument
2726
return cls(
2827
image_bytes=proto.image,
2928
model_version=proto.model_version,

src/yandex_cloud_ml_sdk/_models/text_classifiers/result.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import TYPE_CHECKING, Generic, Sequence, TypeVar, cast, overload
4+
from typing import TYPE_CHECKING, Generic, Sequence, TypeVar, overload
55

66
from typing_extensions import Self
77
# pylint: disable-next=no-name-in-module
88
from yandex.cloud.ai.foundation_models.v1.text_classification.text_classification_service_pb2 import (
99
FewShotTextClassificationResponse, TextClassificationResponse
1010
)
1111

12-
from yandex_cloud_ml_sdk._types.result import BaseResult, ProtoMessage
12+
from yandex_cloud_ml_sdk._types.result import BaseProtoResult
1313

1414
from .types import TextClassificationLabel
1515

@@ -25,7 +25,7 @@
2525

2626

2727
@dataclass(frozen=True)
28-
class TextClassifiersModelResultBase(BaseResult, Sequence, Generic[TextClassificationResponseT]):
28+
class TextClassifiersModelResultBase(BaseProtoResult[TextClassificationResponseT], Sequence, Generic[TextClassificationResponseT]):
2929
"""A class for text classifiers model results.
3030
It represents the common structure for the results returned by text classification models.
3131
"""
@@ -37,8 +37,7 @@ class TextClassifiersModelResultBase(BaseResult, Sequence, Generic[TextClassific
3737
input_tokens: int
3838

3939
@classmethod
40-
def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> Self: # pylint: disable=unused-argument
41-
proto = cast(TextClassificationResponseT, proto)
40+
def _from_proto(cls, *, proto: TextClassificationResponseT, sdk: BaseSDK) -> Self: # pylint: disable=unused-argument
4241
predictions = tuple(
4342
TextClassificationLabel(
4443
label=p.label,

0 commit comments

Comments
 (0)