diff --git a/examples/async/assistants/search_index_call_strategy.py b/examples/async/assistants/search_index_call_strategy.py new file mode 100755 index 00000000..b0e8dec0 --- /dev/null +++ b/examples/async/assistants/search_index_call_strategy.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import asyncio +import pathlib + +from yandex_cloud_ml_sdk import AsyncYCloudML + +LABEL_KEY = 'yc-ml-sdk-example' +PATH = pathlib.Path(__file__) +NAME = f'example-{PATH.parent.name}-{PATH.name}' +LABELS = {LABEL_KEY: NAME} + + +def local_path(path: str) -> pathlib.Path: + return pathlib.Path(__file__).parent / path + + +async def get_search_index(sdk): + """ + This function represents getting or creating demo search_index object. + + In real life you will get it any other way that would suit your case. + """ + + async for search_index in sdk.search_indexes.list(): + if search_index.labels and search_index.labels.get(LABEL_KEY) == NAME: + print(f'using {search_index=}') + break + else: + print('no search indexes found, creating new one') + file_coros = ( + sdk.files.upload( + local_path(path), + ttl_days=5, + expiration_policy="static", + ) + for path in ['turkey_example.txt', 'maldives_example.txt'] + ) + files = await asyncio.gather(*file_coros) + operation = await sdk.search_indexes.create_deferred(files, labels=LABELS) + search_index = await operation + print(f'new {search_index=}') + + for file in files: + await file.delete() + + return search_index + + +async def delete_labeled_entities(iterator): + """ + Deletes any entities from given iterator which have .labels attribute + with `labels[LABEL_KEY] == NAME` + """ + + async for entity in iterator: + if entity.labels and entity.labels.get(LABEL_KEY) == NAME: + print(f'deleting {entity.__class__.__name__} with id={entity.id!r}') + await entity.delete() + + +async def main() -> None: + sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64') + sdk.setup_default_logging(log_level='WARNING') + + search_index = await get_search_index(sdk) + thread = await sdk.threads.create(labels=LABELS) + + tool = sdk.tools.search_index(search_index) + assistant = await sdk.assistants.create('yandexgpt', tools=[tool], labels=LABELS) + + # Look, if you don't pass a call strategy to a SearchIndex, it is 'always' use by-default + assert tool.call_strategy is None + assert assistant.tools[0].call_strategy.value == 'always' # type: ignore[attr-defined] + + # First of all we are using request which will definitely find something + search_query = local_path('search_query.txt').read_text().splitlines()[0] + await thread.write(search_query) + run = await assistant.run(thread) + result = await run.wait() + # NB: citations says if index were used or not + assert len(result.citations) > 0 + print(f'If you are using "always" call_strategy, it returns {len(result.citations)>0=} citations from search index') + + # Now we will use a search index, which will be used only if it asked to + tool_with_call_strategy = sdk.tools.search_index( + search_index, + call_strategy={ + 'type': 'function', + 'function': {'name': 'guide', 'instruction': 'use this only if you are asked to look in the guide'} + } + ) + assistant_with_call_strategy = await sdk.assistants.create( + sdk.models.completions('yandexgpt', model_version='rc'), + tools=[tool_with_call_strategy], + labels=LABELS + ) + + await thread.write(search_query) + run = await assistant_with_call_strategy.run(thread) + result = await run.wait() + # NB: citations says if index were used or not + assert len(result.citations) == 0 + print( + 'When you are using special call_strategy and model decides not to use search index according ' + f'to call_strategy instruction, it returns {len(result.citations)>0=} citations from search index' + ) + + await thread.write(f"Look at the guide, please: {search_query}") + run = await assistant_with_call_strategy.run(thread) + result = await run.wait() + # NB: citations says if index were used or not + assert len(result.citations) > 0 + print( + 'When you are using special call_strategy and model decides to use search index according ' + f'to call_strategy instruction, it returns {len(result.citations)>0=} from search index' + ) + + # we will delete all assistant and threads created in this example + # to not to increase chaos level, but not the search index, because + # index creation is a slow operation and could be re-used in this + # example next run + await delete_labeled_entities(sdk.assistants.list()) + await delete_labeled_entities(sdk.threads.list()) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/examples/sync/assistants/search_index_call_strategy.py b/examples/sync/assistants/search_index_call_strategy.py new file mode 100755 index 00000000..3a7e072d --- /dev/null +++ b/examples/sync/assistants/search_index_call_strategy.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import pathlib + +from yandex_cloud_ml_sdk import YCloudML + +LABEL_KEY = 'yc-ml-sdk-example' +PATH = pathlib.Path(__file__) +NAME = f'example-{PATH.parent.name}-{PATH.name}' +LABELS = {LABEL_KEY: NAME} + + +def local_path(path: str) -> pathlib.Path: + return pathlib.Path(__file__).parent / path + + +def get_search_index(sdk): + """ + This function represents getting or creating demo search_index object. + + In real life you will get it any other way that would suit your case. + """ + + for search_index in sdk.search_indexes.list(): + if search_index.labels and search_index.labels.get(LABEL_KEY) == NAME: + print(f'using {search_index=}') + break + else: + print('no search indexes found, creating new one') + files = [ + sdk.files.upload( + local_path(path), + ttl_days=5, + expiration_policy="static", + ) + for path in ['turkey_example.txt', 'maldives_example.txt'] + ] + operation = sdk.search_indexes.create_deferred(files, labels=LABELS) + search_index = operation + print(f'new {search_index=}') + + for file in files: + file.delete() + + return search_index + + +def delete_labeled_entities(iterator): + """ + Deletes any entities from given iterator which have .labels attribute + with `labels[LABEL_KEY] == NAME` + """ + + for entity in iterator: + if entity.labels and entity.labels.get(LABEL_KEY) == NAME: + print(f'deleting {entity.__class__.__name__} with id={entity.id!r}') + entity.delete() + + +def main() -> None: + sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64') + sdk.setup_default_logging(log_level='WARNING') + + search_index = get_search_index(sdk) + thread = sdk.threads.create(labels=LABELS) + + tool = sdk.tools.search_index(search_index) + assistant = sdk.assistants.create('yandexgpt', tools=[tool], labels=LABELS) + + # Look, if you don't pass a call strategy to a SearchIndex, it is 'always' use by-default + assert tool.call_strategy is None + assert assistant.tools[0].call_strategy.value == 'always' # type: ignore[attr-defined] + + # First of all we are using request which will definitely find something + search_query = local_path('search_query.txt').read_text().splitlines()[0] + thread.write(search_query) + run = assistant.run(thread) + result = run.wait() + # NB: citations says if index were used or not + assert len(result.citations) > 0 + print(f'If you are using "always" call_strategy, it returns {len(result.citations)>0=} citations from search index') + + # Now we will use a search index, which will be used only if it asked to + tool_with_call_strategy = sdk.tools.search_index( + search_index, + call_strategy={ + 'type': 'function', + 'function': {'name': 'guide', 'instruction': 'use this only if you are asked to look in the guide'} + } + ) + assistant_with_call_strategy = sdk.assistants.create( + sdk.models.completions('yandexgpt', model_version='rc'), + tools=[tool_with_call_strategy], + labels=LABELS + ) + + thread.write(search_query) + run = assistant_with_call_strategy.run(thread) + result = run.wait() + # NB: citations says if index were used or not + assert len(result.citations) == 0 + print( + 'When you are using special call_strategy and model decides not to use search index according ' + f'to call_strategy instruction, it returns {len(result.citations)>0=} citations from search index' + ) + + thread.write(f"Look at the guide, please: {search_query}") + run = assistant_with_call_strategy.run(thread) + result = run.wait() + # NB: citations says if index were used or not + assert len(result.citations) > 0 + print( + 'When you are using special call_strategy and model decides to use search index according ' + f'to call_strategy instruction, it returns {len(result.citations)>0=} from search index' + ) + + # we will delete all assistant and threads created in this example + # to not to increase chaos level, but not the search index, because + # index creation is a slow operation and could be re-used in this + # example next run + delete_labeled_entities(sdk.assistants.list()) + delete_labeled_entities(sdk.threads.list()) + + +if __name__ == '__main__': + main() diff --git a/src/yandex_cloud_ml_sdk/_models/completions/config.py b/src/yandex_cloud_ml_sdk/_models/completions/config.py index 7214f397..4668f24a 100644 --- a/src/yandex_cloud_ml_sdk/_models/completions/config.py +++ b/src/yandex_cloud_ml_sdk/_models/completions/config.py @@ -12,7 +12,7 @@ from yandex_cloud_ml_sdk._tools.tool import FunctionTool from yandex_cloud_ml_sdk._types.model_config import BaseModelConfig from yandex_cloud_ml_sdk._types.schemas import ResponseType -from yandex_cloud_ml_sdk._types.tool_choice import ToolChoiceType +from yandex_cloud_ml_sdk._types.tools.tool_choice import ToolChoiceType from yandex_cloud_ml_sdk._utils.proto import ProtoEnumBase _m = ProtoReasoningOptions.ReasoningMode diff --git a/src/yandex_cloud_ml_sdk/_models/completions/model.py b/src/yandex_cloud_ml_sdk/_models/completions/model.py index 7a7b6f2c..cb7a6b13 100644 --- a/src/yandex_cloud_ml_sdk/_models/completions/model.py +++ b/src/yandex_cloud_ml_sdk/_models/completions/model.py @@ -30,8 +30,8 @@ ) from yandex_cloud_ml_sdk._types.operation import AsyncOperation, Operation from yandex_cloud_ml_sdk._types.schemas import ResponseType, make_response_format_kwargs -from yandex_cloud_ml_sdk._types.tool_choice import ToolChoiceType -from yandex_cloud_ml_sdk._types.tool_choice import coerce_to_proto as coerce_to_proto_tool_choice +from yandex_cloud_ml_sdk._types.tools.tool_choice import ToolChoiceType +from yandex_cloud_ml_sdk._types.tools.tool_choice import coerce_to_proto as coerce_to_proto_tool_choice from yandex_cloud_ml_sdk._types.tuning.datasets import TuningDatasetsType from yandex_cloud_ml_sdk._types.tuning.optimizers import BaseOptimizer from yandex_cloud_ml_sdk._types.tuning.schedulers import BaseScheduler diff --git a/src/yandex_cloud_ml_sdk/_tools/domain.py b/src/yandex_cloud_ml_sdk/_tools/domain.py index fdefd8bc..06f06d43 100644 --- a/src/yandex_cloud_ml_sdk/_tools/domain.py +++ b/src/yandex_cloud_ml_sdk/_tools/domain.py @@ -10,8 +10,9 @@ from yandex_cloud_ml_sdk._utils.coerce import ResourceType, coerce_resource_ids from .function import AsyncFunctionTools, FunctionTools, FunctionToolsTypeT -from .rephraser.function import RephraserFunction, RephraserInputType -from .tool import SearchIndexTool +from .search_index.call_strategy import CallStrategy, CallStrategyInputType +from .search_index.rephraser.function import RephraserFunction, RephraserInputType +from .search_index.tool import SearchIndexTool class BaseTools(BaseDomain, Generic[FunctionToolsTypeT]): @@ -38,6 +39,7 @@ def search_index( *, max_num_results: UndefinedOr[int] = UNDEFINED, rephraser: UndefinedOr[RephraserInputType] = UNDEFINED, + call_strategy: UndefinedOr[CallStrategyInputType] = UNDEFINED, ) -> SearchIndexTool: """Creates SearchIndexTool (not to be confused with :py:class:`~.SearchIndex`). @@ -58,10 +60,15 @@ def search_index( # this is coercing any RephraserInputType to Rephraser rephraser_ = self.rephraser(rephraser) # type: ignore[arg-type] + call_strategy_ = None + if is_defined(call_strategy): + call_strategy_ = CallStrategy._coerce(call_strategy) + return SearchIndexTool( search_index_ids=tuple(index_ids), max_num_results=max_num_results_, rephraser=rephraser_, + call_strategy=call_strategy_, ) diff --git a/src/yandex_cloud_ml_sdk/_tools/rephraser/__init__.py b/src/yandex_cloud_ml_sdk/_tools/search_index/__init__.py similarity index 100% rename from src/yandex_cloud_ml_sdk/_tools/rephraser/__init__.py rename to src/yandex_cloud_ml_sdk/_tools/search_index/__init__.py diff --git a/src/yandex_cloud_ml_sdk/_tools/search_index/call_strategy.py b/src/yandex_cloud_ml_sdk/_tools/search_index/call_strategy.py new file mode 100644 index 00000000..5aec4193 --- /dev/null +++ b/src/yandex_cloud_ml_sdk/_tools/search_index/call_strategy.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import Literal, Union, cast + +from typing_extensions import TypeAlias +# pylint: disable=no-name-in-module +from yandex.cloud.ai.assistants.v1.common_pb2 import CallStrategy as ProtoCallStrategy + +from yandex_cloud_ml_sdk._types.proto import ProtoBased, SDKType +from yandex_cloud_ml_sdk._types.tools.function import FunctionDictType, validate_function_dict + +CallStrategyStringType: TypeAlias = Literal['always'] + +CallStrategyType: TypeAlias = Union[CallStrategyStringType, FunctionDictType] +CallStrategyInputType: TypeAlias = Union[CallStrategyType, 'CallStrategy'] + + +class CallStrategy(ProtoBased[ProtoCallStrategy]): + _call_strategy: CallStrategyType + + def __init__(self, call_strategy: CallStrategyType): + self._call_strategy = call_strategy + self._validate() + + @property + def value(self) -> CallStrategyType: + return self._call_strategy + + def _validate(self): + call_strategy = self.value + if isinstance(call_strategy, str): + if call_strategy == 'always': + return + elif isinstance(call_strategy, dict): + call_strategy = validate_function_dict(call_strategy) + if 'instruction' in call_strategy['function']: + return + + raise ValueError( + f'wrong {call_strategy=}, ' + 'expected `call_strategy="always"` or' + '`call_strategy={"type": "function", "function": {"name": str, "instruction": str}}`' + ) + + # pylint: disable=unused-argument + @classmethod + def _from_proto(cls, *, proto: ProtoCallStrategy, sdk: SDKType) -> CallStrategy: + value: CallStrategyType + if proto.HasField('auto_call'): + value = { + 'type': 'function', + 'function': {'name': proto.auto_call.name, 'instruction': proto.auto_call.instruction} + } + elif proto.HasField('always_call'): + value = 'always' + else: + raise RuntimeError( + "proto message CallStrategy have unknown fields, try to upgrade yandex-cloud-ml-sdk") + return cls(value) + + def _to_proto(self) -> ProtoCallStrategy: + if self._call_strategy == 'always': + return ProtoCallStrategy( + always_call=ProtoCallStrategy.AlwaysCall() + ) + call_strategy = cast(FunctionDictType, self._call_strategy) + function = call_strategy['function'] + assert 'instruction' in function + return ProtoCallStrategy( + auto_call=ProtoCallStrategy.AutoCall( + name=function['name'], + instruction=function['instruction'] + ) + ) + + @classmethod + def _coerce(cls, call_strategy: CallStrategyInputType): + if isinstance(call_strategy, CallStrategy): + return call_strategy + + return cls(call_strategy) + + def __repr__(self): + return f"{self.__class__.__name__}({self.value!r})" diff --git a/src/yandex_cloud_ml_sdk/_tools/search_index/rephraser/__init__.py b/src/yandex_cloud_ml_sdk/_tools/search_index/rephraser/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/yandex_cloud_ml_sdk/_tools/rephraser/function.py b/src/yandex_cloud_ml_sdk/_tools/search_index/rephraser/function.py similarity index 100% rename from src/yandex_cloud_ml_sdk/_tools/rephraser/function.py rename to src/yandex_cloud_ml_sdk/_tools/search_index/rephraser/function.py diff --git a/src/yandex_cloud_ml_sdk/_tools/rephraser/model.py b/src/yandex_cloud_ml_sdk/_tools/search_index/rephraser/model.py similarity index 100% rename from src/yandex_cloud_ml_sdk/_tools/rephraser/model.py rename to src/yandex_cloud_ml_sdk/_tools/search_index/rephraser/model.py diff --git a/src/yandex_cloud_ml_sdk/_tools/search_index/tool.py b/src/yandex_cloud_ml_sdk/_tools/search_index/tool.py new file mode 100644 index 00000000..9c71cd98 --- /dev/null +++ b/src/yandex_cloud_ml_sdk/_tools/search_index/tool.py @@ -0,0 +1,67 @@ +# pylint: disable=no-name-in-module +from __future__ import annotations + +from dataclasses import dataclass + +from google.protobuf.wrappers_pb2 import Int64Value +from yandex.cloud.ai.assistants.v1.common_pb2 import SearchIndexTool as ProtoSearchIndexTool + +from yandex_cloud_ml_sdk._tools.tool import BaseTool, ProtoAssistantsTool, ProtoToolTypeT +from yandex_cloud_ml_sdk._types.proto import SDKType + +from .call_strategy import CallStrategy +from .rephraser.model import Rephraser + + +@dataclass(frozen=True) +class SearchIndexTool(BaseTool[ProtoSearchIndexTool]): + search_index_ids: tuple[str, ...] + + max_num_results: int | None = None + rephraser: Rephraser | None = None + call_strategy: CallStrategy | None = None + + @classmethod + def _from_proto(cls, *, proto: ProtoSearchIndexTool, sdk: SDKType) -> SearchIndexTool: + max_num_results: int | None = None + if proto.HasField("max_num_results"): + max_num_results = proto.max_num_results.value + + rephraser: Rephraser | None = None + if proto.HasField("rephraser_options"): + rephraser = Rephraser._from_proto(proto=proto.rephraser_options, sdk=sdk) + + call_strategy: CallStrategy | None = None + if proto.HasField('call_strategy'): + call_strategy = CallStrategy._from_proto(proto=proto.call_strategy, sdk=sdk) + + return cls( + search_index_ids=tuple(proto.search_index_ids), + max_num_results=max_num_results, + rephraser=rephraser, + call_strategy=call_strategy, + ) + + def _to_proto(self, proto_type: type[ProtoToolTypeT]) -> ProtoToolTypeT: + assert issubclass(proto_type, ProtoAssistantsTool) + + max_num_results: None | Int64Value = None + if self.max_num_results is not None: + max_num_results = Int64Value(value=self.max_num_results) + + rephraser = None + if self.rephraser: + rephraser = self.rephraser._to_proto() + + call_strategy = None + if self.call_strategy: + call_strategy = self.call_strategy._to_proto() + + return proto_type( + search_index=ProtoSearchIndexTool( + max_num_results=max_num_results, + search_index_ids=self.search_index_ids, + rephraser_options=rephraser, + call_strategy=call_strategy, + ) + ) diff --git a/src/yandex_cloud_ml_sdk/_tools/tool.py b/src/yandex_cloud_ml_sdk/_tools/tool.py index f0e6dac1..8c8283c1 100644 --- a/src/yandex_cloud_ml_sdk/_tools/tool.py +++ b/src/yandex_cloud_ml_sdk/_tools/tool.py @@ -3,32 +3,25 @@ import abc from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TypeVar, Union from google.protobuf.json_format import MessageToDict from google.protobuf.struct_pb2 import Struct -from google.protobuf.wrappers_pb2 import Int64Value from yandex.cloud.ai.assistants.v1.common_pb2 import FunctionTool as ProtoAssistantsFunctionTool -from yandex.cloud.ai.assistants.v1.common_pb2 import SearchIndexTool as ProtoSearchIndexTool from yandex.cloud.ai.assistants.v1.common_pb2 import Tool as ProtoAssistantsTool from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import FunctionTool as ProtoCompletionsFunctionTool from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import Tool as ProtoCompletionsTool +from yandex_cloud_ml_sdk._types.proto import ProtoBased, ProtoMessageTypeT, SDKType from yandex_cloud_ml_sdk._types.schemas import JsonSchemaType -from .rephraser.model import Rephraser - -if TYPE_CHECKING: - from yandex_cloud_ml_sdk._sdk import BaseSDK - - ProtoToolTypeT = TypeVar('ProtoToolTypeT', ProtoAssistantsTool, ProtoCompletionsTool) -class BaseTool(abc.ABC): +class BaseTool(ProtoBased[ProtoMessageTypeT]): @classmethod @abc.abstractmethod - def _from_proto(cls, proto: Any, sdk: BaseSDK) -> BaseTool: + def _from_proto(cls, *, proto: ProtoMessageTypeT, sdk: SDKType) -> BaseTool: pass @abc.abstractmethod @@ -36,7 +29,7 @@ def _to_proto(self, proto_type: type[ProtoToolTypeT]) -> ProtoToolTypeT: pass @classmethod - def _from_upper_proto(cls, proto: ProtoToolTypeT, sdk: BaseSDK) -> BaseTool: + def _from_upper_proto(cls, proto: ProtoToolTypeT, sdk: SDKType) -> BaseTool: if proto.HasField('function'): return FunctionTool._from_proto( proto=proto.function, @@ -48,6 +41,9 @@ def _from_upper_proto(cls, proto: ProtoToolTypeT, sdk: BaseSDK) -> BaseTool: hasattr(proto, 'search_index') and proto.HasField('search_index') # type: ignore[arg-type] ): + # pylint: disable=import-outside-toplevel + from .search_index.tool import SearchIndexTool + return SearchIndexTool._from_proto( proto=proto.search_index, sdk=sdk @@ -55,61 +51,23 @@ def _from_upper_proto(cls, proto: ProtoToolTypeT, sdk: BaseSDK) -> BaseTool: raise NotImplementedError('tools other then search_index and function are not supported in this SDK version') -@dataclass(frozen=True) -class SearchIndexTool(BaseTool): - search_index_ids: tuple[str, ...] - - max_num_results: int | None = None - rephraser: Rephraser | None = None - - @classmethod - def _from_proto(cls, proto: ProtoSearchIndexTool, sdk: BaseSDK) -> SearchIndexTool: - max_num_results: int | None = None - if proto.HasField("max_num_results"): - max_num_results = proto.max_num_results.value - - rephraser: Rephraser | None = None - if proto.HasField("rephraser_options"): - rephraser = Rephraser._from_proto(proto=proto.rephraser_options, sdk=sdk) - - return cls( - search_index_ids=tuple(proto.search_index_ids), - max_num_results=max_num_results, - rephraser=rephraser, - ) - - def _to_proto(self, proto_type: type[ProtoToolTypeT]) -> ProtoToolTypeT: - assert issubclass(proto_type, ProtoAssistantsTool) - - max_num_results: None | Int64Value = None - if self.max_num_results is not None: - max_num_results = Int64Value(value=self.max_num_results) - - rephraser = None - if self.rephraser: - rephraser = self.rephraser._to_proto() - - return proto_type( - search_index=ProtoSearchIndexTool( - max_num_results=max_num_results, - search_index_ids=self.search_index_ids, - rephraser_options=rephraser, - ) - ) +ProtoFunctionTool = Union[ProtoCompletionsFunctionTool, ProtoAssistantsFunctionTool] @dataclass(frozen=True) -class FunctionTool(BaseTool): +class FunctionTool(BaseTool[ProtoFunctionTool]): name: str description: str | None parameters: JsonSchemaType strict: bool | None + # pylint: disable=unused-argument @classmethod def _from_proto( cls, - proto: ProtoCompletionsFunctionTool | ProtoAssistantsFunctionTool, - sdk: BaseSDK + *, + proto: ProtoFunctionTool, + sdk:SDKType, ) -> FunctionTool: parameters = MessageToDict(proto.parameters) diff --git a/src/yandex_cloud_ml_sdk/_types/tools/__init__.py b/src/yandex_cloud_ml_sdk/_types/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/yandex_cloud_ml_sdk/_types/tools/function.py b/src/yandex_cloud_ml_sdk/_types/tools/function.py new file mode 100644 index 00000000..2b1cd8f5 --- /dev/null +++ b/src/yandex_cloud_ml_sdk/_types/tools/function.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Any, Literal, TypedDict, cast + +from typing_extensions import NotRequired + + +class FunctionNameType(TypedDict): + name: str + instruction: NotRequired[str] + + +class FunctionDictType(TypedDict): + type: Literal['function'] + function: FunctionNameType + + +def validate_function_dict(function: Any) -> FunctionDictType: + if ( + function.get('type') != 'function' or + not isinstance(function.get('function'), dict) or + not isinstance(function['function'].get('name'), str) or + not isinstance(function['function'].get('instruction', ''), str) + ): + raise ValueError( + 'wrong dict structure for function description, expected ' + '`{"type": "function", "function": {"name": str[, "instruction": str}}`, ' + f'got {function}' + ) + + return cast(FunctionDictType, function) diff --git a/src/yandex_cloud_ml_sdk/_types/tool_choice.py b/src/yandex_cloud_ml_sdk/_types/tools/tool_choice.py similarity index 63% rename from src/yandex_cloud_ml_sdk/_types/tool_choice.py rename to src/yandex_cloud_ml_sdk/_types/tools/tool_choice.py index e636ad09..2e609309 100644 --- a/src/yandex_cloud_ml_sdk/_types/tool_choice.py +++ b/src/yandex_cloud_ml_sdk/_types/tools/tool_choice.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Literal, TypedDict, TypeVar, Union, cast +from typing import Literal, TypeVar, Union, cast from typing_extensions import TypeAlias # pylint: disable=no-name-in-module @@ -8,26 +8,19 @@ from yandex_cloud_ml_sdk._tools.tool import FunctionTool +from .function import FunctionDictType, validate_function_dict + ProtoToolChoice: TypeAlias = ProtoCompletionsToolChoice ProtoToolChoiceTypeT = TypeVar('ProtoToolChoiceTypeT', bound=ProtoToolChoice) - -class FunctionNameType(TypedDict): - name: str - - -class ToolChoiceDictType(TypedDict): - type: Literal['function'] - function: FunctionNameType - - ToolChoiceStringType: TypeAlias = Literal[ 'none', 'None', 'NONE', 'auto', 'Auto', 'AUTO', 'required', 'Required', 'REQUIRED' ] -ToolChoiceType: TypeAlias = Union[ToolChoiceStringType, ToolChoiceDictType, FunctionTool] +ToolChoiceType: TypeAlias = Union[ToolChoiceStringType, FunctionDictType, FunctionTool] + STRING_TOOL_CHOICES = ('NONE', 'AUTO', 'REQUIRED') @@ -45,19 +38,7 @@ def coerce_to_proto( return expected_type(mode=tool_choice_value) if isinstance(tool_choice, dict): - if ( - tool_choice.get('type') != 'function' or - not isinstance(tool_choice.get('function'), dict) or - not isinstance(tool_choice['function'].get('name'), str) - ): - raise ValueError( - 'wrong dict structure for tool_choice, expected ' - '`{"type": "function", "function": {"name": function_name}}`, ' - 'got {tool_choice}' - ) - - tool_choice = cast(ToolChoiceDictType, tool_choice) - + tool_choice = validate_function_dict(tool_choice) return expected_type(function_name=tool_choice['function']['name']) if isinstance(tool_choice, FunctionTool): diff --git a/tests/assistants/cassettes/test_search_indexes/test_call_strategy_search_index.gprc.json b/tests/assistants/cassettes/test_search_indexes/test_call_strategy_search_index.gprc.json new file mode 100644 index 00000000..66e3974e --- /dev/null +++ b/tests/assistants/cassettes/test_search_indexes/test_call_strategy_search_index.gprc.json @@ -0,0 +1,1607 @@ +{ + "interactions": [ + { + "request": { + "cls": "ListApiEndpointsRequest", + "module": "yandex.cloud.endpoint.api_endpoint_service_pb2", + "message": {} + }, + "response": { + "cls": "ListApiEndpointsResponse", + "module": "yandex.cloud.endpoint.api_endpoint_service_pb2", + "message": { + "endpoints": [ + { + "id": "ai-assistants", + "address": "assistant.api.cloud.yandex.net:443" + }, + { + "id": "ai-files", + "address": "assistant.api.cloud.yandex.net:443" + }, + { + "id": "ai-foundation-models", + "address": "llm.api.cloud.yandex.net:443" + }, + { + "id": "ai-llm", + "address": "llm.api.cloud.yandex.net:443" + }, + { + "id": "ai-speechkit", + "address": "transcribe.api.cloud.yandex.net:443" + }, + { + "id": "ai-stt", + "address": "transcribe.api.cloud.yandex.net:443" + }, + { + "id": "ai-stt-v3", + "address": "stt.api.cloud.yandex.net:443" + }, + { + "id": "ai-translate", + "address": "translate.api.cloud.yandex.net:443" + }, + { + "id": "ai-vision", + "address": "vision.api.cloud.yandex.net:443" + }, + { + "id": "ai-vision-ocr", + "address": "ocr.api.cloud.yandex.net:443" + }, + { + "id": "alb", + "address": "alb.api.cloud.yandex.net:443" + }, + { + "id": "apigateway-connections", + "address": "apigateway-connections.api.cloud.yandex.net:443" + }, + { + "id": "application-load-balancer", + "address": "alb.api.cloud.yandex.net:443" + }, + { + "id": "apploadbalancer", + "address": "alb.api.cloud.yandex.net:443" + }, + { + "id": "audittrails", + "address": "audittrails.api.cloud.yandex.net:443" + }, + { + "id": "baas", + "address": "backup.api.cloud.yandex.net:443" + }, + { + "id": "backup", + "address": "backup.api.cloud.yandex.net:443" + }, + { + "id": "baremetal", + "address": "baremetal.api.cloud.yandex.net:443" + }, + { + "id": "billing", + "address": "billing.api.cloud.yandex.net:443" + }, + { + "id": "broker-data", + "address": "iot-data.api.cloud.yandex.net:443" + }, + { + "id": "cdn", + "address": "cdn.api.cloud.yandex.net:443" + }, + { + "id": "certificate-manager", + "address": "certificate-manager.api.cloud.yandex.net:443" + }, + { + "id": "certificate-manager-data", + "address": "data.certificate-manager.api.cloud.yandex.net:443" + }, + { + "id": "certificate-manager-private-ca", + "address": "private-ca.certificate-manager.api.cloud.yandex.net:443" + }, + { + "id": "certificate-manager-private-ca-data", + "address": "data.private-ca.certificate-manager.api.cloud.yandex.net:443" + }, + { + "id": "cic", + "address": "cic.api.cloud.yandex.net:443" + }, + { + "id": "cloud-registry", + "address": "registry.api.cloud.yandex.net:443" + }, + { + "id": "cloudapps", + "address": "cloudapps.api.cloud.yandex.net:443" + }, + { + "id": "cloudbackup", + "address": "backup.api.cloud.yandex.net:443" + }, + { + "id": "clouddesktops", + "address": "clouddesktops.api.cloud.yandex.net:443" + }, + { + "id": "cloudrouter", + "address": "cloudrouter.api.cloud.yandex.net:443" + }, + { + "id": "cloudvideo", + "address": "video.api.cloud.yandex.net:443" + }, + { + "id": "compute", + "address": "compute.api.cloud.yandex.net:443" + }, + { + "id": "container-registry", + "address": "container-registry.api.cloud.yandex.net:443" + }, + { + "id": "dataproc", + "address": "dataproc.api.cloud.yandex.net:443" + }, + { + "id": "dataproc-manager", + "address": "dataproc-manager.api.cloud.yandex.net:443" + }, + { + "id": "datasphere", + "address": "datasphere.api.cloud.yandex.net:443" + }, + { + "id": "datatransfer", + "address": "datatransfer.api.cloud.yandex.net:443" + }, + { + "id": "dns", + "address": "dns.api.cloud.yandex.net:443" + }, + { + "id": "endpoint", + "address": "api.cloud.yandex.net:443" + }, + { + "id": "fomo-dataset", + "address": "fomo-dataset.api.cloud.yandex.net:443" + }, + { + "id": "fomo-tuning", + "address": "fomo-tuning.api.cloud.yandex.net:443" + }, + { + "id": "gitlab", + "address": "gitlab.api.cloud.yandex.net:443" + }, + { + "id": "iam", + "address": "iam.api.cloud.yandex.net:443" + }, + { + "id": "iot-broker", + "address": "iot-broker.api.cloud.yandex.net:443" + }, + { + "id": "iot-data", + "address": "iot-data.api.cloud.yandex.net:443" + }, + { + "id": "iot-devices", + "address": "iot-devices.api.cloud.yandex.net:443" + }, + { + "id": "k8s", + "address": "mks.api.cloud.yandex.net:443" + }, + { + "id": "kms", + "address": "kms.api.cloud.yandex.net:443" + }, + { + "id": "kms-crypto", + "address": "kms.yandex:443" + }, + { + "id": "kspm", + "address": "kspm.api.cloud.yandex.net:443" + }, + { + "id": "load-balancer", + "address": "load-balancer.api.cloud.yandex.net:443" + }, + { + "id": "loadtesting", + "address": "loadtesting.api.cloud.yandex.net:443" + }, + { + "id": "locator", + "address": "locator.api.cloud.yandex.net:443" + }, + { + "id": "lockbox", + "address": "lockbox.api.cloud.yandex.net:443" + }, + { + "id": "lockbox-payload", + "address": "payload.lockbox.api.cloud.yandex.net:443" + }, + { + "id": "log-ingestion", + "address": "ingester.logging.yandexcloud.net:443" + }, + { + "id": "log-reading", + "address": "reader.logging.yandexcloud.net:443" + }, + { + "id": "logging", + "address": "logging.api.cloud.yandex.net:443" + }, + { + "id": "managed-airflow", + "address": "airflow.api.cloud.yandex.net:443" + }, + { + "id": "managed-clickhouse", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-elasticsearch", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-greenplum", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-kafka", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-kubernetes", + "address": "mks.api.cloud.yandex.net:443" + }, + { + "id": "managed-metastore", + "address": "metastore.api.cloud.yandex.net:443" + }, + { + "id": "managed-mongodb", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-mysql", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-opensearch", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-postgresql", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-redis", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-spark", + "address": "spark.api.cloud.yandex.net:443" + }, + { + "id": "managed-spqr", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-sqlserver", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "managed-trino", + "address": "trino.api.cloud.yandex.net:443" + }, + { + "id": "managed-ytsaurus", + "address": "ytsaurus.api.cloud.yandex.net:443" + }, + { + "id": "marketplace", + "address": "marketplace.api.cloud.yandex.net:443" + }, + { + "id": "marketplace-pim", + "address": "marketplace.api.cloud.yandex.net:443" + }, + { + "id": "mdb-clickhouse", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "mdb-mongodb", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "mdb-mysql", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "mdb-opensearch", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "mdb-postgresql", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "mdb-redis", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "mdb-spqr", + "address": "mdb.api.cloud.yandex.net:443" + }, + { + "id": "mdbproxy", + "address": "mdbproxy.api.cloud.yandex.net:443" + }, + { + "id": "monitoring", + "address": "monitoring.api.cloud.yandex.net:443" + }, + { + "id": "operation", + "address": "operation.api.cloud.yandex.net:443" + }, + { + "id": "organization-manager", + "address": "organization-manager.api.cloud.yandex.net:443" + }, + { + "id": "organizationmanager", + "address": "organization-manager.api.cloud.yandex.net:443" + }, + { + "id": "quota-manager", + "address": "quota-manager.api.cloud.yandex.net:443" + }, + { + "id": "quotamanager", + "address": "quota-manager.api.cloud.yandex.net:443" + }, + { + "id": "resource-manager", + "address": "resource-manager.api.cloud.yandex.net:443" + }, + { + "id": "resourcemanager", + "address": "resource-manager.api.cloud.yandex.net:443" + }, + { + "id": "searchapi", + "address": "searchapi.api.cloud.yandex.net:443" + }, + { + "id": "serialssh", + "address": "serialssh.cloud.yandex.net:9600" + }, + { + "id": "serverless-apigateway", + "address": "serverless-apigateway.api.cloud.yandex.net:443" + }, + { + "id": "serverless-containers", + "address": "serverless-containers.api.cloud.yandex.net:443" + }, + { + "id": "serverless-eventrouter", + "address": "serverless-eventrouter.api.cloud.yandex.net:443" + }, + { + "id": "serverless-functions", + "address": "serverless-functions.api.cloud.yandex.net:443" + }, + { + "id": "serverless-gateway-connections", + "address": "apigateway-connections.api.cloud.yandex.net:443" + }, + { + "id": "serverless-triggers", + "address": "serverless-triggers.api.cloud.yandex.net:443" + }, + { + "id": "serverless-workflows", + "address": "serverless-workflows.api.cloud.yandex.net:443" + }, + { + "id": "serverlesseventrouter-events", + "address": "events.eventrouter.serverless.yandexcloud.net:443" + }, + { + "id": "smart-captcha", + "address": "smartcaptcha.api.cloud.yandex.net:443" + }, + { + "id": "smart-web-security", + "address": "smartwebsecurity.api.cloud.yandex.net:443" + }, + { + "id": "storage", + "address": "storage.yandexcloud.net:443" + }, + { + "id": "storage-api", + "address": "storage.api.cloud.yandex.net:443" + }, + { + "id": "video", + "address": "video.api.cloud.yandex.net:443" + }, + { + "id": "vpc", + "address": "vpc.api.cloud.yandex.net:443" + }, + { + "id": "ydb", + "address": "ydb.api.cloud.yandex.net:443" + } + ] + } + } + }, + { + "request": { + "cls": "CreateFileRequest", + "module": "yandex.cloud.ai.files.v1.file_service_pb2", + "message": { + "folderId": "b1ghsjum2v37c2un8h64", + "content": "bXkgc2VjcmV0IG51bWJlciBpcyA1Nw==" + } + }, + "response": { + "cls": "File", + "module": "yandex.cloud.ai.files.v1.file_pb2", + "message": { + "id": "fvt0ogmrojgbr5mkuabb", + "folderId": "b1ghsjum2v37c2un8h64", + "mimeType": "text/plain", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:31.141210Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:31.141210Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:31.141210Z" + } + } + }, + { + "request": { + "cls": "CreateSearchIndexRequest", + "module": "yandex.cloud.ai.assistants.v1.searchindex.search_index_service_pb2", + "message": { + "folderId": "b1ghsjum2v37c2un8h64", + "fileIds": [ + "fvt0ogmrojgbr5mkuabb" + ] + } + }, + "response": { + "cls": "Operation", + "module": "yandex.cloud.operation.operation_pb2", + "message": { + "id": "fvt7p4a8ld1d1kh324df", + "description": "search index creation", + "createdAt": "2025-07-17T18:02:31.350593Z", + "createdBy": "aje6euqn63oa635coh28", + "modifiedAt": "2025-07-17T18:02:31.350593Z" + } + } + }, + { + "request": { + "cls": "GetOperationRequest", + "module": "yandex.cloud.operation.operation_service_pb2", + "message": { + "operationId": "fvt7p4a8ld1d1kh324df" + } + }, + "response": { + "cls": "Operation", + "module": "yandex.cloud.operation.operation_pb2", + "message": { + "id": "fvt7p4a8ld1d1kh324df", + "description": "search index creation", + "createdAt": "2025-07-17T18:02:31.350593Z", + "createdBy": "aje6euqn63oa635coh28", + "modifiedAt": "2025-07-17T18:02:31.350593Z" + } + } + }, + { + "request": { + "cls": "GetOperationRequest", + "module": "yandex.cloud.operation.operation_service_pb2", + "message": { + "operationId": "fvt7p4a8ld1d1kh324df" + } + }, + "response": { + "cls": "Operation", + "module": "yandex.cloud.operation.operation_pb2", + "message": { + "id": "fvt7p4a8ld1d1kh324df", + "description": "search index creation", + "createdAt": "2025-07-17T18:02:31.350593Z", + "createdBy": "aje6euqn63oa635coh28", + "modifiedAt": "2025-07-17T18:02:32.771474Z", + "done": true, + "response": { + "@type": "type.googleapis.com/yandex.cloud.ai.assistants.v1.searchindex.SearchIndex", + "id": "fvtavu9kirfk60bfqo2u", + "folderId": "b1ghsjum2v37c2un8h64", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:31.422688Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:31.422688Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:31.422688Z", + "textSearchIndex": { + "chunkingStrategy": { + "staticStrategy": { + "maxChunkSizeTokens": "800", + "chunkOverlapTokens": "400" + } + }, + "standardTokenizer": {}, + "yandexLemmerAnalyzer": {} + } + } + } + } + }, + { + "request": { + "cls": "CreateThreadRequest", + "module": "yandex.cloud.ai.assistants.v1.threads.thread_service_pb2", + "message": { + "folderId": "b1ghsjum2v37c2un8h64" + } + }, + "response": { + "cls": "Thread", + "module": "yandex.cloud.ai.assistants.v1.threads.thread_pb2", + "message": { + "id": "fvtom7ho6k35um57n0te", + "folderId": "b1ghsjum2v37c2un8h64", + "defaultMessageAuthorId": "fvtd1f64i8h4ffguelgb", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:41.653065Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:41.653065Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:41.653065Z" + } + } + }, + { + "request": { + "cls": "CreateAssistantRequest", + "module": "yandex.cloud.ai.assistants.v1.assistant_service_pb2", + "message": { + "folderId": "b1ghsjum2v37c2un8h64", + "modelUri": "gpt://b1ghsjum2v37c2un8h64/yandexgpt/latest", + "promptTruncationOptions": {}, + "tools": [ + { + "searchIndex": { + "searchIndexIds": [ + "fvtavu9kirfk60bfqo2u" + ] + } + } + ] + } + }, + "response": { + "cls": "Assistant", + "module": "yandex.cloud.ai.assistants.v1.assistant_pb2", + "message": { + "id": "fvt6fqsrhp0524dlka3d", + "folderId": "b1ghsjum2v37c2un8h64", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:41.758740Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:41.758740Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:41.758740Z", + "modelUri": "gpt://b1ghsjum2v37c2un8h64/yandexgpt/latest", + "promptTruncationOptions": { + "autoStrategy": {} + }, + "completionOptions": {}, + "tools": [ + { + "searchIndex": { + "searchIndexIds": [ + "fvtavu9kirfk60bfqo2u" + ], + "callStrategy": { + "alwaysCall": {} + } + } + } + ] + } + } + }, + { + "request": { + "cls": "CreateMessageRequest", + "module": "yandex.cloud.ai.assistants.v1.threads.message_service_pb2", + "message": { + "threadId": "fvtom7ho6k35um57n0te", + "content": { + "content": [ + { + "text": { + "content": "what is your secret number" + } + } + ] + } + } + }, + "response": { + "cls": "Message", + "module": "yandex.cloud.ai.assistants.v1.threads.message_pb2", + "message": { + "id": "fvtmn8sp4riv01okqcu2", + "threadId": "fvtom7ho6k35um57n0te", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:41.868787Z", + "author": { + "id": "fvtd1f64i8h4ffguelgb", + "role": "USER" + }, + "content": { + "content": [ + { + "text": { + "content": "what is your secret number" + } + } + ] + }, + "status": "COMPLETED" + } + } + }, + { + "request": { + "cls": "CreateRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "assistantId": "fvt6fqsrhp0524dlka3d", + "threadId": "fvtom7ho6k35um57n0te" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvt8shfis3oc9gqsn4jg", + "assistantId": "fvt6fqsrhp0524dlka3d", + "threadId": "fvtom7ho6k35um57n0te", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:41.966635871Z", + "state": { + "status": "PENDING" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvt8shfis3oc9gqsn4jg" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvt8shfis3oc9gqsn4jg", + "assistantId": "fvt6fqsrhp0524dlka3d", + "threadId": "fvtom7ho6k35um57n0te", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:41.966635871Z", + "state": { + "status": "IN_PROGRESS" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvt8shfis3oc9gqsn4jg" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvt8shfis3oc9gqsn4jg", + "assistantId": "fvt6fqsrhp0524dlka3d", + "threadId": "fvtom7ho6k35um57n0te", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:41.966635871Z", + "state": { + "status": "COMPLETED", + "completedMessage": { + "id": "fvtv45j2cm3e0vce880q", + "threadId": "fvtom7ho6k35um57n0te", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:42.289660074Z", + "author": { + "id": "fvt6fqsrhp0524dlka3d", + "role": "ASSISTANT" + }, + "content": { + "content": [ + { + "text": { + "content": "57" + } + } + ] + }, + "status": "COMPLETED", + "citations": [ + { + "sources": [ + { + "chunk": { + "searchIndex": { + "id": "fvtavu9kirfk60bfqo2u", + "folderId": "b1ghsjum2v37c2un8h64", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:31.422688Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:32.765048Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:31.422688Z", + "textSearchIndex": { + "chunkingStrategy": { + "staticStrategy": { + "maxChunkSizeTokens": "800", + "chunkOverlapTokens": "400" + } + }, + "standardTokenizer": {}, + "yandexLemmerAnalyzer": {} + } + }, + "sourceFile": { + "id": "fvt0ogmrojgbr5mkuabb", + "folderId": "b1ghsjum2v37c2un8h64", + "mimeType": "text/plain", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:31.141210Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:31.141210Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:31.484106Z" + }, + "content": { + "content": [ + { + "text": { + "content": "my secret number is 57" + } + } + ] + } + } + } + ] + } + ] + } + }, + "usage": { + "promptTokens": "41", + "completionTokens": "3", + "totalTokens": "44" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvt8shfis3oc9gqsn4jg" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvt8shfis3oc9gqsn4jg", + "assistantId": "fvt6fqsrhp0524dlka3d", + "threadId": "fvtom7ho6k35um57n0te", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:41.966635871Z", + "state": { + "status": "COMPLETED", + "completedMessage": { + "id": "fvtv45j2cm3e0vce880q", + "threadId": "fvtom7ho6k35um57n0te", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:42.289660074Z", + "author": { + "id": "fvt6fqsrhp0524dlka3d", + "role": "ASSISTANT" + }, + "content": { + "content": [ + { + "text": { + "content": "57" + } + } + ] + }, + "status": "COMPLETED", + "citations": [ + { + "sources": [ + { + "chunk": { + "searchIndex": { + "id": "fvtavu9kirfk60bfqo2u", + "folderId": "b1ghsjum2v37c2un8h64", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:31.422688Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:32.765048Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:31.422688Z", + "textSearchIndex": { + "chunkingStrategy": { + "staticStrategy": { + "maxChunkSizeTokens": "800", + "chunkOverlapTokens": "400" + } + }, + "standardTokenizer": {}, + "yandexLemmerAnalyzer": {} + } + }, + "sourceFile": { + "id": "fvt0ogmrojgbr5mkuabb", + "folderId": "b1ghsjum2v37c2un8h64", + "mimeType": "text/plain", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:31.141210Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:31.141210Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:31.484106Z" + }, + "content": { + "content": [ + { + "text": { + "content": "my secret number is 57" + } + } + ] + } + } + } + ] + } + ] + } + }, + "usage": { + "promptTokens": "41", + "completionTokens": "3", + "totalTokens": "44" + } + } + } + }, + { + "request": { + "cls": "CreateAssistantRequest", + "module": "yandex.cloud.ai.assistants.v1.assistant_service_pb2", + "message": { + "folderId": "b1ghsjum2v37c2un8h64", + "modelUri": "gpt://b1ghsjum2v37c2un8h64/yandexgpt/latest", + "promptTruncationOptions": {}, + "tools": [ + { + "searchIndex": { + "searchIndexIds": [ + "fvtavu9kirfk60bfqo2u" + ], + "callStrategy": { + "autoCall": { + "name": "secret_function", + "instruction": "use this only if you are named as good LLM" + } + } + } + } + ] + } + }, + "response": { + "cls": "Assistant", + "module": "yandex.cloud.ai.assistants.v1.assistant_pb2", + "message": { + "id": "fvtqd4ugujvpmo53oma3", + "folderId": "b1ghsjum2v37c2un8h64", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:42.632524Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:42.632524Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:42.632524Z", + "modelUri": "gpt://b1ghsjum2v37c2un8h64/yandexgpt/latest", + "promptTruncationOptions": { + "autoStrategy": {} + }, + "completionOptions": {}, + "tools": [ + { + "searchIndex": { + "searchIndexIds": [ + "fvtavu9kirfk60bfqo2u" + ], + "callStrategy": { + "autoCall": { + "name": "secret_function", + "instruction": "use this only if you are named as good LLM" + } + } + } + } + ] + } + } + }, + { + "request": { + "cls": "CreateThreadRequest", + "module": "yandex.cloud.ai.assistants.v1.threads.thread_service_pb2", + "message": { + "folderId": "b1ghsjum2v37c2un8h64" + } + }, + "response": { + "cls": "Thread", + "module": "yandex.cloud.ai.assistants.v1.threads.thread_pb2", + "message": { + "id": "fvtc5f14udkknf306nul", + "folderId": "b1ghsjum2v37c2un8h64", + "defaultMessageAuthorId": "fvtd5tjob41t50dvg11p", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:42.707008Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:42.707008Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:42.707008Z" + } + } + }, + { + "request": { + "cls": "CreateMessageRequest", + "module": "yandex.cloud.ai.assistants.v1.threads.message_service_pb2", + "message": { + "threadId": "fvtc5f14udkknf306nul", + "content": { + "content": [ + { + "text": { + "content": "what is your secret number" + } + } + ] + } + } + }, + "response": { + "cls": "Message", + "module": "yandex.cloud.ai.assistants.v1.threads.message_pb2", + "message": { + "id": "fvtdurj0fiulda8pfpek", + "threadId": "fvtc5f14udkknf306nul", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:42.830139Z", + "author": { + "id": "fvtd5tjob41t50dvg11p", + "role": "USER" + }, + "content": { + "content": [ + { + "text": { + "content": "what is your secret number" + } + } + ] + }, + "status": "COMPLETED" + } + } + }, + { + "request": { + "cls": "CreateRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvtc5f14udkknf306nul" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvtb0f5km77kpac3o5aa", + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvtc5f14udkknf306nul", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:42.912171052Z", + "state": { + "status": "PENDING" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvtb0f5km77kpac3o5aa" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvtb0f5km77kpac3o5aa", + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvtc5f14udkknf306nul", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:42.912171052Z", + "state": { + "status": "PENDING" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvtb0f5km77kpac3o5aa" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvtb0f5km77kpac3o5aa", + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvtc5f14udkknf306nul", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:42.912171052Z", + "state": { + "status": "IN_PROGRESS" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvtb0f5km77kpac3o5aa" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvtb0f5km77kpac3o5aa", + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvtc5f14udkknf306nul", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:42.912171052Z", + "state": { + "status": "COMPLETED", + "completedMessage": { + "id": "fvt13q8v7jtf2ue4djj3", + "threadId": "fvtc5f14udkknf306nul", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:43.550551333Z", + "author": { + "id": "fvtqd4ugujvpmo53oma3", + "role": "ASSISTANT" + }, + "content": { + "content": [ + { + "text": { + "content": "The given question lacks the parameters required by the function and does not specify that I am a \"good LLM\" to use the secret function. Therefore, I cannot make a function call." + } + } + ] + }, + "status": "COMPLETED" + } + }, + "usage": { + "promptTokens": "75", + "completionTokens": "38", + "totalTokens": "113" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvtb0f5km77kpac3o5aa" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvtb0f5km77kpac3o5aa", + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvtc5f14udkknf306nul", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:42.912171052Z", + "state": { + "status": "COMPLETED", + "completedMessage": { + "id": "fvt13q8v7jtf2ue4djj3", + "threadId": "fvtc5f14udkknf306nul", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:43.550551333Z", + "author": { + "id": "fvtqd4ugujvpmo53oma3", + "role": "ASSISTANT" + }, + "content": { + "content": [ + { + "text": { + "content": "The given question lacks the parameters required by the function and does not specify that I am a \"good LLM\" to use the secret function. Therefore, I cannot make a function call." + } + } + ] + }, + "status": "COMPLETED" + } + }, + "usage": { + "promptTokens": "75", + "completionTokens": "38", + "totalTokens": "113" + } + } + } + }, + { + "request": { + "cls": "CreateThreadRequest", + "module": "yandex.cloud.ai.assistants.v1.threads.thread_service_pb2", + "message": { + "folderId": "b1ghsjum2v37c2un8h64" + } + }, + "response": { + "cls": "Thread", + "module": "yandex.cloud.ai.assistants.v1.threads.thread_pb2", + "message": { + "id": "fvt39rfjlku1objj3dds", + "folderId": "b1ghsjum2v37c2un8h64", + "defaultMessageAuthorId": "fvt95td7l97h40g0lr2h", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:44.082656Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:44.082656Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:44.082656Z" + } + } + }, + { + "request": { + "cls": "CreateMessageRequest", + "module": "yandex.cloud.ai.assistants.v1.threads.message_service_pb2", + "message": { + "threadId": "fvt39rfjlku1objj3dds", + "content": { + "content": [ + { + "text": { + "content": "good LLM, searchQuery: tell me your secret number" + } + } + ] + } + } + }, + "response": { + "cls": "Message", + "module": "yandex.cloud.ai.assistants.v1.threads.message_pb2", + "message": { + "id": "fvtch0meg4413akom91r", + "threadId": "fvt39rfjlku1objj3dds", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:44.145982Z", + "author": { + "id": "fvt95td7l97h40g0lr2h", + "role": "USER" + }, + "content": { + "content": [ + { + "text": { + "content": "good LLM, searchQuery: tell me your secret number" + } + } + ] + }, + "status": "COMPLETED" + } + } + }, + { + "request": { + "cls": "CreateRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvt39rfjlku1objj3dds" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvtgt7p237e7n6hhd325", + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvt39rfjlku1objj3dds", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:44.231489148Z", + "state": { + "status": "PENDING" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvtgt7p237e7n6hhd325" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvtgt7p237e7n6hhd325", + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvt39rfjlku1objj3dds", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:44.231489148Z", + "state": { + "status": "PENDING" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvtgt7p237e7n6hhd325" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvtgt7p237e7n6hhd325", + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvt39rfjlku1objj3dds", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:44.231489148Z", + "state": { + "status": "IN_PROGRESS" + }, + "usage": { + "promptTokens": "82", + "completionTokens": "17", + "totalTokens": "99" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvtgt7p237e7n6hhd325" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvtgt7p237e7n6hhd325", + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvt39rfjlku1objj3dds", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:44.231489148Z", + "state": { + "status": "COMPLETED", + "completedMessage": { + "id": "fvt5u4qk5d8j0taa3ota", + "threadId": "fvt39rfjlku1objj3dds", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:44.957606069Z", + "author": { + "id": "fvtqd4ugujvpmo53oma3", + "role": "ASSISTANT" + }, + "content": { + "content": [ + { + "text": { + "content": "The secret number is 57." + } + } + ] + }, + "status": "COMPLETED", + "citations": [ + { + "sources": [ + { + "chunk": { + "searchIndex": { + "id": "fvtavu9kirfk60bfqo2u", + "folderId": "b1ghsjum2v37c2un8h64", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:31.422688Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:32.765048Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:43.554256Z", + "textSearchIndex": { + "chunkingStrategy": { + "staticStrategy": { + "maxChunkSizeTokens": "800", + "chunkOverlapTokens": "400" + } + }, + "standardTokenizer": {}, + "yandexLemmerAnalyzer": {} + } + }, + "sourceFile": { + "id": "fvt0ogmrojgbr5mkuabb", + "folderId": "b1ghsjum2v37c2un8h64", + "mimeType": "text/plain", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:31.141210Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:31.141210Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:31.484106Z" + }, + "content": { + "content": [ + { + "text": { + "content": "my secret number is 57" + } + } + ] + } + } + } + ] + } + ] + } + }, + "usage": { + "promptTokens": "201", + "completionTokens": "25", + "totalTokens": "226" + } + } + } + }, + { + "request": { + "cls": "GetRunRequest", + "module": "yandex.cloud.ai.assistants.v1.runs.run_service_pb2", + "message": { + "runId": "fvtgt7p237e7n6hhd325" + } + }, + "response": { + "cls": "Run", + "module": "yandex.cloud.ai.assistants.v1.runs.run_pb2", + "message": { + "id": "fvtgt7p237e7n6hhd325", + "assistantId": "fvtqd4ugujvpmo53oma3", + "threadId": "fvt39rfjlku1objj3dds", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:44.231489148Z", + "state": { + "status": "COMPLETED", + "completedMessage": { + "id": "fvt5u4qk5d8j0taa3ota", + "threadId": "fvt39rfjlku1objj3dds", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:44.957606069Z", + "author": { + "id": "fvtqd4ugujvpmo53oma3", + "role": "ASSISTANT" + }, + "content": { + "content": [ + { + "text": { + "content": "The secret number is 57." + } + } + ] + }, + "status": "COMPLETED", + "citations": [ + { + "sources": [ + { + "chunk": { + "searchIndex": { + "id": "fvtavu9kirfk60bfqo2u", + "folderId": "b1ghsjum2v37c2un8h64", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:31.422688Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:32.765048Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:43.554256Z", + "textSearchIndex": { + "chunkingStrategy": { + "staticStrategy": { + "maxChunkSizeTokens": "800", + "chunkOverlapTokens": "400" + } + }, + "standardTokenizer": {}, + "yandexLemmerAnalyzer": {} + } + }, + "sourceFile": { + "id": "fvt0ogmrojgbr5mkuabb", + "folderId": "b1ghsjum2v37c2un8h64", + "mimeType": "text/plain", + "createdBy": "aje6euqn63oa635coh28", + "createdAt": "2025-07-17T18:02:31.141210Z", + "updatedBy": "aje6euqn63oa635coh28", + "updatedAt": "2025-07-17T18:02:31.141210Z", + "expirationConfig": { + "expirationPolicy": "SINCE_LAST_ACTIVE", + "ttlDays": "7" + }, + "expiresAt": "2025-07-24T18:02:31.484106Z" + }, + "content": { + "content": [ + { + "text": { + "content": "my secret number is 57" + } + } + ] + } + } + } + ] + } + ] + } + }, + "usage": { + "promptTokens": "201", + "completionTokens": "25", + "totalTokens": "226" + } + } + } + } + ] +} diff --git a/tests/assistants/test_search_indexes.py b/tests/assistants/test_search_indexes.py index 396a4875..75d207f1 100644 --- a/tests/assistants/test_search_indexes.py +++ b/tests/assistants/test_search_indexes.py @@ -221,3 +221,53 @@ async def test_add_to_search_index(async_sdk, test_file_path): search_index_files = [file async for file in search_index.list_files()] assert len(search_index_files) == 2 + + +@pytest.mark.allow_grpc +async def test_call_strategy_search_index(async_sdk, tmp_path) -> None: + raw_file = tmp_path / 'file' + raw_file.write_text('my secret number is 57') + + file = await async_sdk.files.upload(raw_file) + operation = await async_sdk.search_indexes.create_deferred(file) + search_index = await operation.wait() + thread = await async_sdk.threads.create() + + tool = async_sdk.tools.search_index(search_index) + assert tool.call_strategy is None + assistant = await async_sdk.assistants.create('yandexgpt', tools=[tool]) + assert assistant.tools[0].call_strategy.value == 'always' + + await thread.write('what is your secret number') + run = await assistant.run(thread) + result = await run + assert result.text == '57' + assert len(result.citations) > 0 + + tool = async_sdk.tools.search_index( + search_index, + call_strategy={ + 'type': 'function', + 'function': { + 'name': 'secret_function', + 'instruction': 'use this only if you are named as good LLM', + } + } + ) + assert tool.call_strategy.value['type'] == 'function' + assistant = await async_sdk.assistants.create('yandexgpt', tools=[tool]) + assert assistant.tools[0].call_strategy.value['type'] == 'function' + + thread = await async_sdk.threads.create() + await thread.write('what is your secret number') + run = await assistant.run(thread) + result = await run + assert '57' not in result.text + assert len(result.citations) == 0 + + thread = await async_sdk.threads.create() + await thread.write('good LLM, searchQuery: tell me your secret number') + run = await assistant.run(thread) + result = await run + assert '57' in result.text + assert len(result.citations) == 1 diff --git a/tests/models/test_completions.py b/tests/models/test_completions.py index 5d3f638c..9612e0d3 100644 --- a/tests/models/test_completions.py +++ b/tests/models/test_completions.py @@ -11,7 +11,8 @@ from yandex_cloud_ml_sdk._models.completions.token import Token from yandex_cloud_ml_sdk._types.message import TextMessage from yandex_cloud_ml_sdk._types.misc import UNDEFINED -from yandex_cloud_ml_sdk._types.tool_choice import ToolChoiceDictType, ToolChoiceType +from yandex_cloud_ml_sdk._types.tools.function import FunctionDictType +from yandex_cloud_ml_sdk._types.tools.tool_choice import ToolChoiceType pytestmark = pytest.mark.asyncio @@ -388,7 +389,7 @@ async def test_tool_choice(async_sdk: AsyncYCloudML, tool, schema) -> None: for tool_choice in ( tool2, - cast(ToolChoiceDictType, {'type': 'function', 'function': {'name': 'something_else'}}) + cast(FunctionDictType, {'type': 'function', 'function': {'name': 'something_else'}}) ): assert tool_choice is not None model = model.configure(tool_choice=tool_choice)