Skip to content

Commit ef23b5d

Browse files
authored
* Support for prompt truncation strategy (#82)
1 parent 9247b6d commit ef23b5d

24 files changed

+3920
-994
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
7+
from yandex_cloud_ml_sdk import AsyncYCloudML
8+
from yandex_cloud_ml_sdk.assistants import AutoPromptTruncationStrategy, LastMessagesPromptTruncationStrategy
9+
10+
LABEL_KEY = 'yc-ml-sdk-example'
11+
LABEL_VALUE = 'prompt-truncation-options'
12+
13+
14+
async def new_thread(sdk):
15+
thread = await sdk.threads.create(labels={LABEL_KEY: LABEL_VALUE})
16+
await thread.write('hey, how are you?')
17+
await thread.write('what is your name?')
18+
return thread
19+
20+
21+
async def delete_labeled_entities(iterator):
22+
async for entity in iterator:
23+
if entity.labels and entity.labels.get(LABEL_KEY) == LABEL_VALUE:
24+
print(f'deleting {entity.__class__.__name__} with id={entity.id!r}')
25+
await entity.delete()
26+
27+
28+
async def main() -> None:
29+
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
30+
sdk.setup_default_logging()
31+
32+
assistant = await sdk.assistants.create(
33+
'yandexgpt',
34+
labels={LABEL_KEY: LABEL_VALUE},
35+
# you could choose value for max_prompt_tokens, default value
36+
# is 7000 by the time I'm making this example
37+
max_prompt_tokens=500,
38+
# default prompt truncation strategy is AutoPromptTruncationStrategy, you could
39+
# change it as well
40+
prompt_truncation_strategy=LastMessagesPromptTruncationStrategy(num_messages=10),
41+
)
42+
43+
thread = await new_thread(sdk)
44+
# You could also override prompt trunction options vis custom_* run() parameters:
45+
run = await assistant.run(
46+
thread,
47+
custom_max_prompt_tokens=1,
48+
custom_prompt_truncation_strategy=AutoPromptTruncationStrategy()
49+
)
50+
result = await run
51+
# This run should be failed because of custom_max_prompt_tokens=1
52+
assert result.is_failed
53+
print(f'{result.error=}')
54+
55+
thread = await new_thread(sdk)
56+
run = await assistant.run(
57+
thread,
58+
custom_prompt_truncation_strategy=LastMessagesPromptTruncationStrategy(num_messages=1)
59+
)
60+
result = await run
61+
assert result.usage
62+
one_message_input_tokens = result.usage.input_text_tokens
63+
64+
thread = await new_thread(sdk)
65+
# NB: 'auto' is a shortcut for AutoPromptTruncationStrategy
66+
run = await assistant.run(thread, custom_prompt_truncation_strategy='auto')
67+
result = await run
68+
assert result.usage
69+
two_message_input_tokens = result.usage.input_text_tokens
70+
71+
print('Input tokens used with LastMessagesPromptTruncationStrategy(1) < AutoPromptTruncationStrategy():')
72+
print(f' {one_message_input_tokens} < {two_message_input_tokens}')
73+
74+
await delete_labeled_entities(sdk.assistants.list())
75+
await delete_labeled_entities(sdk.threads.list())
76+
77+
78+
if __name__ == '__main__':
79+
asyncio.run(main())
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
from yandex_cloud_ml_sdk import YCloudML
6+
from yandex_cloud_ml_sdk.assistants import AutoPromptTruncationStrategy, LastMessagesPromptTruncationStrategy
7+
8+
LABEL_KEY = 'yc-ml-sdk-example'
9+
LABEL_VALUE = 'prompt-truncation-options'
10+
11+
12+
def new_thread(sdk):
13+
thread = sdk.threads.create(labels={LABEL_KEY: LABEL_VALUE})
14+
thread.write('hey, how are you?')
15+
thread.write('what is your name?')
16+
return thread
17+
18+
19+
def delete_labeled_entities(iterator):
20+
for entity in iterator:
21+
if entity.labels and entity.labels.get(LABEL_KEY) == LABEL_VALUE:
22+
print(f'deleting {entity.__class__.__name__} with id={entity.id!r}')
23+
entity.delete()
24+
25+
26+
def main() -> None:
27+
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
28+
sdk.setup_default_logging()
29+
30+
assistant = sdk.assistants.create(
31+
'yandexgpt',
32+
labels={LABEL_KEY: LABEL_VALUE},
33+
# you could choose value for max_prompt_tokens, default value
34+
# is 7000 by the time I'm making this example
35+
max_prompt_tokens=500,
36+
# default prompt truncation strategy is AutoPromptTruncationStrategy, you could
37+
# change it as well
38+
prompt_truncation_strategy=LastMessagesPromptTruncationStrategy(num_messages=10),
39+
)
40+
41+
thread = new_thread(sdk)
42+
# You could also override prompt trunction options vis custom_* run() parameters:
43+
run = assistant.run(
44+
thread,
45+
custom_max_prompt_tokens=1,
46+
custom_prompt_truncation_strategy=AutoPromptTruncationStrategy()
47+
)
48+
result = run.wait()
49+
# This run should be failed because of custom_max_prompt_tokens=1
50+
assert result.is_failed
51+
print(f'{result.error=}')
52+
53+
thread = new_thread(sdk)
54+
run = assistant.run(
55+
thread,
56+
custom_prompt_truncation_strategy=LastMessagesPromptTruncationStrategy(num_messages=1)
57+
)
58+
result = run.wait()
59+
assert result.usage
60+
one_message_input_tokens = result.usage.input_text_tokens
61+
62+
thread = new_thread(sdk)
63+
# NB: 'auto' is a shortcut for AutoPromptTruncationStrategy
64+
run = assistant.run(thread, custom_prompt_truncation_strategy='auto')
65+
result = run.wait()
66+
assert result.usage
67+
two_message_input_tokens = result.usage.input_text_tokens
68+
69+
print('Input tokens used with LastMessagesPromptTruncationStrategy(1) < AutoPromptTruncationStrategy():')
70+
print(f' {one_message_input_tokens} < {two_message_input_tokens}')
71+
72+
delete_labeled_entities(sdk.assistants.list())
73+
delete_labeled_entities(sdk.threads.list())
74+
75+
76+
if __name__ == '__main__':
77+
main()

src/yandex_cloud_ml_sdk/_assistants/assistant.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from yandex_cloud_ml_sdk._utils.coerce import coerce_tuple
2626
from yandex_cloud_ml_sdk._utils.sync import run_sync_generator_impl, run_sync_impl
2727

28-
from .utils import get_completion_options, get_prompt_trunctation_options
28+
from .prompt_truncation_options import PromptTruncationOptions, PromptTruncationStrategyType
29+
from .utils import get_completion_options
2930

3031
if TYPE_CHECKING:
3132
from yandex_cloud_ml_sdk._sdk import BaseSDK
@@ -36,9 +37,13 @@ class BaseAssistant(ExpirableResource, Generic[RunTypeT, ThreadTypeT]):
3637
expiration_config: ExpirationConfig
3738
model: BaseGPTModel
3839
instruction: str | None
39-
max_prompt_tokens: int | None
40+
prompt_truncation_options: PromptTruncationOptions
4041
tools: tuple[BaseTool, ...]
4142

43+
@property
44+
def max_prompt_tokens(self) -> int | None:
45+
return self.prompt_truncation_options.max_prompt_tokens
46+
4247
@classmethod
4348
def _kwargs_from_message(cls, proto: ProtoAssistant, sdk: BaseSDK) -> dict[str, Any]: # type: ignore[override]
4449
kwargs = super()._kwargs_from_message(proto, sdk=sdk)
@@ -55,9 +60,10 @@ def _kwargs_from_message(cls, proto: ProtoAssistant, sdk: BaseSDK) -> dict[str,
5560
BaseTool._from_upper_proto(tool, sdk=sdk)
5661
for tool in proto.tools
5762
)
58-
59-
if max_prompt_tokens := proto.prompt_truncation_options.max_prompt_tokens.value:
60-
kwargs['max_prompt_tokens'] = max_prompt_tokens
63+
kwargs['prompt_truncation_options'] = PromptTruncationOptions._from_proto(
64+
proto=proto.prompt_truncation_options,
65+
sdk=sdk
66+
)
6167

6268
return kwargs
6369

@@ -71,6 +77,7 @@ async def _update(
7177
max_tokens: UndefinedOr[int] = UNDEFINED,
7278
instruction: UndefinedOr[str] = UNDEFINED,
7379
max_prompt_tokens: UndefinedOr[int] = UNDEFINED,
80+
prompt_truncation_strategy: UndefinedOr[PromptTruncationStrategyType] = UNDEFINED,
7481
name: UndefinedOr[str] = UNDEFINED,
7582
description: UndefinedOr[str] = UNDEFINED,
7683
labels: UndefinedOr[dict[str, str]] = UNDEFINED,
@@ -104,16 +111,20 @@ async def _update(
104111
else:
105112
raise TypeError('model argument must be str, GPTModel object either undefined')
106113

114+
prompt_truncation_options = PromptTruncationOptions._coerce(
115+
max_prompt_tokens=max_prompt_tokens,
116+
strategy=prompt_truncation_strategy
117+
)
118+
proto_prompt_trunction_options = prompt_truncation_options._to_proto()
119+
107120
request = UpdateAssistantRequest(
108121
assistant_id=self.id,
109122
name=get_defined_value(name, ''),
110123
description=get_defined_value(description, ''),
111124
labels=get_defined_value(labels, {}),
112125
instruction=get_defined_value(instruction, ''),
113126
expiration_config=expiration_config.to_proto(),
114-
prompt_truncation_options=get_prompt_trunctation_options(
115-
max_prompt_tokens=get_defined_value(max_prompt_tokens, None)
116-
),
127+
prompt_truncation_options=proto_prompt_trunction_options,
117128
completion_options=get_completion_options(
118129
temperature=temperature,
119130
max_tokens=max_tokens,
@@ -135,9 +146,8 @@ async def _update(
135146
'model_uri': model_uri,
136147
'completion_options.temperature': temperature,
137148
'completion_options.max_tokens': max_tokens,
138-
'prompt_truncation_options.max_prompt_tokens': max_prompt_tokens,
139149
'tools': tools,
140-
}
150+
} | prompt_truncation_options._get_update_paths()
141151
)
142152

143153
async with self._client.get_service_stub(AssistantServiceStub, timeout=timeout) as stub:
@@ -215,6 +225,7 @@ async def _run_impl(
215225
custom_temperature: UndefinedOr[float] = UNDEFINED,
216226
custom_max_tokens: UndefinedOr[int] = UNDEFINED,
217227
custom_max_prompt_tokens: UndefinedOr[int] = UNDEFINED,
228+
custom_prompt_truncation_strategy: UndefinedOr[PromptTruncationStrategyType] = UNDEFINED,
218229
timeout: float = 60,
219230
) -> RunTypeT:
220231
return await self._sdk.runs._create(
@@ -224,6 +235,7 @@ async def _run_impl(
224235
custom_temperature=custom_temperature,
225236
custom_max_tokens=custom_max_tokens,
226237
custom_max_prompt_tokens=custom_max_prompt_tokens,
238+
custom_prompt_truncation_strategy=custom_prompt_truncation_strategy,
227239
timeout=timeout,
228240
)
229241

@@ -234,6 +246,7 @@ async def _run(
234246
custom_temperature: UndefinedOr[float] = UNDEFINED,
235247
custom_max_tokens: UndefinedOr[int] = UNDEFINED,
236248
custom_max_prompt_tokens: UndefinedOr[int] = UNDEFINED,
249+
custom_prompt_truncation_strategy: UndefinedOr[PromptTruncationStrategyType] = UNDEFINED,
237250
timeout: float = 60,
238251
) -> RunTypeT:
239252
return await self._run_impl(
@@ -242,6 +255,7 @@ async def _run(
242255
custom_temperature=custom_temperature,
243256
custom_max_tokens=custom_max_tokens,
244257
custom_max_prompt_tokens=custom_max_prompt_tokens,
258+
custom_prompt_truncation_strategy=custom_prompt_truncation_strategy,
245259
timeout=timeout,
246260
)
247261

@@ -252,6 +266,7 @@ async def _run_stream(
252266
custom_temperature: UndefinedOr[float] = UNDEFINED,
253267
custom_max_tokens: UndefinedOr[int] = UNDEFINED,
254268
custom_max_prompt_tokens: UndefinedOr[int] = UNDEFINED,
269+
custom_prompt_truncation_strategy: UndefinedOr[PromptTruncationStrategyType] = UNDEFINED,
255270
timeout: float = 60,
256271
) -> RunTypeT:
257272
return await self._run_impl(
@@ -260,6 +275,7 @@ async def _run_stream(
260275
custom_temperature=custom_temperature,
261276
custom_max_tokens=custom_max_tokens,
262277
custom_max_prompt_tokens=custom_max_prompt_tokens,
278+
custom_prompt_truncation_strategy=custom_prompt_truncation_strategy,
263279
timeout=timeout,
264280
)
265281

@@ -293,6 +309,7 @@ async def update(
293309
max_tokens: UndefinedOr[int] = UNDEFINED,
294310
instruction: UndefinedOr[str] = UNDEFINED,
295311
max_prompt_tokens: UndefinedOr[int] = UNDEFINED,
312+
prompt_truncation_strategy: UndefinedOr[PromptTruncationStrategyType] = UNDEFINED,
296313
name: UndefinedOr[str] = UNDEFINED,
297314
description: UndefinedOr[str] = UNDEFINED,
298315
labels: UndefinedOr[dict[str, str]] = UNDEFINED,
@@ -307,6 +324,7 @@ async def update(
307324
max_tokens=max_tokens,
308325
instruction=instruction,
309326
max_prompt_tokens=max_prompt_tokens,
327+
prompt_truncation_strategy=prompt_truncation_strategy,
310328
name=name,
311329
description=description,
312330
labels=labels,
@@ -343,13 +361,15 @@ async def run(
343361
custom_temperature: UndefinedOr[float] = UNDEFINED,
344362
custom_max_tokens: UndefinedOr[int] = UNDEFINED,
345363
custom_max_prompt_tokens: UndefinedOr[int] = UNDEFINED,
364+
custom_prompt_truncation_strategy: UndefinedOr[PromptTruncationStrategyType] = UNDEFINED,
346365
timeout: float = 60,
347366
) -> AsyncRun:
348367
return await self._run(
349368
thread=thread,
350369
custom_temperature=custom_temperature,
351370
custom_max_tokens=custom_max_tokens,
352371
custom_max_prompt_tokens=custom_max_prompt_tokens,
372+
custom_prompt_truncation_strategy=custom_prompt_truncation_strategy,
353373
timeout=timeout
354374
)
355375

@@ -360,13 +380,15 @@ async def run_stream(
360380
custom_temperature: UndefinedOr[float] = UNDEFINED,
361381
custom_max_tokens: UndefinedOr[int] = UNDEFINED,
362382
custom_max_prompt_tokens: UndefinedOr[int] = UNDEFINED,
383+
custom_prompt_truncation_strategy: UndefinedOr[PromptTruncationStrategyType] = UNDEFINED,
363384
timeout: float = 60,
364385
) -> AsyncRun:
365386
return await self._run_stream(
366387
thread=thread,
367388
custom_temperature=custom_temperature,
368389
custom_max_tokens=custom_max_tokens,
369390
custom_max_prompt_tokens=custom_max_prompt_tokens,
391+
custom_prompt_truncation_strategy=custom_prompt_truncation_strategy,
370392
timeout=timeout
371393
)
372394

@@ -380,6 +402,7 @@ def update(
380402
max_tokens: UndefinedOr[int] = UNDEFINED,
381403
instruction: UndefinedOr[str] = UNDEFINED,
382404
max_prompt_tokens: UndefinedOr[int] = UNDEFINED,
405+
prompt_truncation_strategy: UndefinedOr[PromptTruncationStrategyType] = UNDEFINED,
383406
name: UndefinedOr[str] = UNDEFINED,
384407
description: UndefinedOr[str] = UNDEFINED,
385408
labels: UndefinedOr[dict[str, str]] = UNDEFINED,
@@ -394,6 +417,7 @@ def update(
394417
max_tokens=max_tokens,
395418
instruction=instruction,
396419
max_prompt_tokens=max_prompt_tokens,
420+
prompt_truncation_strategy=prompt_truncation_strategy,
397421
name=name,
398422
description=description,
399423
labels=labels,
@@ -432,13 +456,15 @@ def run(
432456
custom_temperature: UndefinedOr[float] = UNDEFINED,
433457
custom_max_tokens: UndefinedOr[int] = UNDEFINED,
434458
custom_max_prompt_tokens: UndefinedOr[int] = UNDEFINED,
459+
custom_prompt_truncation_strategy: UndefinedOr[PromptTruncationStrategyType] = UNDEFINED,
435460
timeout: float = 60,
436461
) -> Run:
437462
return run_sync_impl(self._run(
438463
thread=thread,
439464
custom_temperature=custom_temperature,
440465
custom_max_tokens=custom_max_tokens,
441466
custom_max_prompt_tokens=custom_max_prompt_tokens,
467+
custom_prompt_truncation_strategy=custom_prompt_truncation_strategy,
442468
timeout=timeout
443469
), self._sdk)
444470

@@ -449,13 +475,15 @@ def run_stream(
449475
custom_temperature: UndefinedOr[float] = UNDEFINED,
450476
custom_max_tokens: UndefinedOr[int] = UNDEFINED,
451477
custom_max_prompt_tokens: UndefinedOr[int] = UNDEFINED,
478+
custom_prompt_truncation_strategy: UndefinedOr[PromptTruncationStrategyType] = UNDEFINED,
452479
timeout: float = 60,
453480
) -> Run:
454481
return run_sync_impl(self._run_stream(
455482
thread=thread,
456483
custom_temperature=custom_temperature,
457484
custom_max_tokens=custom_max_tokens,
458485
custom_max_prompt_tokens=custom_max_prompt_tokens,
486+
custom_prompt_truncation_strategy=custom_prompt_truncation_strategy,
459487
timeout=timeout
460488
), self._sdk)
461489

0 commit comments

Comments
 (0)