|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import asyncio |
| 6 | + |
| 7 | +from yandex_cloud_ml_sdk import AsyncYCloudML |
| 8 | + |
| 9 | + |
| 10 | +def calculator(expression: str) -> str: |
| 11 | + print(f'calculator got {expression=}') |
| 12 | + return "-5" |
| 13 | + |
| 14 | + |
| 15 | +def weather(location: str, date: str) -> str: |
| 16 | + print(f"weather func got {location=} and {date=}") |
| 17 | + return "-10 celsius" |
| 18 | + |
| 19 | + |
| 20 | +def process_tool_calls(tool_calls) -> dict[str, list[dict]]: |
| 21 | + """ |
| 22 | + This function is an example how you could organize |
| 23 | + dispatching of function calls in general case |
| 24 | + """ |
| 25 | + |
| 26 | + function_map = { |
| 27 | + 'calculator': calculator, |
| 28 | + 'Weather': weather |
| 29 | + } |
| 30 | + |
| 31 | + result = [] |
| 32 | + for tool_call in tool_calls: |
| 33 | + # only functions are available at the moment |
| 34 | + assert tool_call.function |
| 35 | + |
| 36 | + function = function_map[tool_call.function.name] |
| 37 | + |
| 38 | + answer = function(**tool_call.function.arguments) # type: ignore[operator] |
| 39 | + |
| 40 | + result.append({'name': tool_call.function.name, 'content': answer}) |
| 41 | + |
| 42 | + return {'tool_results': result} |
| 43 | + |
| 44 | + |
| 45 | +def create_tools(sdk: AsyncYCloudML): |
| 46 | + # it is imported inside only because yandex-cloud-ml-sdk does not require pydantic by default |
| 47 | + # pylint: disable=import-outside-toplevel |
| 48 | + from pydantic import BaseModel, Field |
| 49 | + |
| 50 | + calculator_tool = sdk.tools.function( |
| 51 | + name="calculator", |
| 52 | + description="A simple calculator that performs basic arithmetic operations.", |
| 53 | + # parameters could contain valid jsonschema with function parameters types and description |
| 54 | + parameters={ |
| 55 | + "type": "object", |
| 56 | + "properties": { |
| 57 | + "expression": { |
| 58 | + "type": "string", |
| 59 | + "description": "The mathematical expression to evaluate (e.g., '2 + 3 * 4').", |
| 60 | + } |
| 61 | + }, |
| 62 | + "required": ["expression"], |
| 63 | + } |
| 64 | + ) |
| 65 | + |
| 66 | + class Weather(BaseModel): |
| 67 | + """Getting the weather in specified location for specified date""" |
| 68 | + |
| 69 | + # It is important to describe all the arguments with a natural language |
| 70 | + location: str = Field(description="Name of the place for fetching weatcher at") |
| 71 | + date: str = Field(description="Date which a user interested in") |
| 72 | + |
| 73 | + weather_tool = sdk.tools.function(Weather) |
| 74 | + |
| 75 | + return [calculator_tool, weather_tool] |
| 76 | + |
| 77 | + |
| 78 | +async def main() -> None: |
| 79 | + sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64') |
| 80 | + sdk.setup_default_logging() |
| 81 | + |
| 82 | + model = sdk.models.completions('yandexgpt') |
| 83 | + |
| 84 | + # tools must be bound to a model object via .configure method and would be used in all |
| 85 | + # model calls from this model object. |
| 86 | + model = model.configure(tools=create_tools(sdk), temperature=0) |
| 87 | + |
| 88 | + for question in ["How much it would be 7@8?", "What is the weather like in Paris at 12 of March?"]: |
| 89 | + # it is required to carefully maintain context for passing tool_results back to the model after function call |
| 90 | + messages: list = [ |
| 91 | + {"role": "system", "text": "Please use English language for answer"}, |
| 92 | + question |
| 93 | + ] |
| 94 | + |
| 95 | + done = False |
| 96 | + result = None |
| 97 | + while not done: |
| 98 | + done = True |
| 99 | + async for event in model.run_stream(messages): |
| 100 | + print( |
| 101 | + 'Stream event - ' |
| 102 | + f'status={event.status.name!r}, text={event.text!r}, tool_calls={event.tool_calls}' |
| 103 | + ) |
| 104 | + |
| 105 | + if event.tool_calls: |
| 106 | + tool_results = process_tool_calls(event.tool_calls) |
| 107 | + |
| 108 | + # We need to enrich message history with tool_call record, tool_results record |
| 109 | + messages.append(event) |
| 110 | + messages.append(tool_results) |
| 111 | + # and launch another run_stream with a new message history |
| 112 | + done = False |
| 113 | + |
| 114 | + result = event |
| 115 | + |
| 116 | + assert result |
| 117 | + print(f"Model answer for {question=}:", result.text) |
| 118 | + |
| 119 | +if __name__ == '__main__': |
| 120 | + asyncio.run(main()) |
0 commit comments