Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ async def main() -> None:
}
},
"required": ["expression"],
}
},
strict=True,
)

# it is imported inside only because yandex-cloud-ml-sdk does not require pydantic by default
Expand All @@ -78,7 +79,7 @@ class Weather(BaseModel):
# be inferred from pydantic model
weather_tool = sdk.tools.function(Weather)

model = sdk.models.completions('yandexgpt')
model = sdk.models.completions('yandexgpt', model_version='rc')

# tools must be bound to model object via .configure method and would be used in all
# model calls from this model object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def main() -> None:
}
},
"required": ["expression"],
}
},
strict=True,
)

# it is imported inside only because yandex-cloud-ml-sdk does not require pydantic by default
Expand All @@ -76,7 +77,7 @@ class Weather(BaseModel):
# be inferred from pydantic model
weather_tool = sdk.tools.function(Weather)

model = sdk.models.completions('yandexgpt')
model = sdk.models.completions('yandexgpt', model_version='rc')

# tools must be bound to model object via .configure method and would be used in all
# model calls from this model object.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ classifiers = [
requires-python = ">=3.9"
dynamic = ["version"]
dependencies = [
"yandexcloud>=0.343.0",
"yandexcloud>=0.350.0",
"grpcio>=1.70.0",
"get-annotations",
"httpx>=0.27,<1",
Expand Down
5 changes: 4 additions & 1 deletion src/yandex_cloud_ml_sdk/_tools/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __call__(
*,
name: UndefinedOr[str] = UNDEFINED,
description: UndefinedOr[str] = UNDEFINED,
strict: UndefinedOr[bool] = UNDEFINED,
) -> FunctionTool:
schema = schema_from_parameters(parameters)
description_ = (
Expand All @@ -29,6 +30,7 @@ def __call__(
get_defined_value(name, None) or
cast(Optional[str], schema.get('title'))
)
strict_: bool | None = get_defined_value(strict, None)

if not name_:
raise TypeError(
Expand All @@ -39,7 +41,8 @@ def __call__(
return FunctionTool(
parameters=schema,
name=name_,
description=description_
description=description_,
strict=strict_,
)


Expand Down
24 changes: 23 additions & 1 deletion src/yandex_cloud_ml_sdk/_tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class FunctionTool(BaseTool):
name: str
description: str | None
parameters: JsonSchemaType
strict: bool | None

@classmethod
def _from_proto(
Expand All @@ -111,10 +112,16 @@ def _from_proto(
sdk: BaseSDK
) -> FunctionTool:
parameters = MessageToDict(proto.parameters)

strict: bool | None = None
if hasattr(proto, 'strict'):
strict = proto.strict

return cls(
name=proto.name,
description=proto.description,
parameters=parameters,
strict=strict,
)

def _to_proto(self, proto_type: type[ProtoToolTypeT]) -> ProtoToolTypeT:
Expand All @@ -126,10 +133,25 @@ def _to_proto(self, proto_type: type[ProtoToolTypeT]) -> ProtoToolTypeT:
ProtoCompletionsTool: ProtoCompletionsFunctionTool,
}[proto_type]

additional_kwargs = {}
# TODO: remove this logic after strict would be supported in assistants
if self.strict is not None:
strict_field_present = 'strict' in {
field.name
for field in function_class.DESCRIPTOR.fields # type: ignore[attr-defined]
}
if strict_field_present:
additional_kwargs['strict'] = self.strict
else:
raise ValueError(
'"strict" field is not supported in sdk.assistants yet, only in sdk.models.completions'
)

function = function_class(
name=self.name,
description=self.description or '',
parameters=parameters
parameters=parameters,
**additional_kwargs,
)

# i dunno how to properly describe this type of polymorphism to mypy
Expand Down
30 changes: 26 additions & 4 deletions tests/tools/test_function.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel,no-name-in-module
from __future__ import annotations

import dataclasses

import pytest
from yandex.cloud.ai.assistants.v1.common_pb2 import Tool as ProtoAssistantsTool
from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import Tool as ProtoCompletionsTool

from yandex_cloud_ml_sdk import AsyncYCloudML
from yandex_cloud_ml_sdk._tools.function import ParametersType
Expand All @@ -29,7 +31,8 @@ def test_pydantic_model_function_tool(async_sdk: AsyncYCloudML) -> None:
"required": ["field"],
"title": 'Function',
"description": 'function description'
}
},
strict=None,
)

description: str = (
Expand Down Expand Up @@ -69,7 +72,8 @@ def test_pydantic_dataclass_function_tool(async_sdk: AsyncYCloudML) -> None:
"required": ["field"],
"title": 'Function',
"description": 'function description'
}
},
strict=None,
)

description: str = (
Expand Down Expand Up @@ -109,7 +113,8 @@ def test_raw_json_function_tool(async_sdk: AsyncYCloudML) -> None:
},
},
"required": ["field"],
}
},
strict=None,
)

parameters = dict(etalon.parameters)
Expand Down Expand Up @@ -144,3 +149,20 @@ def test_bad_types(async_sdk: AsyncYCloudML) -> None:

with pytest.raises(TypeError):
async_sdk.tools.function([], name='foo') # type: ignore[arg-type]


def test_strict(async_sdk: AsyncYCloudML) -> None:
tool = async_sdk.tools.function({}, name='foo')

assistant_proto = tool._to_proto(ProtoAssistantsTool)
assert not hasattr(assistant_proto.function, 'strict')

assert tool._to_proto(ProtoCompletionsTool).function.strict is False

for strict in (True, False):
tool = async_sdk.tools.function({}, name='foo', strict=strict)
proto = tool._to_proto(ProtoCompletionsTool)
assert proto.function.strict is strict

with pytest.raises(ValueError, match='"strict" field is not supported'):
tool._to_proto(ProtoAssistantsTool)