-
Notifications
You must be signed in to change notification settings - Fork 27
Add docstrings for _models/image_generation #115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
38f672f
1e3c4dd
4dc64eb
3e0790b
ec11730
cf6ce94
ad281bb
e08d487
5a34ef2
3f38094
f3d841e
cd7c1c1
f4b7892
ae40a0e
c37a528
f903ef5
5baf79f
ccf3dd7
b8a2f24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,26 +12,41 @@ | |
|
|
||
| @dataclass(frozen=True) | ||
| class ImageMessage: | ||
| """ | ||
| This class represents an image message with optional weight field. | ||
| """ | ||
| #: the text content of the image message | ||
| text: str | ||
| #: the weight associated with the message | ||
| weight: float | None = None | ||
|
|
||
|
|
||
| class ImageMessageDict(TypedDict): | ||
| """ | ||
| The class with TypedDict representing the structure of an image message. | ||
| """ | ||
| text: str | ||
| weight: NotRequired[float] | ||
|
|
||
|
|
||
| # NB: it supports _messages.message.Message and _models.completions.message.TextMessage | ||
| @runtime_checkable | ||
| class AnyMessage(Protocol): | ||
| """ | ||
| The class with a protocol which defines an object with a text field. | ||
| The protocol can be used to check if an object has a text attribute. | ||
| """ | ||
| text: str | ||
|
|
||
|
|
||
|
||
| ImageMessageType = Union[ImageMessage, ImageMessageDict, AnyMessage, str] | ||
| """Type alias for different types of image messages that can be processed.""" | ||
|
||
| ImageMessageInputType = Union[ImageMessageType, Iterable[ImageMessageType]] | ||
| """Type alias for input types accepted by the `messages_to_proto` function.""" | ||
|
|
||
|
|
||
| def messages_to_proto(messages: ImageMessageInputType) -> list[ProtoMessage]: | ||
| """:meta private:""" | ||
| msgs: tuple[ImageMessageType] = coerce_tuple( # type: ignore[assignment] | ||
| messages, | ||
| (dict, str, ImageMessage, AnyMessage) # type: ignore[arg-type] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr | ||
| from yandex_cloud_ml_sdk._types.model import ModelAsyncMixin, OperationTypeT | ||
| from yandex_cloud_ml_sdk._types.operation import AsyncOperation, Operation | ||
| from yandex_cloud_ml_sdk._utils.doc import doc_from | ||
| from yandex_cloud_ml_sdk._utils.sync import run_sync | ||
|
|
||
| from .config import ImageGenerationModelConfig | ||
|
|
@@ -28,6 +29,7 @@ | |
| class BaseImageGenerationModel( | ||
| ModelAsyncMixin[ImageGenerationModelConfig, ImageGenerationModelResult, OperationTypeT], | ||
| ): | ||
| """A class for image generation models.""" | ||
|
||
| _config_type = ImageGenerationModelConfig | ||
| _result_type = ImageGenerationModelResult | ||
| _operation_type: type[OperationTypeT] | ||
|
|
@@ -41,6 +43,15 @@ def configure( # type: ignore[override] | |
| height_ratio: UndefinedOr[int] = UNDEFINED, | ||
| mime_type: UndefinedOr[str] = UNDEFINED, | ||
| ) -> Self: | ||
| """ | ||
| Configures the image generation model with specified parameters and | ||
| returns the configured instance of the model. | ||
|
|
||
| :param seed: a random seed for generation. | ||
| :param width_ratio: the width ratio for the generated image. | ||
| :param height_ratio: the height ratio for the generated image. | ||
| :param mime_type: the MIME type of the generated image. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Здесь бы не помешала бы ссылка на доку про то, а какие MIME-типы мы умеем принимать |
||
| """ | ||
| return super().configure( | ||
| seed=seed, | ||
| width_ratio=width_ratio, | ||
|
|
@@ -83,6 +94,7 @@ async def _run_deferred( | |
| ) | ||
|
|
||
|
|
||
| @doc_from(BaseImageGenerationModel) | ||
| class AsyncImageGenerationModel(BaseImageGenerationModel[AsyncOperation[ImageGenerationModelResult]]): | ||
| _operation_type = AsyncOperation[ImageGenerationModelResult] | ||
|
|
||
|
|
@@ -92,6 +104,13 @@ async def run_deferred( | |
| *, | ||
| timeout: float = 60, | ||
| ) -> AsyncOperation[ImageGenerationModelResult]: | ||
| """Executes the image generation operation asynchronously | ||
| and returns an operation representing the ongoing image generation process. | ||
|
|
||
| :param messages: the input messages for image generation. | ||
| :param timeout: the timeout for the operation in seconds. | ||
|
||
| Defaults to 60 seconds. | ||
| """ | ||
| return await self._run_deferred( | ||
| messages=messages, | ||
| timeout=timeout | ||
|
|
@@ -102,14 +121,22 @@ async def attach_deferred( | |
| operation_id: str, | ||
| timeout: float = 60, | ||
| ) -> AsyncOperation[ImageGenerationModelResult]: | ||
| """Attaches to an ongoing image generation operation. | ||
|
|
||
| :param operation_id: the ID of the operation to attach to. | ||
| :param timeout: the timeout for the operation in seconds. | ||
| Defaults to 60 seconds. | ||
| """ | ||
| return await self._attach_deferred(operation_id=operation_id, timeout=timeout) | ||
|
|
||
|
|
||
| @doc_from(BaseImageGenerationModel) | ||
| class ImageGenerationModel(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]): | ||
| _operation_type = Operation[ImageGenerationModelResult] | ||
| __run_deferred = run_sync(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]._run_deferred) | ||
| __attach_deferred = run_sync(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]._attach_deferred) | ||
|
|
||
| @doc_from(AsyncImageGenerationModel.run_deferred) | ||
| def run_deferred( | ||
| self, | ||
| messages: ImageMessageInputType, | ||
|
|
@@ -122,6 +149,7 @@ def run_deferred( | |
| timeout=timeout | ||
| ) | ||
|
|
||
| @doc_from(AsyncImageGenerationModel.attach_deferred) | ||
| def attach_deferred(self, operation_id: str, timeout: float = 60) -> Operation[ImageGenerationModelResult]: | ||
| return cast( | ||
| Operation[ImageGenerationModelResult], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,18 +15,23 @@ | |
|
|
||
| @dataclass(frozen=True, repr=False) | ||
| class ImageGenerationModelResult(BaseResult): | ||
| """This class represents the result of an image generation model.""" | ||
|
||
| #: the generated image in bytes | ||
| image_bytes: bytes | ||
| #: the version of the model used for generation | ||
| model_version: str | ||
|
|
||
| @classmethod | ||
| def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> Self: # pylint: disable=unused-argument | ||
| """:meta private:""" | ||
|
||
| proto = cast(ImageGenerationResponse, proto) | ||
| return cls( | ||
| image_bytes=proto.image, | ||
| model_version=proto.model_version, | ||
| ) | ||
|
|
||
| def _repr_jpeg_(self) -> bytes | None: | ||
| """:meta private:""" | ||
|
||
| # NB: currently model could return only jpeg, | ||
| # but for future I will put this check here to | ||
| # remember we will need to make a _repr_png_ and such | ||
|
|
@@ -38,5 +43,6 @@ def _repr_jpeg_(self) -> bytes | None: | |
| return None | ||
|
|
||
| def __repr__(self) -> str: | ||
| """:meta private:""" | ||
|
||
| size = len(self.image_bytes) | ||
| return f'{self.__class__.__name__}(model_version={self.model_version!r}, image_bytes=<{size} bytes>)' | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Что такое image message?)
На самом деле, это просто структура сообщения, которую модель ожидает от пользователя.
То есть "message for using in image generation models", ниже тоже самое