Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions dmr/validation/controller.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import types
from typing import TYPE_CHECKING

Expand All @@ -19,12 +20,30 @@ def __call__(
controller: type['Controller[BaseSerializer]'],
) -> bool | None:
"""Run the validation."""
self._validate_async_generator_endpoints(controller)
is_async = self._validate_endpoints_color(controller)
self._validate_error_handlers(controller, is_async=is_async)
self._validate_meta_mixins(controller)
self._validate_non_endpoints(controller)
return is_async

def _validate_async_generator_endpoints(
self,
controller: type['Controller[BaseSerializer]'],
) -> None:
for method_name in controller.allowed_http_methods:
method = getattr(controller, method_name, None)
if isinstance(
method,
types.FunctionType,
) and inspect.isasyncgenfunction(method):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we are at it: let's all check for sync generators. Because they also won't work.

raise EndpointMetadataError(
f'{controller!r}.{method_name} is an async generator. '
'HTTP endpoints cannot use `yield` in method body. '
'Return an iterator object instead, '
'for example: `return self._events()`',
)

def _validate_endpoints_color(
self,
controller: type['Controller[BaseSerializer]'],
Expand Down
24 changes: 24 additions & 0 deletions tests/test_unit/test_controllers/test_controller_validation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import AsyncIterator
from http import HTTPStatus
from typing import Any

Expand All @@ -9,6 +10,7 @@
from dmr.exceptions import EndpointMetadataError
from dmr.options_mixins import AsyncMetaMixin, MetaMixin
from dmr.plugins.pydantic import PydanticSerializer
from dmr.plugins.pydantic.serializer import PydanticEndpointOptimizer


def test_controller_either_sync_or_async() -> None:
Expand Down Expand Up @@ -174,3 +176,25 @@ def handle_error(
exc: Exception,
) -> Any:
raise NotImplementedError


def test_endpoint_rejects_async_gen() -> None:
"""Ensure endpoints cannot be async generators."""

class _NoOpOptimizer(PydanticEndpointOptimizer):
@override
@classmethod
def optimize_endpoint(cls, metadata: Any) -> None: # noqa: WPS324
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, why do you need this type?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If i use controller like this

    with pytest.raises(
        EndpointMetadataError,
        match='is an async generator',
    ):

        class _BadController(Controller[PydanticSerializer]):
            async def get(self) -> AsyncIterator[int]:
                yield 1  # pragma: no cover

I have pydantic error =(

.venv/lib/python3.13/site-packages/pydantic/_internal/_generate_schema.py:659: in _unknown_type_schema
    raise PydanticSchemaGenerationError(
E   pydantic.errors.PydanticSchemaGenerationError: Unable to generate pydantic-core schema for collections.abc.AsyncIterator[int]. Set `arbitrary_types_allowed=True` in the model_config to ignore this error or implement `__get_pydantic_core_schema__` on your type to fully support it.
E
E   If you got this error by calling handler(<some type>) within `__get_pydantic_core_schema__` then you likely need to call `handler.generate_schema(<some type>)` since we do not call `__get_pydantic_core_schema__` on `<some type>` otherwise to avoid infinite recursion.
E
E   For further information visit https://errors.pydantic.dev/2.12/u/schema-for-unknown-type

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it!

Please, then also test that SSEController an JsonLinesController do not need this.

return None # noqa: WPS324

class _NoOpPydanticSerializer(PydanticSerializer):
optimizer = _NoOpOptimizer

with pytest.raises(
EndpointMetadataError,
match='is an async generator',
):

class _BadController(Controller[_NoOpPydanticSerializer]):
async def get(self) -> AsyncIterator[int]:
yield 1 # pragma: no cover
14 changes: 14 additions & 0 deletions tests/test_unit/test_plugins/test_msgspec/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections.abc import AsyncIterator
from http import HTTPMethod, HTTPStatus
from typing import final

Expand All @@ -14,6 +15,7 @@
from inline_snapshot import snapshot

from dmr import Body, Controller, ResponseSpec, modify, validate
from dmr.exceptions import EndpointMetadataError
from dmr.plugins.msgspec import MsgspecSerializer
from dmr.test import DMRRequestFactory

Expand Down Expand Up @@ -206,3 +208,15 @@ def test_msgspec_struct_renames_work(
assert response.status_code == HTTPStatus.CREATED, response.content
assert response.headers == {'Content-Type': 'application/json'}
assert json.loads(response.content) == request_data


def test_msgspec_rejects_async_gen() -> None:
"""Ensure msgspec controllers cannot define async generator endpoints."""
with pytest.raises(
EndpointMetadataError,
match='is an async generator',
):

class _BadController(Controller[MsgspecSerializer]):
async def get(self) -> AsyncIterator[int]:
yield 1 # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def test_sync_sse_prod_from_sync_method(

class _SSEWithClose(SSEController[PydanticSerializer]):
async def get(self) -> AsyncIterator[SSEvent[str | bytes | int]]:
return self._events()

async def _events(self) -> AsyncIterator[SSEvent[str | bytes | int]]:
yield SSEvent(b'event', serialize=False)
yield SSEvent(b'second', serialize=False)
raise StreamingCloseError
Expand All @@ -168,7 +171,9 @@ def test_sync_sse_dev_with_close(
settings.DEBUG = True
request = dmr_rf.get('/whatever/')

response = _SSEWithClose.as_view()(request)
response: StreamingResponse = async_to_sync(
_SSEWithClose.as_view(), # type: ignore[arg-type]
)(request)

assert isinstance(response, StreamingResponse)
assert response.streaming
Expand Down
Loading