Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions examples/async/assistants/search_index_call_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/usr/bin/env python3

from __future__ import annotations

import asyncio
import pathlib

from yandex_cloud_ml_sdk import AsyncYCloudML

LABEL_KEY = 'yc-ml-sdk-example'
PATH = pathlib.Path(__file__)
NAME = f'example-{PATH.parent.name}-{PATH.name}'
LABELS = {LABEL_KEY: NAME}


def local_path(path: str) -> pathlib.Path:
return pathlib.Path(__file__).parent / path


async def get_search_index(sdk):
"""
This function represents getting or creating demo search_index object.

In real life you will get it any other way that would suit your case.
"""

async for search_index in sdk.search_indexes.list():
if search_index.labels and search_index.labels.get(LABEL_KEY) == NAME:
print(f'using {search_index=}')
break
else:
print('no search indexes found, creating new one')
file_coros = (
sdk.files.upload(
local_path(path),
ttl_days=5,
expiration_policy="static",
)
for path in ['turkey_example.txt', 'maldives_example.txt']
)
files = await asyncio.gather(*file_coros)
operation = await sdk.search_indexes.create_deferred(files, labels=LABELS)
search_index = await operation
print(f'new {search_index=}')

for file in files:
await file.delete()

return search_index


async def delete_labeled_entities(iterator):
"""
Deletes any entities from given iterator which have .labels attribute
with `labels[LABEL_KEY] == NAME`
"""

async for entity in iterator:
if entity.labels and entity.labels.get(LABEL_KEY) == NAME:
print(f'deleting {entity.__class__.__name__} with id={entity.id!r}')
await entity.delete()


async def main() -> None:
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
sdk.setup_default_logging(log_level='WARNING')

search_index = await get_search_index(sdk)
thread = await sdk.threads.create(labels=LABELS)

tool = sdk.tools.search_index(search_index)
assistant = await sdk.assistants.create('yandexgpt', tools=[tool], labels=LABELS)

# Look, if you don't pass a call strategy to a SearchIndex, it is 'always' use by-default
assert tool.call_strategy is None
assert assistant.tools[0].call_strategy.value == 'always' # type: ignore[attr-defined]

# First of all we are using request which will definitely find something
search_query = local_path('search_query.txt').read_text().splitlines()[0]
await thread.write(search_query)
run = await assistant.run(thread)
result = await run.wait()
# NB: citations says if index were used or not
assert len(result.citations) > 0
print(f'If you are using "always" call_strategy, it returns {len(result.citations)>0=} citations from search index')

# Now we will use a search index, which will be used only if it asked to
tool_with_call_strategy = sdk.tools.search_index(
search_index,
call_strategy={
'type': 'function',
'function': {'name': 'guide', 'instruction': 'use this only if you are asked to look in the guide'}
}
)
assistant_with_call_strategy = await sdk.assistants.create(
sdk.models.completions('yandexgpt', model_version='rc'),
tools=[tool_with_call_strategy],
labels=LABELS
)

await thread.write(search_query)
run = await assistant_with_call_strategy.run(thread)
result = await run.wait()
# NB: citations says if index were used or not
assert len(result.citations) == 0
print(
'When you are using special call_strategy and model decides not to use search index according '
f'to call_strategy instruction, it returns {len(result.citations)>0=} citations from search index'
)

await thread.write(f"Look at the guide, please: {search_query}")
run = await assistant_with_call_strategy.run(thread)
result = await run.wait()
# NB: citations says if index were used or not
assert len(result.citations) > 0
print(
'When you are using special call_strategy and model decides to use search index according '
f'to call_strategy instruction, it returns {len(result.citations)>0=} from search index'
)

# we will delete all assistant and threads created in this example
# to not to increase chaos level, but not the search index, because
# index creation is a slow operation and could be re-used in this
# example next run
await delete_labeled_entities(sdk.assistants.list())
await delete_labeled_entities(sdk.threads.list())


if __name__ == '__main__':
asyncio.run(main())
128 changes: 128 additions & 0 deletions examples/sync/assistants/search_index_call_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python3

from __future__ import annotations

import pathlib

from yandex_cloud_ml_sdk import YCloudML

LABEL_KEY = 'yc-ml-sdk-example'
PATH = pathlib.Path(__file__)
NAME = f'example-{PATH.parent.name}-{PATH.name}'
LABELS = {LABEL_KEY: NAME}


def local_path(path: str) -> pathlib.Path:
return pathlib.Path(__file__).parent / path


def get_search_index(sdk):
"""
This function represents getting or creating demo search_index object.

In real life you will get it any other way that would suit your case.
"""

for search_index in sdk.search_indexes.list():
if search_index.labels and search_index.labels.get(LABEL_KEY) == NAME:
print(f'using {search_index=}')
break
else:
print('no search indexes found, creating new one')
files = [
sdk.files.upload(
local_path(path),
ttl_days=5,
expiration_policy="static",
)
for path in ['turkey_example.txt', 'maldives_example.txt']
]
operation = sdk.search_indexes.create_deferred(files, labels=LABELS)
search_index = operation
print(f'new {search_index=}')

for file in files:
file.delete()

return search_index


def delete_labeled_entities(iterator):
"""
Deletes any entities from given iterator which have .labels attribute
with `labels[LABEL_KEY] == NAME`
"""

for entity in iterator:
if entity.labels and entity.labels.get(LABEL_KEY) == NAME:
print(f'deleting {entity.__class__.__name__} with id={entity.id!r}')
entity.delete()


