diff --git a/aidial_sdk/chat_completion/request.py b/aidial_sdk/chat_completion/request.py index 2186b50..f731b6b 100644 --- a/aidial_sdk/chat_completion/request.py +++ b/aidial_sdk/chat_completion/request.py @@ -15,10 +15,10 @@ from aidial_sdk.chat_completion.enums import Status from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin from aidial_sdk.exceptions import InvalidRequestError -from aidial_sdk.utils.pydantic import ExtraAllowModel +from aidial_sdk.utils.pydantic import ExtraAllowModel, IgnoreIndex -class Attachment(ExtraAllowModel): +class Attachment(ExtraAllowModel, IgnoreIndex): type: Optional[StrictStr] = "text/markdown" title: Optional[StrictStr] = None data: Optional[StrictStr] = None @@ -43,7 +43,7 @@ def check_data_or_url(cls, values: Any): return values -class Stage(ExtraAllowModel): +class Stage(ExtraAllowModel, IgnoreIndex): name: StrictStr status: Status content: Optional[StrictStr] = None @@ -63,9 +63,7 @@ class FunctionCall(ExtraAllowModel): arguments: str -class ToolCall(ExtraAllowModel): - # OpenAI API doesn't strictly specify existence of the index field - index: Optional[int] +class ToolCall(ExtraAllowModel, IgnoreIndex): id: StrictStr type: Literal["function"] function: FunctionCall diff --git a/aidial_sdk/utils/pydantic.py b/aidial_sdk/utils/pydantic.py index b430380..c6a54b1 100644 --- a/aidial_sdk/utils/pydantic.py +++ b/aidial_sdk/utils/pydantic.py @@ -1,8 +1,8 @@ from enum import Enum -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union from aidial_sdk._pydantic import PYDANTIC_V2, ConfigDict, FieldInfo -from aidial_sdk._pydantic._compat import BaseModel +from aidial_sdk._pydantic._compat import BaseModel, model_validator class ExtraAllowModel(BaseModel): @@ -14,6 +14,21 @@ class Config: extra = "allow" +class IgnoreIndex(BaseModel): + @model_validator(mode="before") + @classmethod + def strip_index(cls, data: Any) -> Any: + if ( + isinstance(data, Mapping) + and (idx := data.get("index")) is not None + and isinstance(idx, int) + ): + d = dict(data) + d.pop("index") + return d + return data + + _Loc = Tuple[Union[int, str], ...] diff --git a/tests/test_request_indices.py b/tests/test_request_indices.py new file mode 100644 index 0000000..bb52e47 --- /dev/null +++ b/tests/test_request_indices.py @@ -0,0 +1,96 @@ +import dataclasses +from typing import List + +import pytest + +from aidial_sdk._pydantic import BaseModel, ValidationError +from aidial_sdk.chat_completion import ( + Attachment, + FunctionCall, + Status, + ToolCall, +) +from aidial_sdk.chat_completion.request import Stage +from tests.utils.pydantic import model_dump, model_parse + + +@dataclasses.dataclass +class TestCase: + __test__ = False + obj: BaseModel + dct: dict + + def get_id(self) -> str: + return type(self.obj).__name__ + + +_test_cases: List[TestCase] = [ + TestCase( + ToolCall( + id="tool-call-id", + type="function", + function=FunctionCall(name="func-name", arguments="{}"), + ), + { + "id": "tool-call-id", + "type": "function", + "function": {"name": "func-name", "arguments": "{}"}, + }, + ), + TestCase( + Attachment(type="text/plain", data="test"), + {"type": "text/plain", "data": "test"}, + ), + TestCase( + Stage(name="Testing", status=Status.COMPLETED, content="test"), + {"name": "Testing", "status": "completed", "content": "test"}, + ), +] + + +@pytest.fixture(params=_test_cases, ids=lambda x: x.get_id()) +def test_case(request) -> TestCase: + return request.param + + +def _check_ser_deser(obj: BaseModel): + dct = model_dump(obj) + obj2 = model_parse(type(obj), dct, allow_extra_fields=False) + assert obj == obj2 + + +def test_index_field_ser_deser(test_case: TestCase): + _check_ser_deser(test_case.obj) + + +def test_index_field_ignore_int(test_case: TestCase): + obj = model_parse( + type(test_case.obj), + {**test_case.dct, **{"index": 101}}, + allow_extra_fields=False, + ) + _check_ser_deser(obj) + + +def test_index_field_fail_on_str(test_case: TestCase): + with pytest.raises( + ValidationError, + match=r"(Extra inputs are not permitted|extra fields not permitted)", + ): + model_parse( + type(test_case.obj), + {**test_case.dct, **{"index": "value"}}, + allow_extra_fields=False, + ) + + +def test_index_field_fail_on_extra_fields(test_case: TestCase): + with pytest.raises( + ValidationError, + match=r"(Extra inputs are not permitted|extra fields not permitted)", + ): + model_parse( + type(test_case.obj), + {**test_case.dct, **{"index2": "whatever"}}, + allow_extra_fields=False, + ) diff --git a/tests/utils/pydantic.py b/tests/utils/pydantic.py index 2518e30..3a28606 100644 --- a/tests/utils/pydantic.py +++ b/tests/utils/pydantic.py @@ -2,6 +2,7 @@ from aidial_sdk._pydantic import PYDANTIC_V2, BaseModel from aidial_sdk._pydantic import Field as PydField +from aidial_sdk.utils.pydantic import model_validate_extra_fields _ModelT = TypeVar("_ModelT", bound=BaseModel) @@ -15,16 +16,28 @@ def Field(*args, **kwargs) -> Any: return PydField(*args, **kwargs) -def model_parse(model: Type[_ModelT], data: Any) -> _ModelT: +def model_parse( + model: Type[_ModelT], data: Any, *, allow_extra_fields=True +) -> _ModelT: if PYDANTIC_V2: - return model.model_validate(data) - return model.parse_obj(data) # pyright: ignore[reportDeprecated] + obj = model.model_validate(data) + else: + obj = model.parse_obj(data) # pyright: ignore[reportDeprecated] + if not allow_extra_fields: + model_validate_extra_fields(obj) # type: ignore + return obj -def model_parse_json(model: Type[_ModelT], data: Union[str, bytes]) -> _ModelT: +def model_parse_json( + model: Type[_ModelT], data: Union[str, bytes], *, allow_extra_fields=True +) -> _ModelT: if PYDANTIC_V2: - return model.model_validate_json(data) - return model.parse_raw(data) # pyright: ignore[reportDeprecated] + obj = model.model_validate_json(data) + else: + obj = model.parse_raw(data) # pyright: ignore[reportDeprecated] + if not allow_extra_fields: + model_validate_extra_fields(obj) # type: ignore + return obj def model_json_schema(model: Type[_ModelT]) -> Dict[str, Any]: