Skip to content

Commit c21e070

Browse files
authored
Add call_strategy to search index (#128)
1 parent 54e9aad commit c21e070

File tree

18 files changed

+2132
-88
lines changed

18 files changed

+2132
-88
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import pathlib
7+
8+
from yandex_cloud_ml_sdk import AsyncYCloudML
9+
10+
LABEL_KEY = 'yc-ml-sdk-example'
11+
PATH = pathlib.Path(__file__)
12+
NAME = f'example-{PATH.parent.name}-{PATH.name}'
13+
LABELS = {LABEL_KEY: NAME}
14+
15+
16+
def local_path(path: str) -> pathlib.Path:
17+
return pathlib.Path(__file__).parent / path
18+
19+
20+
async def get_search_index(sdk):
21+
"""
22+
This function represents getting or creating demo search_index object.
23+
24+
In real life you will get it any other way that would suit your case.
25+
"""
26+
27+
async for search_index in sdk.search_indexes.list():
28+
if search_index.labels and search_index.labels.get(LABEL_KEY) == NAME:
29+
print(f'using {search_index=}')
30+
break
31+
else:
32+
print('no search indexes found, creating new one')
33+
file_coros = (
34+
sdk.files.upload(
35+
local_path(path),
36+
ttl_days=5,
37+
expiration_policy="static",
38+
)
39+
for path in ['turkey_example.txt', 'maldives_example.txt']
40+
)
41+
files = await asyncio.gather(*file_coros)
42+
operation = await sdk.search_indexes.create_deferred(files, labels=LABELS)
43+
search_index = await operation
44+
print(f'new {search_index=}')
45+
46+
for file in files:
47+
await file.delete()
48+
49+
return search_index
50+
51+
52+
async def delete_labeled_entities(iterator):
53+
"""
54+
Deletes any entities from given iterator which have .labels attribute
55+
with `labels[LABEL_KEY] == NAME`
56+
"""
57+
58+
async for entity in iterator:
59+
if entity.labels and entity.labels.get(LABEL_KEY) == NAME:
60+
print(f'deleting {entity.__class__.__name__} with id={entity.id!r}')
61+
await entity.delete()
62+
63+
64+
async def main() -> None:
65+
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
66+
sdk.setup_default_logging(log_level='WARNING')
67+
68+
search_index = await get_search_index(sdk)
69+
thread = await sdk.threads.create(labels=LABELS)
70+
71+
tool = sdk.tools.search_index(search_index)
72+
assistant = await sdk.assistants.create('yandexgpt', tools=[tool], labels=LABELS)
73+
74+
# Look, if you don't pass a call strategy to a SearchIndex, it is 'always' use by-default
75+
assert tool.call_strategy is None
76+
assert assistant.tools[0].call_strategy.value == 'always' # type: ignore[attr-defined]
77+
78+
# First of all we are using request which will definitely find something
79+
search_query = local_path('search_query.txt').read_text().splitlines()[0]
80+
await thread.write(search_query)
81+
run = await assistant.run(thread)
82+
result = await run.wait()
83+
# NB: citations says if index were used or not
84+
assert len(result.citations) > 0
85+
print(f'If you are using "always" call_strategy, it returns {len(result.citations)>0=} citations from search index')
86+
87+
# Now we will use a search index, which will be used only if it asked to
88+
tool_with_call_strategy = sdk.tools.search_index(
89+
search_index,
90+
call_strategy={
91+
'type': 'function',
92+
'function': {'name': 'guide', 'instruction': 'use this only if you are asked to look in the guide'}
93+
}
94+
)
95+
assistant_with_call_strategy = await sdk.assistants.create(
96+
sdk.models.completions('yandexgpt', model_version='rc'),
97+
tools=[tool_with_call_strategy],
98+
labels=LABELS
99+
)
100+
101+
await thread.write(search_query)
102+
run = await assistant_with_call_strategy.run(thread)
103+
result = await run.wait()
104+
# NB: citations says if index were used or not
105+
assert len(result.citations) == 0
106+
print(
107+
'When you are using special call_strategy and model decides not to use search index according '
108+
f'to call_strategy instruction, it returns {len(result.citations)>0=} citations from search index'
109+
)
110+
111+
await thread.write(f"Look at the guide, please: {search_query}")
112+
run = await assistant_with_call_strategy.run(thread)
113+
result = await run.wait()
114+
# NB: citations says if index were used or not
115+
assert len(result.citations) > 0
116+
print(
117+
'When you are using special call_strategy and model decides to use search index according '
118+
f'to call_strategy instruction, it returns {len(result.citations)>0=} from search index'
119+
)
120+
121+
# we will delete all assistant and threads created in this example
122+
# to not to increase chaos level, but not the search index, because
123+
# index creation is a slow operation and could be re-used in this
124+
# example next run
125+
await delete_labeled_entities(sdk.assistants.list())
126+
await delete_labeled_entities(sdk.threads.list())
127+
128+
129+
if __name__ == '__main__':
130+
asyncio.run(main())
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import pathlib
6+
7+
from yandex_cloud_ml_sdk import YCloudML
8+
9+
LABEL_KEY = 'yc-ml-sdk-example'
10+
PATH = pathlib.Path(__file__)
11+
NAME = f'example-{PATH.parent.name}-{PATH.name}'
12+
LABELS = {LABEL_KEY: NAME}
13+
14+
15+
def local_path(path: str) -> pathlib.Path:
16+
return pathlib.Path(__file__).parent / path
17+
18+
19+
def get_search_index(sdk):
20+
"""
21+
This function represents getting or creating demo search_index object.
22+
23+
In real life you will get it any other way that would suit your case.
24+
"""
25+
26+
for search_index in sdk.search_indexes.list():
27+
if search_index.labels and search_index.labels.get(LABEL_KEY) == NAME:
28+
print(f'using {search_index=}')
29+
break
30+
else:
31+
print('no search indexes found, creating new one')
32+
files = [
33+
sdk.files.upload(
34+
local_path(path),
35+
ttl_days=5,
36+
expiration_policy="static",
37+
)
38+
for path in ['turkey_example.txt', 'maldives_example.txt']
39+
]
40+
operation = sdk.search_indexes.create_deferred(files, labels=LABELS)
41+
search_index = operation
42+
print(f'new {search_index=}')
43+
44+
for file in files:
45+
file.delete()
46+
47+
return search_index
48+
49+
50+
def delete_labeled_entities(iterator):
51+
"""
52+
Deletes any entities from given iterator which have .labels attribute
53+
with `labels[LABEL_KEY] == NAME`
54+
"""
55+
56+
for entity in iterator:
57+
if entity.labels and entity.labels.get(LABEL_KEY) == NAME:
58+
print(f'deleting {entity.__class__.__name__} with id={entity.id!r}')
59+
entity.delete()
60+
61+
62+
def main() -> None:
63+
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
64+
sdk.setup_default_logging(log_level='WARNING')
65+
66+
search_index = get_search_index(sdk)
67+
thread = sdk.threads.create(labels=LABELS)
68+
69+
tool = sdk.tools.search_index(search_index)
70+
assistant = sdk.assistants.create('yandexgpt', tools=[tool], labels=LABELS)
71+
72+
# Look, if you don't pass a call strategy to a SearchIndex, it is 'always' use by-default
73+
assert tool.call_strategy is None
74+
assert assistant.tools[0].call_strategy.value == 'always' # type: ignore[attr-defined]
75+
76+
# First of all we are using request which will definitely find something
77+
search_query = local_path('search_query.txt').read_text().splitlines()[0]
78+
thread.write(search_query)
79+
run = assistant.run(thread)
80+
result = run.wait()
81+
# NB: citations says if index were used or not
82+
assert len(result.citations) > 0
83+
print(f'If you are using "always" call_strategy, it returns {len(result.citations)>0=} citations from search index')
84+
85+
# Now we will use a search index, which will be used only if it asked to
86+
tool_with_call_strategy = sdk.tools.search_index(
87+
search_index,
88+
call_strategy={
89+
'type': 'function',
90+
'function': {'name': 'guide', 'instruction': 'use this only if you are asked to look in the guide'}
91+
}
92+
)
93+
assistant_with_call_strategy = sdk.assistants.create(
94+
sdk.models.completions('yandexgpt', model_version='rc'),
95+
tools=[tool_with_call_strategy],
96+
labels=LABELS
97+
)
98+
99+
thread.write(search_query)
100+
run = assistant_with_call_strategy.run(thread)
101+
result = run.wait()
102+
# NB: citations says if index were used or not
103+
assert len(result.citations) == 0
104+
print(
105+
'When you are using special call_strategy and model decides not to use search index according '
106+
f'to call_strategy instruction, it returns {len(result.citations)>0=} citations from search index'
107+
)
108+
109+
thread.write(f"Look at the guide, please: {search_query}")
110+
run = assistant_with_call_strategy.run(thread)
111+
result = run.wait()
112+
# NB: citations says if index were used or not
113+
assert len(result.citations) > 0
114+
print(
115+
'When you are using special call_strategy and model decides to use search index according '
116+
f'to call_strategy instruction, it returns {len(result.citations)>0=} from search index'
117+
)
118+
119+
# we will delete all assistant and threads created in this example
120+
# to not to increase chaos level, but not the search index, because
121+
# index creation is a slow operation and could be re-used in this
122+
# example next run
123+
delete_labeled_entities(sdk.assistants.list())
124+
delete_labeled_entities(sdk.threads.list())
125+
126+
127+
if __name__ == '__main__':
128+
main()

src/yandex_cloud_ml_sdk/_models/completions/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from yandex_cloud_ml_sdk._tools.tool import FunctionTool
1313
from yandex_cloud_ml_sdk._types.model_config import BaseModelConfig
1414
from yandex_cloud_ml_sdk._types.schemas import ResponseType
15-
from yandex_cloud_ml_sdk._types.tool_choice import ToolChoiceType
15+
from yandex_cloud_ml_sdk._types.tools.tool_choice import ToolChoiceType
1616
from yandex_cloud_ml_sdk._utils.proto import ProtoEnumBase
1717

1818
_m = ProtoReasoningOptions.ReasoningMode

src/yandex_cloud_ml_sdk/_models/completions/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
)
3131
from yandex_cloud_ml_sdk._types.operation import AsyncOperation, Operation
3232
from yandex_cloud_ml_sdk._types.schemas import ResponseType, make_response_format_kwargs
33-
from yandex_cloud_ml_sdk._types.tool_choice import ToolChoiceType
34-
from yandex_cloud_ml_sdk._types.tool_choice import coerce_to_proto as coerce_to_proto_tool_choice
33+
from yandex_cloud_ml_sdk._types.tools.tool_choice import ToolChoiceType
34+
from yandex_cloud_ml_sdk._types.tools.tool_choice import coerce_to_proto as coerce_to_proto_tool_choice
3535
from yandex_cloud_ml_sdk._types.tuning.datasets import TuningDatasetsType
3636
from yandex_cloud_ml_sdk._types.tuning.optimizers import BaseOptimizer
3737
from yandex_cloud_ml_sdk._types.tuning.schedulers import BaseScheduler

