Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions aidial_sdk/chat_completion/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from aidial_sdk.chat_completion.enums import Status
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.exceptions import InvalidRequestError
from aidial_sdk.utils.pydantic import ExtraAllowModel
from aidial_sdk.utils.pydantic import ExtraAllowModel, IgnoreIndex


class Attachment(ExtraAllowModel):
class Attachment(ExtraAllowModel, IgnoreIndex):
type: Optional[StrictStr] = "text/markdown"
title: Optional[StrictStr] = None
data: Optional[StrictStr] = None
Expand All @@ -43,7 +43,7 @@ def check_data_or_url(cls, values: Any):
return values


class Stage(ExtraAllowModel):
class Stage(ExtraAllowModel, IgnoreIndex):
name: StrictStr
status: Status
content: Optional[StrictStr] = None
Expand All @@ -63,9 +63,7 @@ class FunctionCall(ExtraAllowModel):
arguments: str


class ToolCall(ExtraAllowModel):
# OpenAI API doesn't strictly specify existence of the index field
index: Optional[int]
class ToolCall(ExtraAllowModel, IgnoreIndex):
id: StrictStr
type: Literal["function"]
function: FunctionCall
Expand Down
19 changes: 17 additions & 2 deletions aidial_sdk/utils/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from enum import Enum
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union

from aidial_sdk._pydantic import PYDANTIC_V2, ConfigDict, FieldInfo
from aidial_sdk._pydantic._compat import BaseModel
from aidial_sdk._pydantic._compat import BaseModel, model_validator


class ExtraAllowModel(BaseModel):
Expand All @@ -14,6 +14,21 @@ class Config:
extra = "allow"


class IgnoreIndex(BaseModel):
@model_validator(mode="before")
@classmethod
def strip_index(cls, data: Any) -> Any:
if (
isinstance(data, Mapping)
and (idx := data.get("index")) is not None
and isinstance(idx, int)
):
d = dict(data)
d.pop("index")
return d
return data


_Loc = Tuple[Union[int, str], ...]


Expand Down
96 changes: 96 additions & 0 deletions tests/test_request_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import dataclasses
from typing import List

import pytest

from aidial_sdk._pydantic import BaseModel, ValidationError
from aidial_sdk.chat_completion import (
Attachment,
FunctionCall,
Status,
ToolCall,
)
from aidial_sdk.chat_completion.request import Stage
from tests.utils.pydantic import model_dump, model_parse


@dataclasses.dataclass
class TestCase:
__test__ = False
obj: BaseModel
dct: dict

def get_id(self) -> str:
return type(self.obj).__name__


_test_cases: List[TestCase] = [
TestCase(
ToolCall(
id="tool-call-id",
type="function",
function=FunctionCall(name="func-name", arguments="{}"),
),
{
"id": "tool-call-id",
"type": "function",
"function": {"name": "func-name", "arguments": "{}"},
},
),
TestCase(
Attachment(type="text/plain", data="test"),
{"type": "text/plain", "data": "test"},
),
TestCase(
Stage(name="Testing", status=Status.COMPLETED, content="test"),
{"name": "Testing", "status": "completed", "content": "test"},
),
]


@pytest.fixture(params=_test_cases, ids=lambda x: x.get_id())
def test_case(request) -> TestCase:
return request.param


def _check_ser_deser(obj: BaseModel):
dct = model_dump(obj)
obj2 = model_parse(type(obj), dct, allow_extra_fields=False)
assert obj == obj2


def test_index_field_ser_deser(test_case: TestCase):
_check_ser_deser(test_case.obj)


def test_index_field_ignore_int(test_case: TestCase):
obj = model_parse(
type(test_case.obj),
{**test_case.dct, **{"index": 101}},
allow_extra_fields=False,
)
_check_ser_deser(obj)


def test_index_field_fail_on_str(test_case: TestCase):
with pytest.raises(
ValidationError,
match=r"(Extra inputs are not permitted|extra fields not permitted)",
):
model_parse(
type(test_case.obj),
{**test_case.dct, **{"index": "value"}},
allow_extra_fields=False,
)


def test_index_field_fail_on_extra_fields(test_case: TestCase):
with pytest.raises(
ValidationError,
match=r"(Extra inputs are not permitted|extra fields not permitted)",
):
model_parse(
type(test_case.obj),
{**test_case.dct, **{"index2": "whatever"}},
allow_extra_fields=False,
)
25 changes: 19 additions & 6 deletions tests/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from aidial_sdk._pydantic import PYDANTIC_V2, BaseModel
from aidial_sdk._pydantic import Field as PydField
from aidial_sdk.utils.pydantic import model_validate_extra_fields

_ModelT = TypeVar("_ModelT", bound=BaseModel)

Expand All @@ -15,16 +16,28 @@ def Field(*args, **kwargs) -> Any:
return PydField(*args, **kwargs)


def model_parse(model: Type[_ModelT], data: Any) -> _ModelT:
def model_parse(
model: Type[_ModelT], data: Any, *, allow_extra_fields=True
) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(data)
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
obj = model.model_validate(data)
else:
obj = model.parse_obj(data) # pyright: ignore[reportDeprecated]
if not allow_extra_fields:
model_validate_extra_fields(obj) # type: ignore
return obj


def model_parse_json(model: Type[_ModelT], data: Union[str, bytes]) -> _ModelT:
def model_parse_json(
model: Type[_ModelT], data: Union[str, bytes], *, allow_extra_fields=True
) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate_json(data)
return model.parse_raw(data) # pyright: ignore[reportDeprecated]
obj = model.model_validate_json(data)
else:
obj = model.parse_raw(data) # pyright: ignore[reportDeprecated]
if not allow_extra_fields:
model_validate_extra_fields(obj) # type: ignore
return obj


def model_json_schema(model: Type[_ModelT]) -> Dict[str, Any]:
Expand Down
Loading