Skip to content

Commit 9ce0ad4

Browse files
authored
Add docstrings for _models/image_generation (#115)
1 parent 82f9a99 commit 9ce0ad4

File tree

4 files changed

+72
-3
lines changed

4 files changed

+72
-3
lines changed

src/yandex_cloud_ml_sdk/_models/image_generation/function.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,40 @@
33
from typing_extensions import override
44

55
from yandex_cloud_ml_sdk._types.function import BaseModelFunction, ModelTypeT
6+
from yandex_cloud_ml_sdk._utils.doc import doc_from
67

78
from .model import AsyncImageGenerationModel, ImageGenerationModel
89

910

1011
class BaseImageGeneration(BaseModelFunction[ModelTypeT]):
12+
"""
13+
A class for image generation models.
14+
15+
It provides the functionality to call an image generation model by constructing
16+
the appropriate URI based on the provided model name and version.
17+
18+
Returns a model's object through which requests to the backend are made.
19+
20+
>>> model = sdk.models.image_generation('yandex-art') # this is how the model is created
21+
"""
1122
@override
1223
def __call__(
1324
self,
1425
model_name: str,
1526
*,
1627
model_version: str = 'latest',
1728
):
29+
"""
30+
Call the image generation model with the specified name and version.
31+
32+
Constructs the URI for the model based on the provided model's name
33+
and version. If the name contains '://', it is treated as a
34+
complete URI. Otherwise, it constructs a URI using the folder ID
35+
from the SDK.
36+
37+
:param model_name: the name of the image generation model.
38+
:param model_version: the version of the model to use (default is 'latest').
39+
"""
1840
if '://' in model_name:
1941
uri = model_name
2042
else:
@@ -26,10 +48,10 @@ def __call__(
2648
uri=uri,
2749
)
2850

29-
51+
@doc_from(BaseImageGeneration)
3052
class ImageGeneration(BaseImageGeneration):
3153
_model_type = ImageGenerationModel
3254

33-
55+
@doc_from(BaseImageGeneration)
3456
class AsyncImageGeneration(BaseImageGeneration):
3557
_model_type = AsyncImageGenerationModel

src/yandex_cloud_ml_sdk/_models/image_generation/message.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,40 @@
1212

1313
@dataclass(frozen=True)
1414
class ImageMessage:
15+
"""
16+
This class represents a message for using in image generation models with optional weight field.
17+
"""
18+
#: the text content of the message for using in image generation models
1519
text: str
20+
#: the weight associated with the message
1621
weight: float | None = None
1722

1823

1924
class ImageMessageDict(TypedDict):
25+
"""
26+
The class with TypedDict representing the structure of an image message.
27+
"""
2028
text: str
2129
weight: NotRequired[float]
2230

2331

2432
# NB: it supports _messages.message.Message and _models.completions.message.TextMessage
2533
@runtime_checkable
2634
class AnyMessage(Protocol):
35+
"""
36+
The class with a protocol which defines an object with a text field.
37+
The protocol can be used to check if an object has a text attribute.
38+
"""
2739
text: str
2840

29-
41+
#: type alias for different types of messages that can be processed by image generation models
3042
ImageMessageType = Union[ImageMessage, ImageMessageDict, AnyMessage, str]
43+
#: type alias for input types accepted by the `messages_to_proto` function
3144
ImageMessageInputType = Union[ImageMessageType, Iterable[ImageMessageType]]
3245

3346

3447
def messages_to_proto(messages: ImageMessageInputType) -> list[ProtoMessage]:
48+
""":meta private:"""
3549
msgs: tuple[ImageMessageType] = coerce_tuple( # type: ignore[assignment]
3650
messages,
3751
(dict, str, ImageMessage, AnyMessage) # type: ignore[arg-type]

src/yandex_cloud_ml_sdk/_models/image_generation/model.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr
1919
from yandex_cloud_ml_sdk._types.model import ModelAsyncMixin, OperationTypeT
2020
from yandex_cloud_ml_sdk._types.operation import AsyncOperation, Operation
21+
from yandex_cloud_ml_sdk._utils.doc import doc_from
2122
from yandex_cloud_ml_sdk._utils.sync import run_sync
2223

2324
from .config import ImageGenerationModelConfig
@@ -28,6 +29,7 @@
2829
class BaseImageGenerationModel(
2930
ModelAsyncMixin[ImageGenerationModelConfig, ImageGenerationModelResult, OperationTypeT],
3031
):
32+
"""A class of the one, concrete model. This model encapsulates the URI and configuration."""
3133
_config_type = ImageGenerationModelConfig
3234
_result_type = ImageGenerationModelResult
3335
_operation_type: type[OperationTypeT]
@@ -41,6 +43,16 @@ def configure( # type: ignore[override]
4143
height_ratio: UndefinedOr[int] = UNDEFINED,
4244
mime_type: UndefinedOr[str] = UNDEFINED,
4345
) -> Self:
46+
"""
47+
Configures the image generation model with specified parameters and
48+
returns the configured instance of the model.
49+
50+
:param seed: a random seed for generation.
51+
:param width_ratio: the width ratio for the generated image.
52+
:param height_ratio: the height ratio for the generated image.
53+
:param mime_type: the MIME type of the generated image.
54+
Read more on what MIME types exist in `the documentation <https://yandex.cloud/docs/foundation-models/image-generation/api-ref/ImageGenerationAsync/generate>`_.
55+
"""
4456
return super().configure(
4557
seed=seed,
4658
width_ratio=width_ratio,
@@ -83,6 +95,7 @@ async def _run_deferred(
8395
)
8496

8597

98+
@doc_from(BaseImageGenerationModel)
8699
class AsyncImageGenerationModel(BaseImageGenerationModel[AsyncOperation[ImageGenerationModelResult]]):
87100
_operation_type = AsyncOperation[ImageGenerationModelResult]
88101

@@ -92,6 +105,13 @@ async def run_deferred(
92105
*,
93106
timeout: float = 60,
94107
) -> AsyncOperation[ImageGenerationModelResult]:
108+
"""Executes the image generation operation asynchronously
109+
and returns an operation representing the ongoing image generation process.
110+
111+
:param messages: the input messages for image generation.
112+
:param timeout: the timeout, or the maximum time to wait for the request to complete in seconds.
113+
Defaults to 60 seconds.
114+
"""
95115
return await self._run_deferred(
96116
messages=messages,
97117
timeout=timeout
@@ -102,14 +122,22 @@ async def attach_deferred(
102122
operation_id: str,
103123
timeout: float = 60,
104124
) -> AsyncOperation[ImageGenerationModelResult]:
125+
"""Attaches to an ongoing image generation operation.
126+
127+
:param operation_id: the ID of the operation to attach to.
128+
:param timeout: the timeout, or the maximum time to wait for the request to complete in seconds.
129+
Defaults to 60 seconds.
130+
"""
105131
return await self._attach_deferred(operation_id=operation_id, timeout=timeout)
106132

107133

134+
@doc_from(BaseImageGenerationModel)
108135
class ImageGenerationModel(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]):
109136
_operation_type = Operation[ImageGenerationModelResult]
110137
__run_deferred = run_sync(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]._run_deferred)
111138
__attach_deferred = run_sync(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]._attach_deferred)
112139

140+
@doc_from(AsyncImageGenerationModel.run_deferred)
113141
def run_deferred(
114142
self,
115143
messages: ImageMessageInputType,
@@ -122,6 +150,7 @@ def run_deferred(
122150
timeout=timeout
123151
)
124152

153+
@doc_from(AsyncImageGenerationModel.attach_deferred)
125154
def attach_deferred(self, operation_id: str, timeout: float = 60) -> Operation[ImageGenerationModelResult]:
126155
return cast(
127156
Operation[ImageGenerationModelResult],

src/yandex_cloud_ml_sdk/_models/image_generation/result.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515

1616
@dataclass(frozen=True, repr=False)
1717
class ImageGenerationModelResult(BaseResult):
18+
"""This class represents the result of an image generation model inference."""
19+
#: the generated image in bytes
1820
image_bytes: bytes
21+
#: the version of the model used for generation
1922
model_version: str
2023

2124
@classmethod
@@ -27,6 +30,7 @@ def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> Self: # pylint: d
2730
)
2831

2932
def _repr_jpeg_(self) -> bytes | None:
33+
""":meta public:"""
3034
# NB: currently model could return only jpeg,
3135
# but for future I will put this check here to
3236
# remember we will need to make a _repr_png_ and such

0 commit comments

Comments
 (0)