def main() -> None:
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
sdk.setup_default_logging(log_level='WARNING')

search_index = get_search_index(sdk)
thread = sdk.threads.create(labels=LABELS)

tool = sdk.tools.search_index(search_index)
assistant = sdk.assistants.create('yandexgpt', tools=[tool], labels=LABELS)

# Look, if you don't pass a call strategy to a SearchIndex, it is 'always' use by-default
assert tool.call_strategy is None
assert assistant.tools[0].call_strategy.value == 'always' # type: ignore[attr-defined]

# First of all we are using request which will definitely find something
search_query = local_path('search_query.txt').read_text().splitlines()[0]
thread.write(search_query)
run = assistant.run(thread)
result = run.wait()
# NB: citations says if index were used or not
assert len(result.citations) > 0
print(f'If you are using "always" call_strategy, it returns {len(result.citations)>0=} citations from search index')

# Now we will use a search index, which will be used only if it asked to
tool_with_call_strategy = sdk.tools.search_index(
search_index,
call_strategy={
'type': 'function',
'function': {'name': 'guide', 'instruction': 'use this only if you are asked to look in the guide'}
}
)
assistant_with_call_strategy = sdk.assistants.create(
sdk.models.completions('yandexgpt', model_version='rc'),
tools=[tool_with_call_strategy],
labels=LABELS
)

thread.write(search_query)
run = assistant_with_call_strategy.run(thread)
result = run.wait()
# NB: citations says if index were used or not
assert len(result.citations) == 0
print(
'When you are using special call_strategy and model decides not to use search index according '
f'to call_strategy instruction, it returns {len(result.citations)>0=} citations from search index'
)

thread.write(f"Look at the guide, please: {search_query}")
run = assistant_with_call_strategy.run(thread)
result = run.wait()
# NB: citations says if index were used or not
assert len(result.citations) > 0
print(
'When you are using special call_strategy and model decides to use search index according '
f'to call_strategy instruction, it returns {len(result.citations)>0=} from search index'
)

# we will delete all assistant and threads created in this example
# to not to increase chaos level, but not the search index, because
# index creation is a slow operation and could be re-used in this
# example next run
delete_labeled_entities(sdk.assistants.list())
delete_labeled_entities(sdk.threads.list())


if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion src/yandex_cloud_ml_sdk/_models/completions/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from yandex_cloud_ml_sdk._tools.tool import FunctionTool
from yandex_cloud_ml_sdk._types.model_config import BaseModelConfig
from yandex_cloud_ml_sdk._types.schemas import ResponseType
from yandex_cloud_ml_sdk._types.tool_choice import ToolChoiceType
from yandex_cloud_ml_sdk._types.tools.tool_choice import ToolChoiceType
from yandex_cloud_ml_sdk._utils.proto import ProtoEnumBase

_m = ProtoReasoningOptions.ReasoningMode
Expand Down
4 changes: 2 additions & 2 deletions src/yandex_cloud_ml_sdk/_models/completions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
)
from yandex_cloud_ml_sdk._types.operation import AsyncOperation, Operation
from yandex_cloud_ml_sdk._types.schemas import ResponseType, make_response_format_kwargs
from yandex_cloud_ml_sdk._types.tool_choice import ToolChoiceType
from yandex_cloud_ml_sdk._types.tool_choice import coerce_to_proto as coerce_to_proto_tool_choice
from yandex_cloud_ml_sdk._types.tools.tool_choice import ToolChoiceType
from yandex_cloud_ml_sdk._types.tools.tool_choice import coerce_to_proto as coerce_to_proto_tool_choice
from yandex_cloud_ml_sdk._types.tuning.datasets import TuningDatasetsType
from yandex_cloud_ml_sdk._types.tuning.optimizers import BaseOptimizer
from yandex_cloud_ml_sdk._types.tuning.schedulers import BaseScheduler
Expand Down
11 changes: 9 additions & 2 deletions src/yandex_cloud_ml_sdk/_tools/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from yandex_cloud_ml_sdk._utils.coerce import ResourceType, coerce_resource_ids

from .function import AsyncFunctionTools, FunctionTools, FunctionToolsTypeT
from .rephraser.function import RephraserFunction, RephraserInputType
from .tool import SearchIndexTool
from .search_index.call_strategy import CallStrategy, CallStrategyInputType
from .search_index.rephraser.function import RephraserFunction, RephraserInputType
from .search_index.tool import SearchIndexTool


class BaseTools(BaseDomain, Generic[FunctionToolsTypeT]):
Expand All @@ -38,6 +39,7 @@ def search_index(
*,
max_num_results: UndefinedOr[int] = UNDEFINED,
rephraser: UndefinedOr[RephraserInputType] = UNDEFINED,
call_strategy: UndefinedOr[CallStrategyInputType] = UNDEFINED,
) -> SearchIndexTool:
"""Creates SearchIndexTool (not to be confused with :py:class:`~.SearchIndex`).

Expand All @@ -58,10 +60,15 @@ def search_index(
# this is coercing any RephraserInputType to Rephraser
rephraser_ = self.rephraser(rephraser) # type: ignore[arg-type]

call_strategy_ = None
if is_defined(call_strategy):
call_strategy_ = CallStrategy._coerce(call_strategy)

return SearchIndexTool(
search_index_ids=tuple(index_ids),
max_num_results=max_num_results_,
rephraser=rephraser_,
call_strategy=call_strategy_,
)


Expand Down
Loading