Skip to content

Commit 716140a

Browse files
committed
Fix all old tests
1 parent d9f57c6 commit 716140a

File tree

10 files changed

+36
-27
lines changed

10 files changed

+36
-27
lines changed

examples/async/completions/batch_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ async def main() -> None:
5757

5858
with tqdm.tqdm(total=len(input_data), desc='creatig tasks') as t:
5959
operations = await run_chunked_tasks(
60-
function=model.run_async,
60+
function=model.run_deferred,
6161
data=input_data,
6262
chunk_size=100,
6363
tqdm_callback=t.update,

examples/async/completions/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ async def main() -> None:
1414
model = sdk.models.completions('yandexgpt')
1515
model = model.configure(temperature=0.5)
1616

17-
messages: list[dict[str, str] | str] = [{'role': 'system', 'text': 'Ты - Аркадий'}]
17+
messages: list = [{'role': 'system', 'text': 'Ты - Аркадий'}]
1818
while True:
1919
message = input()
2020
messages.append(message)

examples/async/function_calling/completions/raw_tool_calls_processing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def weather(location: str, date: str) -> str:
1717
return "-10 celsius"
1818

1919

20-
def process_tool_calls(tool_calls) -> list[dict]:
20+
def process_tool_calls(tool_calls) -> dict[str, list[dict]]:
2121
"""
2222
This function is an example how you could organize
2323
dispatching of function calls in general case
@@ -35,7 +35,7 @@ def process_tool_calls(tool_calls) -> list[dict]:
3535

3636
function = function_map[tool_call.function.name]
3737

38-
answer = function(**tool_call.function.arguments)
38+
answer = function(**tool_call.function.arguments) # type: ignore[operator]
3939

4040
result.append({'name': tool_call.function.name, 'content': answer})
4141

@@ -86,7 +86,7 @@ class Weather(BaseModel):
8686

8787
for question in ["How much it would be 7@8?", "What is the weather like in Paris at 12 of March?"]:
8888
# it is required to carefully maintain context for passing tool_results back to the model after function call
89-
messages = [
89+
messages: list = [
9090
{"role": "system", "text": "Please use English language for answer"},
9191
question
9292
]

src/yandex_cloud_ml_sdk/_models/completions/message.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def messages_to_proto(messages: MessageInputType) -> list[ProtoMessage]:
7474
kwargs = {'role': 'user', 'text': message}
7575
elif isinstance(message, TextMessageProtocol):
7676
if isinstance(message, TextMessageWithToolCallsProtocol) and message.tool_calls:
77-
# pylint: disable[protected-access]
77+
# pylint: disable=protected-access
7878
kwargs = {'role': message.role, 'tool_call_list': message.tool_calls._proto_origin}
7979
else:
8080
kwargs = {'role': message.role, 'text': message.text}
@@ -91,7 +91,7 @@ def messages_to_proto(messages: MessageInputType) -> list[ProtoMessage]:
9191
'tool_result_list': tool_results,
9292
}
9393
else:
94-
raise ValueError(f'{message=!r} should have a "text" or "tool_results" key')
94+
raise TypeError(f'{message=!r} should have a "text" or "tool_results" key')
9595
else:
9696
raise TypeError(f'{message=!r} should be str, dict with "text" or "tool_results" key or TextMessage instance')
9797

src/yandex_cloud_ml_sdk/_tools/function.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TypeVar
3+
from typing import TypeVar, cast
44

55
from yandex_cloud_ml_sdk._types.domain import BaseDomain
66
from yandex_cloud_ml_sdk._types.misc import UNDEFINED, UndefinedOr, get_defined_value
@@ -21,8 +21,14 @@ def __call__(
2121
description: UndefinedOr[str] = UNDEFINED,
2222
) -> FunctionTool:
2323
schema = schema_from_parameters(parameters)
24-
description_ = get_defined_value(description, None) or schema.get('description')
25-
name_ = get_defined_value(name, None) or schema.get('title')
24+
description_ = (
25+
get_defined_value(description, None) or
26+
cast(str | None, schema.get('description'))
27+
)
28+
name_ = (
29+
get_defined_value(name, None) or
30+
cast(str | None, schema.get('title'))
31+
)
2632

2733
if not name_:
2834
raise TypeError(

src/yandex_cloud_ml_sdk/_tools/tool_call.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# pylint: disable=no-name-in-module
22
from __future__ import annotations
33

4-
from collections.abc import Iterable
54
from dataclasses import dataclass, field
65
from typing import Generic, TypeVar, Union
76

src/yandex_cloud_ml_sdk/_tools/tool_result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def tool_results_to_proto(
7878
ProtoCompletionsToolResultList: ProtoCompletionsToolResult,
7979
}[proto_type])
8080

81-
tool_results = coerce_tuple(tool_results, (dict, ))
81+
tool_results = coerce_tuple(tool_results, cast(type[FunctionResultDict], dict))
8282

8383
proto_tool_results: list[object] = []
8484
for tool_result in tool_results:

src/yandex_cloud_ml_sdk/_types/operation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import abc
55
import asyncio
66
from dataclasses import dataclass, field
7-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generic, Iterable, TypeVar, cast
7+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generic, Iterable, TypeVar, cast, get_origin
88

99
from typing_extensions import Self
1010
# pylint: disable-next=no-name-in-module
@@ -186,13 +186,17 @@ async def _default_result_transofrmer(self, proto: Any, timeout: float) -> Resul
186186
# NB: default_result_transformer should be used only with _result_type
187187
# which are BaseResult-compatible, but I don't know how to express it with typing,
188188
# maybe we need special operation class, which support transforming (probably a base one)
189-
assert isinstance(self._result_type, ProtoBasedType)
189+
# NB: issubclass don't like if instead of SomeClass object pass SomeClass[T];
190+
# because we use _result_type also for a generic typing reasons, sometimes it requires
191+
# unwrapping for issubclass check
192+
result_type = get_origin(self._result_type) or self._result_type
193+
assert issubclass(result_type, ProtoBasedType), f'{self._result_type} is not ProtoBasedType'
190194

191195
# NB: mypy can't figure out that self._result_type._from_proto is
192196
# returning instance of self._result_type which is also is a ResultTypeT_co
193197
return cast(
194198
ResultTypeT_co,
195-
self._result_type._from_proto(proto=proto, sdk=self._sdk)
199+
self._result_type._from_proto(proto=proto, sdk=self._sdk) # type: ignore[attr-defined]
196200
)
197201

198202
@property

tests/models/test_image_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def check_messages(messages, expected):
6363
assert messages[0].weight == 2
6464
assert messages[1].weight == 1
6565

66-
messages = messages_to_proto(Alternative(role='foo', text='bar', status=None))
66+
messages = messages_to_proto(Alternative(role='foo', text='bar', status=None, tool_calls=None))
6767
check_messages(messages, ['bar'])
6868
assert messages[0].weight == 0
6969

@@ -73,8 +73,8 @@ def check_messages(messages, expected):
7373

7474
gpt_model_result = GPTModelResult(
7575
alternatives=[
76-
Alternative(role='1', text='1', status=None),
77-
Alternative(role='2', text='2', status=None),
76+
Alternative(role='1', text='1', status=None, tool_calls=None),
77+
Alternative(role='2', text='2', status=None, tool_calls=None),
7878
],
7979
usage=None,
8080
model_version=''

tests/types/test_structured_output.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import pytest
99

10-
import yandex_cloud_ml_sdk._types.structured_output
11-
from yandex_cloud_ml_sdk._types.structured_output import schema_from_response_format
10+
import yandex_cloud_ml_sdk._types.schemas
11+
from yandex_cloud_ml_sdk._types.schemas import schema_from_response_format
1212

1313

1414
def test_string_type() -> None:
@@ -25,7 +25,7 @@ def test_dict_type() -> None:
2525

2626
@pytest.mark.require_env('pydantic')
2727
def test_pydantic_model() -> None:
28-
assert yandex_cloud_ml_sdk._types.structured_output.PYDANTIC is True
28+
assert yandex_cloud_ml_sdk._types.schemas.PYDANTIC is True
2929

3030
import pydantic
3131

@@ -116,25 +116,25 @@ class B:
116116
@pytest.fixture(name='no_pydantic')
117117
def fixture_no_pydantic(monkeypatch) -> typing.Iterator[bool]:
118118
# pylint: disable=reimported
119-
sys.modules.pop(yandex_cloud_ml_sdk._types.structured_output.__name__, None)
119+
sys.modules.pop(yandex_cloud_ml_sdk._types.schemas.__name__, None)
120120

121121
with monkeypatch.context() as m:
122122
m.setitem(sys.modules, 'pydantic', None)
123-
import yandex_cloud_ml_sdk._types.structured_output as _m
123+
import yandex_cloud_ml_sdk._types.schemas as _m
124124

125125
yield True
126126

127-
sys.modules.pop(yandex_cloud_ml_sdk._types.structured_output.__name__, None)
128-
import yandex_cloud_ml_sdk._types.structured_output as _m2
127+
sys.modules.pop(yandex_cloud_ml_sdk._types.schemas.__name__, None)
128+
import yandex_cloud_ml_sdk._types.schemas as _m2
129129

130130
assert _m
131131
assert _m2
132132

133133

134134
def test_no_pydantic(no_pydantic) -> None:
135135
assert no_pydantic
136-
assert yandex_cloud_ml_sdk._types.structured_output.PYDANTIC is False
137-
assert yandex_cloud_ml_sdk._types.structured_output.PYDANTIC_V2 is False
136+
assert yandex_cloud_ml_sdk._types.schemas.PYDANTIC is False
137+
assert yandex_cloud_ml_sdk._types.schemas.PYDANTIC_V2 is False
138138

139139
class A:
140140
a: int

0 commit comments

Comments
 (0)