src/yandex_cloud_ml_sdk/_tools/domain.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from yandex_cloud_ml_sdk._utils.coerce import ResourceType, coerce_resource_ids
1111

1212
from .function import AsyncFunctionTools, FunctionTools, FunctionToolsTypeT
13-
from .rephraser.function import RephraserFunction, RephraserInputType
14-
from .tool import SearchIndexTool
13+
from .search_index.call_strategy import CallStrategy, CallStrategyInputType
14+
from .search_index.rephraser.function import RephraserFunction, RephraserInputType
15+
from .search_index.tool import SearchIndexTool
1516

1617

1718
class BaseTools(BaseDomain, Generic[FunctionToolsTypeT]):
@@ -38,6 +39,7 @@ def search_index(
3839
*,
3940
max_num_results: UndefinedOr[int] = UNDEFINED,
4041
rephraser: UndefinedOr[RephraserInputType] = UNDEFINED,
42+
call_strategy: UndefinedOr[CallStrategyInputType] = UNDEFINED,
4143
) -> SearchIndexTool:
4244
"""Creates SearchIndexTool (not to be confused with :py:class:`~.SearchIndex`).
4345
@@ -58,10 +60,15 @@ def search_index(
5860
# this is coercing any RephraserInputType to Rephraser
5961
rephraser_ = self.rephraser(rephraser) # type: ignore[arg-type]
6062

63+
call_strategy_ = None
64+
if is_defined(call_strategy):
65+
call_strategy_ = CallStrategy._coerce(call_strategy)
66+
6167
return SearchIndexTool(
6268
search_index_ids=tuple(index_ids),
6369
max_num_results=max_num_results_,
6470
rephraser=rephraser_,
71+
call_strategy=call_strategy_,
6572
)
6673

6774

src/yandex_cloud_ml_sdk/_tools/rephraser/__init__.py renamed to src/yandex_cloud_ml_sdk/_tools/search_index/__init__.py

File renamed without changes.

0 commit comments

Comments
 (0)