Skip to content

Commit 5d96854

Browse files
authored
Support strict field for FC in completions (#121)
1 parent 7d8866b commit 5d96854

File tree

5 files changed

+59
-10
lines changed

5 files changed

+59
-10
lines changed

examples/async/function_calling/completions/raw_tool_calls_processing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ async def main() -> None:
5959
}
6060
},
6161
"required": ["expression"],
62-
}
62+
},
63+
strict=True,
6364
)
6465

6566
# it is imported inside only because yandex-cloud-ml-sdk does not require pydantic by default
@@ -78,7 +79,7 @@ class Weather(BaseModel):
7879
# be inferred from pydantic model
7980
weather_tool = sdk.tools.function(Weather)
8081

81-
model = sdk.models.completions('yandexgpt')
82+
model = sdk.models.completions('yandexgpt', model_version='rc')
8283

8384
# tools must be bound to model object via .configure method and would be used in all
8485
# model calls from this model object.

examples/sync/function_calling/completions/raw_tool_calls_processing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def main() -> None:
5757
}
5858
},
5959
"required": ["expression"],
60-
}
60+
},
61+
strict=True,
6162
)
6263

6364
# it is imported inside only because yandex-cloud-ml-sdk does not require pydantic by default
@@ -76,7 +77,7 @@ class Weather(BaseModel):
7677
# be inferred from pydantic model
7778
weather_tool = sdk.tools.function(Weather)
7879

79-
model = sdk.models.completions('yandexgpt')
80+
model = sdk.models.completions('yandexgpt', model_version='rc')
8081

8182
# tools must be bound to model object via .configure method and would be used in all
8283
# model calls from this model object.

src/yandex_cloud_ml_sdk/_tools/function.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __call__(
1919
*,
2020
name: UndefinedOr[str] = UNDEFINED,
2121
description: UndefinedOr[str] = UNDEFINED,
22+
strict: UndefinedOr[bool] = UNDEFINED,
2223
) -> FunctionTool:
2324
schema = schema_from_parameters(parameters)
2425
description_ = (
@@ -29,6 +30,7 @@ def __call__(
2930
get_defined_value(name, None) or
3031
cast(Optional[str], schema.get('title'))
3132
)
33+
strict_: bool | None = get_defined_value(strict, None)
3234

3335
if not name_:
3436
raise TypeError(
@@ -39,7 +41,8 @@ def __call__(
3941
return FunctionTool(
4042
parameters=schema,
4143
name=name_,
42-
description=description_
44+
description=description_,
45+
strict=strict_,
4346
)
4447

4548

src/yandex_cloud_ml_sdk/_tools/tool.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class FunctionTool(BaseTool):
103103
name: str
104104
description: str | None
105105
parameters: JsonSchemaType
106+
strict: bool | None
106107

107108
@classmethod
108109
def _from_proto(
@@ -111,10 +112,16 @@ def _from_proto(
111112
sdk: BaseSDK
112113
) -> FunctionTool:
113114
parameters = MessageToDict(proto.parameters)
115+
116+
strict: bool | None = None
117+
if hasattr(proto, 'strict'):
118+
strict = proto.strict
119+
114120
return cls(
115121
name=proto.name,
116122
description=proto.description,
117123
parameters=parameters,
124+
strict=strict,
118125
)
119126

120127
def _to_proto(self, proto_type: type[ProtoToolTypeT]) -> ProtoToolTypeT:
@@ -126,10 +133,25 @@ def _to_proto(self, proto_type: type[ProtoToolTypeT]) -> ProtoToolTypeT:
126133
ProtoCompletionsTool: ProtoCompletionsFunctionTool,
127134
}[proto_type]
128135

136+
additional_kwargs = {}
137+
# TODO: remove this logic after strict would be supported in assistants
138+
if self.strict is not None:
139+
strict_field_present = 'strict' in {
140+
field.name
141+
for field in function_class.DESCRIPTOR.fields # type: ignore[attr-defined]
142+
}
143+
if strict_field_present:
144+
additional_kwargs['strict'] = self.strict
145+
else:
146+
raise ValueError(
147+
'"strict" field is not supported in sdk.assistants yet, only in sdk.models.completions'
148+
)
149+
129150
function = function_class(
130151
name=self.name,
131152
description=self.description or '',
132-
parameters=parameters
153+
parameters=parameters,
154+
**additional_kwargs,
133155
)
134156

135157
# i dunno how to properly describe this type of polymorphism to mypy

tests/tools/test_function.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
# pylint: disable=import-outside-toplevel
1+
# pylint: disable=import-outside-toplevel,no-name-in-module
22
from __future__ import annotations
33

44
import dataclasses
55

66
import pytest
7+
from yandex.cloud.ai.assistants.v1.common_pb2 import Tool as ProtoAssistantsTool
8+
from yandex.cloud.ai.foundation_models.v1.text_common_pb2 import Tool as ProtoCompletionsTool
79

810
from yandex_cloud_ml_sdk import AsyncYCloudML
911
from yandex_cloud_ml_sdk._tools.function import ParametersType
@@ -29,7 +31,8 @@ def test_pydantic_model_function_tool(async_sdk: AsyncYCloudML) -> None:
2931
"required": ["field"],
3032
"title": 'Function',
3133
"description": 'function description'
32-
}
34+
},
35+
strict=None,
3336
)
3437

3538
description: str = (
@@ -69,7 +72,8 @@ def test_pydantic_dataclass_function_tool(async_sdk: AsyncYCloudML) -> None:
6972
"required": ["field"],
7073
"title": 'Function',
7174
"description": 'function description'
72-
}
75+
},
76+
strict=None,
7377
)
7478

7579
description: str = (
@@ -109,7 +113,8 @@ def test_raw_json_function_tool(async_sdk: AsyncYCloudML) -> None:
109113
},
110114
},
111115
"required": ["field"],
112-
}
116+
},
117+
strict=None,
113118
)
114119

115120
parameters = dict(etalon.parameters)
@@ -144,3 +149,20 @@ def test_bad_types(async_sdk: AsyncYCloudML) -> None:
144149

145150
with pytest.raises(TypeError):
146151
async_sdk.tools.function([], name='foo') # type: ignore[arg-type]
152+
153+
154+
def test_strict(async_sdk: AsyncYCloudML) -> None:
155+
tool = async_sdk.tools.function({}, name='foo')
156+
157+
assistant_proto = tool._to_proto(ProtoAssistantsTool)
158+
assert not hasattr(assistant_proto.function, 'strict')
159+
160+
assert tool._to_proto(ProtoCompletionsTool).function.strict is False
161+
162+
for strict in (True, False):
163+
tool = async_sdk.tools.function({}, name='foo', strict=strict)
164+
proto = tool._to_proto(ProtoCompletionsTool)
165+
assert proto.function.strict is strict
166+
167+
with pytest.raises(ValueError, match='"strict" field is not supported'):
168+
tool._to_proto(ProtoAssistantsTool)

0 commit comments

Comments
 (0)