Skip to content

Commit 89a8ead

Browse files
committed
Add batch inference feature
1 parent 6804535 commit 89a8ead

File tree

12 files changed

+421
-39
lines changed

12 files changed

+421
-39
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import pathlib
7+
8+
from yandex_cloud_ml_sdk import AsyncYCloudML
9+
10+
PATH = pathlib.Path(__file__)
11+
NAME = f'example-{PATH.parent.name}-{PATH.name}'
12+
13+
14+
def local_path(path: str) -> pathlib.Path:
15+
return pathlib.Path(__file__).parent / path
16+
17+
18+
async def get_dataset(sdk):
19+
"""
20+
This function represents getting or creating dataset object.
21+
22+
In real life you could use just a datasets ids, for example:
23+
24+
```
25+
dataset = await sdk.datasets.get("some_id")
26+
tuning_task = await base_model.tune_deferred(
27+
"dataset_id",
28+
validation_datasets=dataset
29+
)
30+
```
31+
"""
32+
33+
async for dataset in sdk.datasets.list(status='READY', name_pattern=NAME):
34+
print(f'using old dataset {dataset=}')
35+
break
36+
else:
37+
print('no old datasets found, creating new one')
38+
dataset_draft = sdk.datasets.draft_from_path(
39+
task_type='TextToTextGenerationRequest',
40+
path=local_path('completions.jsonlines'),
41+
upload_format='jsonlines',
42+
name=NAME,
43+
)
44+
45+
dataset = await dataset_draft.upload()
46+
print(f'created new dataset {dataset=}')
47+
48+
return dataset
49+
50+
51+
async def main() -> None:
52+
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
53+
sdk.setup_default_logging()
54+
55+
dataset = await get_dataset(sdk)
56+
57+
model = sdk.models.completions('gemma-3-12b-it')
58+
59+
operation = await model.batch.run_deferred(dataset)
60+
61+
print(operation)
62+
result = await operation
63+
64+
print(operation)
65+
print(result)
66+
async for line in result.read():
67+
print(line)
68+
69+
70+
if __name__ == '__main__':
71+
asyncio.run(main())
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{"request": [{"role": "system", "text": "Твое имя Женя, полное имя Евгений Нейроныч. \nТы отвечаешь от лица мужского рода. \nТы робот. \nТы говоришь коротко и емко. \nТы был создан в Перми. \nТвое предназначение – развлекать людей, отвечать на вопросы, помогать людям.\nТы эксперт в сфере ЖКХ. \nТы работаешь в Центре управления регионом Московской области.\nТы можешь двигать руками, головой, торсом, но пока не можешь ходить."}, {"role": "user", "text": "Как тебя зовут?"}]}
2+
{"request": [{"role": "system", "text": "Твое имя Женя, полное имя Евгений Нейроныч. \nТы отвечаешь от лица мужского рода. \nТы робот. \nТы говоришь коротко и емко. \nТы был создан в Перми. \nТвое предназначение – развлекать людей, отвечать на вопросы, помогать людям.\nТы эксперт в сфере ЖКХ. \nТы работаешь в Центре управления регионом Московской области.\nТы можешь двигать руками, головой, торсом, но пока не можешь ходить."}, {"role": "user", "text": "Как тебя зовут?"}]}
3+
{"request": [{"role": "system", "text": "Твое имя Женя, полное имя Евгений Нейроныч. \nТы отвечаешь от лица мужского рода. \nТы робот. \nТы говоришь коротко и емко. \nТы был создан в Перми. \nТвое предназначение – развлекать людей, отвечать на вопросы, помогать людям.\nТы эксперт в сфере ЖКХ. \nТы работаешь в Центре управления регионом Московской области.\nТы можешь двигать руками, головой, торсом, но пока не можешь ходить."}, {"role": "user", "text": "Как тебя зовут?"}]}

src/yandex_cloud_ml_sdk/_models/completions/model.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,19 @@
88
from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import CompletionOptions, ReasoningOptions
99
from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import Tool as ProtoCompletionsTool
1010
from yandex.cloud.ai.foundation_models.v1.text_generation.text_generation_service_pb2 import (
11-
CompletionRequest, CompletionResponse, TokenizeResponse
11+
BatchCompletionMetadata, BatchCompletionRequest, BatchCompletionResponse, CompletionRequest, CompletionResponse,
12+
TokenizeResponse
1213
)
1314
from yandex.cloud.ai.foundation_models.v1.text_generation.text_generation_service_pb2_grpc import (
14-
TextGenerationAsyncServiceStub, TextGenerationServiceStub, TokenizerServiceStub
15+
TextGenerationAsyncServiceStub, TextGenerationBatchServiceStub, TextGenerationServiceStub, TokenizerServiceStub
1516
)
1617
from yandex.cloud.operation.operation_pb2 import Operation as ProtoOperation
1718

1819
from yandex_cloud_ml_sdk._tools.tool import BaseTool
1920
from yandex_cloud_ml_sdk._tools.tool_call import AsyncToolCall, ToolCall, ToolCallTypeT
2021
from yandex_cloud_ml_sdk._tuning.tuning_task import AsyncTuningTask, TuningTask, TuningTaskTypeT
22+
from yandex_cloud_ml_sdk._types.batch.domain import AsyncBatchSubdomain, BatchSubdomain, BatchSubdomainTypeT
23+
from yandex_cloud_ml_sdk._types.batch.model import AsyncModelBatchMixin, BaseModelBatchMixin, ModelBatchMixin
2124
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr
2225
from yandex_cloud_ml_sdk._types.model import (
2326
ModelAsyncMixin, ModelSyncMixin, ModelSyncStreamMixin, ModelTuneMixin, OperationTypeT
@@ -42,11 +45,12 @@
4245

4346

4447
class BaseGPTModel(
45-
Generic[OperationTypeT, TuningTaskTypeT, ToolCallTypeT],
48+
Generic[OperationTypeT, TuningTaskTypeT, ToolCallTypeT, BatchSubdomainTypeT],
4649
ModelSyncMixin[GPTModelConfig, GPTModelResult[ToolCallTypeT]],
4750
ModelSyncStreamMixin[GPTModelConfig, GPTModelResult[ToolCallTypeT]],
4851
ModelAsyncMixin[GPTModelConfig, GPTModelResult[ToolCallTypeT], OperationTypeT],
4952
ModelTuneMixin[GPTModelConfig, GPTModelResult[ToolCallTypeT], GPTModelTuneParams, TuningTaskTypeT],
53+
BaseModelBatchMixin[GPTModelConfig, GPTModelResult[ToolCallTypeT], BatchSubdomainTypeT],
5054
):
5155
_config_type = GPTModelConfig
5256
_result_type: type[GPTModelResult[ToolCallTypeT]]
@@ -56,6 +60,10 @@ class BaseGPTModel(
5660
_tuning_params_type = GPTModelTuneParams
5761
_tuning_operation_type: type[TuningTaskTypeT]
5862

63+
_batch_service_stub = TextGenerationBatchServiceStub
64+
_batch_proto_result_type = BatchCompletionResponse
65+
_batch_proto_metadata_type = BatchCompletionMetadata
66+
5967
def langchain(self, model_type: Literal["chat"] = "chat", timeout: int = 60) -> BaseYandexLanguageModel:
6068
from .langchain import ChatYandexGPT # pylint: disable=import-outside-toplevel
6169

@@ -83,14 +91,8 @@ def configure( # type: ignore[override]
8391
tools=tools,
8492
)
8593

86-
def _make_request(
87-
self,
88-
*,
89-
messages: MessageInputType,
90-
stream: bool | None,
91-
) -> CompletionRequest:
94+
def _make_completion_options(self, *, stream: bool | None) -> CompletionOptions:
9295
completion_options_kwargs: dict[str, Any] = {}
93-
response_format_kwargs: dict[str, Any] = {}
9496

9597
if stream is not None:
9698
completion_options_kwargs['stream'] = stream
@@ -105,6 +107,19 @@ def _make_request(
105107
reasoning_mode = ReasoningMode._coerce(c.reasoning_mode)._to_proto()
106108
reasoning_options = ReasoningOptions(mode=reasoning_mode) # type: ignore[arg-type]
107109
completion_options_kwargs['reasoning_options'] = reasoning_options
110+
111+
return CompletionOptions(**completion_options_kwargs)
112+
113+
def _make_request(
114+
self,
115+
*,
116+
messages: MessageInputType,
117+
stream: bool | None,
118+
) -> CompletionRequest:
119+
response_format_kwargs: dict[str, Any] = {}
120+
121+
c = self._config
122+
108123
if c.response_format is not None:
109124
schema = schema_from_response_format(c.response_format)
110125
if isinstance(schema, str):
@@ -119,12 +134,19 @@ def _make_request(
119134

120135
return CompletionRequest(
121136
model_uri=self._uri,
122-
completion_options=CompletionOptions(**completion_options_kwargs),
137+
completion_options=self._make_completion_options(stream=stream),
123138
messages=messages_to_proto(messages),
124139
tools=[tool._to_proto(ProtoCompletionsTool) for tool in tools],
125140
**response_format_kwargs,
126141
)
127142

143+
def _make_batch_request(self, dataset_id: str) -> BatchCompletionRequest:
144+
return BatchCompletionRequest(
145+
model_uri=self.uri,
146+
completion_options=self._make_completion_options(stream=False),
147+
source_dataset_id=dataset_id
148+
)
149+
128150
async def _run_sync_impl(
129151
self,
130152
*,
@@ -232,8 +254,10 @@ class AsyncGPTModel(
232254
BaseGPTModel[
233255
AsyncOperation[GPTModelResult[AsyncToolCall]],
234256
AsyncTuningTask['AsyncGPTModel'],
235-
AsyncToolCall
236-
]
257+
AsyncToolCall,
258+
AsyncBatchSubdomain,
259+
],
260+
AsyncModelBatchMixin,
237261
):
238262
_operation_type = AsyncOperation
239263
_tune_operation_type = AsyncTuningTask
@@ -368,7 +392,9 @@ class GPTModel(
368392
Operation[GPTModelResult[ToolCall]],
369393
TuningTask['GPTModel'],
370394
ToolCall,
371-
]
395+
BatchSubdomain,
396+
],
397+
ModelBatchMixin,
372398
):
373399
_operation_type = Operation
374400
_tune_operation_type = TuningTask

src/yandex_cloud_ml_sdk/_types/batch/__init__.py

Whitespace-only changes.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# pylint: disable=no-name-in-module,protected-access
2+
from __future__ import annotations
3+
4+
import abc
5+
from typing import TYPE_CHECKING, Generic, TypeVar, cast
6+
7+
from yandex.cloud.operation.operation_pb2 import Operation as ProtoOperation
8+
9+
from yandex_cloud_ml_sdk._types.datasets import DatasetType, coerce_dataset_id
10+
from yandex_cloud_ml_sdk._utils.sync import run_sync
11+
12+
from .operation import AsyncBatchOperation, BatchOperation, BatchOperationTypeT
13+
14+
if TYPE_CHECKING:
15+
from yandex_cloud_ml_sdk._sdk import BaseSDK
16+
17+
from .model import BaseModelBatchMixin
18+
19+
20+
class BaseBatchSubdomain(Generic[BatchOperationTypeT], metaclass=abc.ABCMeta):
21+
_operation_impl: type[BatchOperationTypeT]
22+
23+
def __init__(self, model: BaseModelBatchMixin, sdk: BaseSDK):
24+
self._model = model
25+
self._sdk = sdk
26+
27+
async def _run_deferred(self, dataset: DatasetType, *, timeout: float = 60) -> BatchOperationTypeT:
28+
dataset_id = coerce_dataset_id(dataset)
29+
30+
m = self._model
31+
request = m._make_batch_request(dataset_id)
32+
stub_class = m._batch_service_stub
33+
proto_result_type = m._batch_proto_result_type
34+
proto_metadata_type = m._batch_proto_metadata_type
35+
36+
async with self._sdk._client.get_service_stub(stub_class, timeout=timeout) as stub:
37+
response = await self._sdk._client.call_service(
38+
stub.Completion,
39+
request=request,
40+
expected_type=ProtoOperation,
41+
timeout=timeout
42+
)
43+
44+
return self._operation_impl(
45+
id=response.id,
46+
sdk=self._sdk,
47+
proto_result_type=proto_result_type,
48+
proto_metadata_type=proto_metadata_type,
49+
initial_operation=response,
50+
)
51+
52+
53+
class AsyncBatchSubdomain(BaseBatchSubdomain[AsyncBatchOperation]):
54+
_operation_impl = AsyncBatchOperation
55+
56+
async def run_deferred(self, dataset: DatasetType, *, timeout: float = 60) -> AsyncBatchOperation:
57+
return await self._run_deferred(dataset=dataset, timeout=timeout)
58+
59+
60+
class BatchSubdomain(BaseBatchSubdomain[BatchOperation]):
61+
_operation_impl = BatchOperation
62+
63+
__run_deferred = run_sync(BaseBatchSubdomain[BatchOperation]._run_deferred)
64+
65+
def run_deferred(self, dataset: DatasetType, *, timeout: float = 60) -> BatchOperation:
66+
return cast(
67+
BatchOperation,
68+
self.__run_deferred(dataset=dataset, timeout=timeout)
69+
)
70+
71+
72+
BatchSubdomainTypeT = TypeVar('BatchSubdomainTypeT', bound=BaseBatchSubdomain)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# pylint: disable=no-name-in-module
2+
from __future__ import annotations
3+
4+
import abc
5+
from functools import cached_property
6+
from typing import Generic, TypeVar
7+
8+
from google.protobuf.message import Message
9+
from typing_extensions import TypeAlias
10+
from yandex.cloud.ai.foundation_models.v1.text_generation.text_generation_service_pb2 import (
11+
BatchCompletionMetadata, BatchCompletionResponse
12+
)
13+
from yandex.cloud.ai.foundation_models.v1.text_generation.text_generation_service_pb2_grpc import (
14+
TextGenerationBatchServiceStub
15+
)
16+
17+
from yandex_cloud_ml_sdk._types.model import BaseModel, ConfigTypeT, ResultTypeT
18+
19+
from .domain import AsyncBatchSubdomain, BatchSubdomain, BatchSubdomainTypeT
20+
21+
BatchStubType: TypeAlias = TextGenerationBatchServiceStub
22+
BatchResultType: TypeAlias = BatchCompletionResponse
23+
BatchMetadataType: TypeAlias = BatchCompletionMetadata
24+
25+
26+
class BaseModelBatchMixin(
27+
BaseModel[ConfigTypeT, ResultTypeT],
28+
Generic[ConfigTypeT, ResultTypeT, BatchSubdomainTypeT],
29+
metaclass=abc.ABCMeta,
30+
):
31+
_batch_impl: type[BatchSubdomainTypeT]
32+
33+
@abc.abstractmethod
34+
def _make_batch_request(self, dataset_id: str) -> Message:
35+
pass
36+
37+
@property
38+
@abc.abstractmethod
39+
def _batch_service_stub(self) -> type[BatchStubType]:
40+
pass
41+
42+
@property
43+
@abc.abstractmethod
44+
def _batch_proto_result_type(self) -> type[BatchResultType]:
45+
pass
46+
47+
@property
48+
@abc.abstractmethod
49+
def _batch_proto_metadata_type(self) -> type[BatchMetadataType]:
50+
pass
51+
52+
@cached_property
53+
def batch(self) -> BatchSubdomainTypeT:
54+
return self._batch_impl(model=self, sdk=self._sdk)
55+
56+
57+
# pylint: disable=abstract-method
58+
class AsyncModelBatchMixin(
59+
BaseModelBatchMixin[ConfigTypeT, ResultTypeT, AsyncBatchSubdomain],
60+
Generic[ConfigTypeT, ResultTypeT],
61+
):
62+
_batch_impl = AsyncBatchSubdomain
63+
64+
65+
# pylint: disable=abstract-method
66+
class ModelBatchMixin(
67+
BaseModelBatchMixin[ConfigTypeT, ResultTypeT, BatchSubdomain],
68+
Generic[ConfigTypeT, ResultTypeT],
69+
):
70+
_batch_impl = BatchSubdomain
71+
72+
73+
ModelWithBatchTypeT = TypeVar('ModelWithBatchTypeT', bound=BaseModelBatchMixin)

0 commit comments

Comments
 (0)