diff --git a/litestar/_openapi/responses.py b/litestar/_openapi/responses.py index c267f91a27..50e0b781f9 100644 --- a/litestar/_openapi/responses.py +++ b/litestar/_openapi/responses.py @@ -27,7 +27,6 @@ Response as LitestarResponse, ) from litestar.response.base import ASGIResponse -from litestar.types.builtin_types import NoneType from litestar.typing import FieldDefinition from litestar.utils import get_enum_string_value, get_name @@ -120,12 +119,12 @@ def create_description(self) -> str: def create_success_response(self) -> OpenAPIResponse: """Create the schema for a success response.""" - if self.field_definition.is_subclass_of((NoneType, ASGIResponse)): - response = OpenAPIResponse(content=None, description=self.create_description()) - elif self.field_definition.is_subclass_of(Redirect): + if self.field_definition.is_subclass_of(Redirect): response = self.create_redirect_response() elif self.field_definition.is_subclass_of(File): response = self.create_file_response() + elif self.field_definition.is_subclass_of(ASGIResponse) or not self.route_handler.returns_content: + response = self.create_empty_response() else: media_type = self.route_handler.media_type @@ -163,6 +162,13 @@ def create_success_response(self) -> OpenAPIResponse: self.set_success_response_headers(response) return response + def create_empty_response(self) -> OpenAPIResponse: + """Create the schema for a response with no content.""" + return OpenAPIResponse( + content=None, + description=self.create_description(), + ) + def create_redirect_response(self) -> OpenAPIResponse: """Create the schema for a redirect response.""" return OpenAPIResponse( diff --git a/litestar/handlers/http_handlers/base.py b/litestar/handlers/http_handlers/base.py index fae9b8d213..7e08f6d1fa 100644 --- a/litestar/handlers/http_handlers/base.py +++ b/litestar/handlers/http_handlers/base.py @@ -51,11 +51,10 @@ Send, TypeEncodersMap, ) -from litestar.types.builtin_types import NoneType from litestar.utils import deprecated as litestar_deprecated from litestar.utils import ensure_async_callable from litestar.utils.empty import value_or_default -from litestar.utils.predicates import is_async_callable, is_class_and_subclass +from litestar.utils.predicates import is_async_callable from litestar.utils.scope.state import ScopeState from litestar.utils.warnings import warn_implicit_sync_to_thread, warn_sync_to_thread_with_async_callable @@ -531,6 +530,15 @@ def resolve_request_max_body_size(self) -> int | None: def request_max_body_size(self) -> int | None: return value_or_default(self._request_max_body_size, None) # pyright: ignore + @property + def returns_content(self) -> bool: + """Whether the route handler returns any content in the response body.""" + return not ( + self.status_code < 200 + or self.status_code in {HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED} + or self.http_methods == {HttpMethod.HEAD} + ) + def on_registration(self, route: BaseRoute, app: Litestar) -> None: super().on_registration(route=route, app=app) @@ -579,6 +587,15 @@ def _validate_handler_function(self) -> None: f"If {self} should return a value, change the route handler status code to an appropriate value.", ) + if self.http_methods == {HttpMethod.HEAD} and not ( + is_empty_response_annotation(return_type) + or return_type.is_subclass_of(File) + or return_type.is_subclass_of(ASGIFileResponse) + ): + raise ImproperlyConfiguredException( + f"{self}: Handlers for 'HEAD' requests must not return a value. Either return 'None' or a response type without a body." + ) + if not self.media_type: if return_type.is_subclass_of((str, bytes)) or return_type.annotation is AnyStr: self.media_type = MediaType.TEXT @@ -591,23 +608,6 @@ def _validate_handler_function(self) -> None: if "data" in self.parsed_fn_signature.parameters and "GET" in self.http_methods: raise ImproperlyConfiguredException("'data' kwarg is unsupported for 'GET' request handlers") - if self.http_methods == {HttpMethod.HEAD} and not self.parsed_fn_signature.return_type.is_subclass_of( - ( - NoneType, - File, - ASGIFileResponse, - ) - ): - field_definition = self.parsed_fn_signature.return_type - if not ( - is_empty_response_annotation(field_definition) - or is_class_and_subclass(field_definition.annotation, File) - or is_class_and_subclass(field_definition.annotation, ASGIFileResponse) - ): - raise ImproperlyConfiguredException( - f"{self}: Handlers for 'HEAD' requests must not return a value. Either return 'None' or a response type without a body." - ) - if (body_param := self.parsed_fn_signature.parameters.get("body")) and not body_param.is_subclass_of(bytes): raise ImproperlyConfiguredException( f"Invalid type annotation for 'body' parameter in route handler {self}. 'body' will always receive the " diff --git a/tests/unit/test_openapi/test_config.py b/tests/unit/test_openapi/test_config.py index a489ed88af..fed5a463d8 100644 --- a/tests/unit/test_openapi/test_config.py +++ b/tests/unit/test_openapi/test_config.py @@ -55,24 +55,15 @@ def handler_2() -> None: openapi_config=OpenAPIConfig(title="my title", version="1.0.0", operation_id_creator=operation_id_creator), ) - assert app.openapi_schema.to_schema()["paths"] == { - "/1": { - "get": { - "deprecated": False, - "operationId": "id_x", - "responses": {"200": {"description": "Request fulfilled, document follows", "headers": {}}}, - "summary": "Handler1", - } - }, - "/2": { - "get": { - "deprecated": False, - "operationId": "id_y", - "responses": {"200": {"description": "Request fulfilled, document follows", "headers": {}}}, - "summary": "Handler2", - } - }, - } + assert app.openapi_schema.paths is not None + + assert "/1" in app.openapi_schema.paths + assert app.openapi_schema.paths["/1"].get is not None + assert app.openapi_schema.paths["/1"].get.operation_id == "id_x" + + assert "/2" in app.openapi_schema.paths + assert app.openapi_schema.paths["/2"].get is not None + assert app.openapi_schema.paths["/2"].get.operation_id == "id_y" def test_raises_exception_when_no_config_in_place() -> None: diff --git a/tests/unit/test_openapi/test_responses.py b/tests/unit/test_openapi/test_responses.py index 196b665661..cb186b4f2d 100644 --- a/tests/unit/test_openapi/test_responses.py +++ b/tests/unit/test_openapi/test_responses.py @@ -10,7 +10,7 @@ import pytest from typing_extensions import TypeAlias -from litestar import Controller, Litestar, MediaType, Response, delete, get, post +from litestar import Controller, Litestar, MediaType, Response, delete, get, head, post from litestar._openapi.datastructures import OpenAPIContext from litestar._openapi.responses import ( ResponseFactory, @@ -30,6 +30,7 @@ from litestar.openapi.spec import Example, OpenAPIHeader, OpenAPIMediaType, OpenAPIResponse, Reference, Schema from litestar.openapi.spec.enums import OpenAPIType from litestar.response import File, Redirect, Stream, Template +from litestar.response.base import ASGIResponse from litestar.routes import HTTPRoute from litestar.status_codes import ( HTTP_200_OK, @@ -290,6 +291,75 @@ def redirect_handler() -> Redirect: assert location.description +def test_create_success_response_asgi_response(create_factory: CreateFactoryFixture) -> None: + @get(path="/test", name="test") + def handler() -> ASGIResponse: + return ASGIResponse() + + handler = get_registered_route_handler(handler, "test") + response = create_factory(handler, True).create_success_response() + + assert response.content is None + + +def test_create_success_response_none(create_factory: CreateFactoryFixture) -> None: + @get(path="/test", name="test") + def handler() -> None: + return None + + handler = get_registered_route_handler(handler, "test") + response = create_factory(handler, True).create_success_response() + + assert response.content + schema = response.content[handler.media_type].schema + assert isinstance(schema, Schema) + assert schema.type == OpenAPIType.NULL + + +def test_create_success_response_none_no_content(create_factory: CreateFactoryFixture) -> None: + @get(path="/test", status_code=HTTP_204_NO_CONTENT, name="test") + def handler() -> None: + return None + + handler = get_registered_route_handler(handler, "test") + response = create_factory(handler, True).create_success_response() + + assert response.content is None + + +def test_create_success_response_none_head(create_factory: CreateFactoryFixture) -> None: + @head(path="/test", name="test") + def handler() -> None: + return None + + handler = get_registered_route_handler(handler, "test") + response = create_factory(handler, True).create_success_response() + + assert response.content is None + + +def test_create_success_response_response_none_no_content(create_factory: CreateFactoryFixture) -> None: + @get(path="/test", status_code=HTTP_204_NO_CONTENT, name="test") + def handler() -> Response[None]: + return Response(None) + + handler = get_registered_route_handler(handler, "test") + response = create_factory(handler, True).create_success_response() + + assert response.content is None + + +def test_create_success_response_response_none_head(create_factory: CreateFactoryFixture) -> None: + @head(path="/test", name="test") + def handler() -> Response[None]: + return Response(None) + + handler = get_registered_route_handler(handler, "test") + response = create_factory(handler, True).create_success_response() + + assert response.content is None + + def test_create_success_response_no_content_explicit_responsespec( create_factory: CreateFactoryFixture, ) -> None: