From b4e4668157ab78d8640b96073572aeba46378f2f Mon Sep 17 00:00:00 2001 From: Paul Dubs Date: Mon, 24 Mar 2025 22:02:02 +0100 Subject: [PATCH 1/6] Initial Quart Subscription Support --- strawberry/quart/views.py | 111 +++++++++++++++++- tests/http/clients/quart.py | 110 ++++++++++++++++- tests/websockets/conftest.py | 1 + tests/websockets/test_graphql_transport_ws.py | 1 - 4 files changed, 214 insertions(+), 9 deletions(-) diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index b0ceb4acca..d785f42fdf 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,15 +1,27 @@ import warnings from collections.abc import AsyncGenerator, Mapping +from datetime import timedelta from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast from typing_extensions import TypeGuard +from json.decoder import JSONDecodeError -from quart import Request, Response, request +from quart import Quart, Request, Response, websocket, request +from quart.ctx import has_websocket_context from quart.views import View -from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter -from strawberry.http.exceptions import HTTPException +from strawberry.http.async_base_view import ( + AsyncBaseHTTPView, + AsyncHTTPRequestAdapter, + AsyncWebSocketAdapter +) +from strawberry.http.exceptions import ( + HTTPException, + NonJsonMessageReceived, + WebSocketDisconnected +) from strawberry.http.ides import GraphQL_IDE from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import Context, RootValue +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL if TYPE_CHECKING: from quart.typing import ResponseReturnValue @@ -46,6 +58,35 @@ async def get_form_data(self) -> FormData: return FormData(files=files, form=form) +class QuartWebSocketAdapter(AsyncWebSocketAdapter): + def __init__(self, view: AsyncBaseHTTPView, request, ws) -> None: + super().__init__(view) + self.ws = websocket + + async def iter_json( + self, *, ignore_parsing_errors: bool = False + ) -> AsyncGenerator[object, None]: + while True: + try: + message = await self.ws.receive() + try: + yield self.view.decode_json(message) + except JSONDecodeError as e: + if not ignore_parsing_errors: + raise NonJsonMessageReceived from e + except Exception as exc: + raise WebSocketDisconnected from exc + + async def send_json(self, message: Mapping[str, object]) -> None: + try: + await self.ws.send(self.view.encode_json(message)) + except Exception as exc: + raise WebSocketDisconnected from exc + + async def close(self, code: int, reason: str) -> None: + await self.ws.close(code, reason=reason) + + class GraphQLView( AsyncBaseHTTPView[ Request, Response, Response, Request, Response, Context, RootValue @@ -55,6 +96,7 @@ class GraphQLView( methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = QuartHTTPRequestAdapter + websocket_adapter_class = QuartWebSocketAdapter def __init__( self, @@ -62,10 +104,23 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, + keep_alive: bool = True, + keep_alive_interval: float = 1, + debug: bool = False, + subscription_protocols: list[str] = [ + GRAPHQL_TRANSPORT_WS_PROTOCOL, + GRAPHQL_WS_PROTOCOL, + ], + connection_init_wait_timeout: timedelta = timedelta(minutes=1), multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get + self.keep_alive = keep_alive + self.keep_alive_interval = keep_alive_interval + self.debug = debug + self.subscription_protocols = subscription_protocols + self.connection_init_wait_timeout = connection_init_wait_timeout self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: @@ -123,15 +178,59 @@ async def create_streaming_response( ) def is_websocket_request(self, request: Request) -> TypeGuard[Request]: - return False + if has_websocket_context(): + return True + + # Check if the request is a WebSocket upgrade request + connection = request.headers.get("Connection", "").lower() + upgrade = request.headers.get("Upgrade", "").lower() + + return ("upgrade" in connection and "websocket" in upgrade) async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: - raise NotImplementedError + # Get the requested protocols + protocols_header = websocket.headers.get("Sec-WebSocket-Protocol", "") + if not protocols_header: + return None + + # Find the first matching protocol + requested_protocols = [p.strip() for p in protocols_header.split(",")] + for protocol in requested_protocols: + if protocol in self.subscription_protocols: + return protocol + + return None async def create_websocket_response( self, request: Request, subprotocol: Optional[str] ) -> Response: - raise NotImplementedError + if subprotocol: + # Set the WebSocket protocol if specified + await websocket.accept(subprotocol=subprotocol) + else: + await websocket.accept() + + # Return the current websocket context as the "response" + return None + + @classmethod + def register_route(cls, app: Quart, rule_name: str, path: str, **kwargs): + """ + Helper method to register both HTTP and WebSocket handlers for a given path. + + Args: + app: The Quart application + rule_name: The name of the rule + path: The path to register the handlers for + **kwargs: Parameters to pass to the GraphQLView constructor + """ + # Register both HTTP and WebSocket handler at the same path + view_func = cls.as_view(rule_name, **kwargs) + app.add_url_rule(path, view_func=view_func, methods=["GET", "POST"]) + + # Register the WebSocket handler using the same view function + # Quart will handle routing based on the WebSocket upgrade header + app.add_url_rule(path, view_func=view_func, methods=["GET"], websocket=True) __all__ = ["GraphQLView"] diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 1711e58b45..772fc6c744 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -1,9 +1,15 @@ +import asyncio +import contextlib import json import urllib.parse from io import BytesIO -from typing import Any, Optional +from typing import Any, Optional, AsyncGenerator, Mapping + +from quart.typing import TestWebsocketConnectionProtocol from typing_extensions import Literal +from starlette.testclient import TestClient, WebSocketTestSession + from quart import Quart from quart import Request as QuartRequest from quart import Response as QuartResponse @@ -15,7 +21,8 @@ from tests.http.context import get_context from tests.views.schema import Query, schema -from .base import JSON, HttpClient, Response, ResultOverrideFunction +from .base import JSON, HttpClient, Response, ResultOverrideFunction, WebSocketClient, \ + Message class GraphQLView(BaseGraphQLView[dict[str, object], object]): @@ -73,6 +80,34 @@ def __init__( "/graphql", view_func=view, ) + self.app.add_url_rule( + '/graphql', + view_func=view, + methods=["GET"], + websocket=True + ) + + self.client = TestClient(self.app) + + def create_app(self, **kwargs: Any) -> None: + self.app = Quart(__name__) + self.app.debug = True + + view = GraphQLView.as_view("graphql_view", schema=schema, **kwargs) + + self.app.add_url_rule( + "/graphql", + view_func=view, + ) + self.app.add_url_rule( + '/graphql', + view_func=view, + methods=["GET"], + websocket=True + ) + + self.client = TestClient(self.app) + async def _graphql_request( self, @@ -140,3 +175,74 @@ async def post( return await self.request( url, "post", **{k: v for k, v in kwargs.items() if v is not None} ) + + @contextlib.asynccontextmanager + async def ws_connect( + self, + url: str, + *, + protocols: list[str], + ) -> AsyncGenerator[WebSocketClient, None]: + with self.client.websocket_connect(url, protocols) as ws: + yield QuartWebSocketClient(ws) + + + +class QuartWebSocketClient(WebSocketClient): + def __init__(self, ws: WebSocketTestSession): + self.ws = ws + self._closed: bool = False + self._close_code: Optional[int] = None + self._close_reason: Optional[str] = None + + async def send_text(self, payload: str) -> None: + self.ws.send_text(payload) + + async def send_json(self, payload: Mapping[str, object]) -> None: + self.ws.send_json(payload) + + async def send_bytes(self, payload: bytes) -> None: + self.ws.send_bytes(payload) + + async def receive(self, timeout: Optional[float] = None) -> Message: + if self._closed: + # if close was received via exception, fake it so that recv works + return Message( + type="websocket.close", data=self._close_code, extra=self._close_reason + ) + m = self.ws.receive() + if m["type"] == "websocket.close": + self._closed = True + self._close_code = m["code"] + self._close_reason = m.get("reason", None) + return Message(type=m["type"], data=m["code"], extra=m.get("reason", None)) + if m["type"] == "websocket.send": + return Message(type=m["type"], data=m["text"]) + return Message(type=m["type"], data=m["data"], extra=m["extra"]) + + async def receive_json(self, timeout: Optional[float] = None) -> Any: + m = self.ws.receive() + assert m["type"] == "websocket.send" + assert "text" in m + return json.loads(m["text"]) + + async def close(self) -> None: + self.ws.close() + self._closed = True + + @property + def accepted_subprotocol(self) -> Optional[str]: + return self.ws.accepted_subprotocol + + @property + def closed(self) -> bool: + return self._closed + + @property + def close_code(self) -> int: + assert self._close_code is not None + return self._close_code + + @property + def close_reason(self) -> Optional[str]: + return self._close_reason diff --git a/tests/websockets/conftest.py b/tests/websockets/conftest.py index 7b784c2168..9fd56317b2 100644 --- a/tests/websockets/conftest.py +++ b/tests/websockets/conftest.py @@ -14,6 +14,7 @@ def _get_http_client_classes() -> Generator[Any, None, None]: ("ChannelsHttpClient", "channels", [pytest.mark.channels]), ("FastAPIHttpClient", "fastapi", [pytest.mark.fastapi]), ("LitestarHttpClient", "litestar", [pytest.mark.litestar]), + ("QuartHttpClient", "quart", [pytest.mark.quart]), ]: try: client_class = getattr( diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 2b5ea8afe4..3a7f94df0a 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -72,7 +72,6 @@ def assert_next( async def test_unknown_message_type(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_json({"type": "NOT_A_MESSAGE_TYPE"}) await ws.receive(timeout=2) From b61023dc28a5f48de4b0066dfefede5f7d0dafc7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Mar 2025 21:09:49 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/quart/views.py | 13 ++++++------- tests/http/clients/quart.py | 30 +++++++++++++----------------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index d785f42fdf..05808e019c 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,22 +1,22 @@ import warnings from collections.abc import AsyncGenerator, Mapping from datetime import timedelta +from json.decoder import JSONDecodeError from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast from typing_extensions import TypeGuard -from json.decoder import JSONDecodeError -from quart import Quart, Request, Response, websocket, request +from quart import Quart, Request, Response, request, websocket from quart.ctx import has_websocket_context from quart.views import View from strawberry.http.async_base_view import ( AsyncBaseHTTPView, AsyncHTTPRequestAdapter, - AsyncWebSocketAdapter + AsyncWebSocketAdapter, ) from strawberry.http.exceptions import ( HTTPException, NonJsonMessageReceived, - WebSocketDisconnected + WebSocketDisconnected, ) from strawberry.http.ides import GraphQL_IDE from strawberry.http.types import FormData, HTTPMethod, QueryParams @@ -185,7 +185,7 @@ def is_websocket_request(self, request: Request) -> TypeGuard[Request]: connection = request.headers.get("Connection", "").lower() upgrade = request.headers.get("Upgrade", "").lower() - return ("upgrade" in connection and "websocket" in upgrade) + return "upgrade" in connection and "websocket" in upgrade async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: # Get the requested protocols @@ -215,8 +215,7 @@ async def create_websocket_response( @classmethod def register_route(cls, app: Quart, rule_name: str, path: str, **kwargs): - """ - Helper method to register both HTTP and WebSocket handlers for a given path. + """Helper method to register both HTTP and WebSocket handlers for a given path. Args: app: The Quart application diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 772fc6c744..c3d8ded99f 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -1,11 +1,9 @@ -import asyncio import contextlib import json import urllib.parse +from collections.abc import AsyncGenerator, Mapping from io import BytesIO -from typing import Any, Optional, AsyncGenerator, Mapping - -from quart.typing import TestWebsocketConnectionProtocol +from typing import Any, Optional from typing_extensions import Literal from starlette.testclient import TestClient, WebSocketTestSession @@ -21,8 +19,14 @@ from tests.http.context import get_context from tests.views.schema import Query, schema -from .base import JSON, HttpClient, Response, ResultOverrideFunction, WebSocketClient, \ - Message +from .base import ( + JSON, + HttpClient, + Message, + Response, + ResultOverrideFunction, + WebSocketClient, +) class GraphQLView(BaseGraphQLView[dict[str, object], object]): @@ -81,10 +85,7 @@ def __init__( view_func=view, ) self.app.add_url_rule( - '/graphql', - view_func=view, - methods=["GET"], - websocket=True + "/graphql", view_func=view, methods=["GET"], websocket=True ) self.client = TestClient(self.app) @@ -100,15 +101,11 @@ def create_app(self, **kwargs: Any) -> None: view_func=view, ) self.app.add_url_rule( - '/graphql', - view_func=view, - methods=["GET"], - websocket=True + "/graphql", view_func=view, methods=["GET"], websocket=True ) self.client = TestClient(self.app) - async def _graphql_request( self, method: Literal["get", "post"], @@ -187,7 +184,6 @@ async def ws_connect( yield QuartWebSocketClient(ws) - class QuartWebSocketClient(WebSocketClient): def __init__(self, ws: WebSocketTestSession): self.ws = ws @@ -214,7 +210,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message: if m["type"] == "websocket.close": self._closed = True self._close_code = m["code"] - self._close_reason = m.get("reason", None) + self._close_reason = m.get("reason", None) return Message(type=m["type"], data=m["code"], extra=m.get("reason", None)) if m["type"] == "websocket.send": return Message(type=m["type"], data=m["text"]) From b01e13460399537e5eddbbb10b253a7ac5aeb834 Mon Sep 17 00:00:00 2001 From: Paul Dubs Date: Tue, 25 Mar 2025 11:28:43 +0100 Subject: [PATCH 3/6] Use Quart Test Client --- strawberry/quart/views.py | 16 +++--- tests/http/clients/quart.py | 97 ++++++++++++++++++++++++++++--------- 2 files changed, 83 insertions(+), 30 deletions(-) diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 05808e019c..5246531f84 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -17,6 +17,7 @@ HTTPException, NonJsonMessageReceived, WebSocketDisconnected, + NonTextMessageReceived ) from strawberry.http.ides import GraphQL_IDE from strawberry.http.types import FormData, HTTPMethod, QueryParams @@ -67,15 +68,14 @@ async def iter_json( self, *, ignore_parsing_errors: bool = False ) -> AsyncGenerator[object, None]: while True: + message = await self.ws.receive() + if type(message) is bytes: + raise NonTextMessageReceived try: - message = await self.ws.receive() - try: - yield self.view.decode_json(message) - except JSONDecodeError as e: - if not ignore_parsing_errors: - raise NonJsonMessageReceived from e - except Exception as exc: - raise WebSocketDisconnected from exc + yield self.view.decode_json(message) + except JSONDecodeError as e: + if not ignore_parsing_errors: + raise NonJsonMessageReceived from e async def send_json(self, message: Mapping[str, object]) -> None: try: diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index c3d8ded99f..3d5370ca0c 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -1,21 +1,30 @@ +import asyncio import contextlib import json import urllib.parse from collections.abc import AsyncGenerator, Mapping from io import BytesIO +from typing import Any, Optional, AsyncGenerator, Mapping, Union + +from asgiref.typing import ASGISendEvent +from hypercorn.typing import WebsocketScope +from quart.typing import TestWebsocketConnectionProtocol +from quart.utils import decode_headers from typing import Any, Optional from typing_extensions import Literal -from starlette.testclient import TestClient, WebSocketTestSession - from quart import Quart from quart import Request as QuartRequest from quart import Response as QuartResponse from quart.datastructures import FileStorage +from quart.testing.connections import TestWebsocketConnection + +from strawberry.exceptions import ConnectionRejectionError from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.quart.views import GraphQLView as BaseGraphQLView from strawberry.types import ExecutionResult +from strawberry.types.unset import UnsetType, UNSET from tests.http.context import get_context from tests.views.schema import Query, schema @@ -26,16 +35,19 @@ Response, ResultOverrideFunction, WebSocketClient, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler ) class GraphQLView(BaseGraphQLView[dict[str, object], object]): methods = ["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"] - + graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler + graphql_ws_handler_class = DebuggableGraphQLWSHandler result_override: ResultOverrideFunction = None def __init__(self, *args: Any, **kwargs: Any): - self.result_override = kwargs.pop("result_override") + self.result_override = kwargs.pop("result_override", None) super().__init__(*args, **kwargs) async def get_root_value(self, request: QuartRequest) -> Query: @@ -57,6 +69,28 @@ async def process_result( return await super().process_result(request, result) + async def on_ws_connect( + self, context: dict[str, object] + ) -> Union[UnsetType, None, dict[str, object]]: + connection_params = context["connection_params"] + + if isinstance(connection_params, dict): + if connection_params.get("test-reject"): + if "err-payload" in connection_params: + raise ConnectionRejectionError(connection_params["err-payload"]) + raise ConnectionRejectionError + + if connection_params.get("test-accept"): + if "ack-payload" in connection_params: + return connection_params["ack-payload"] + return UNSET + + if connection_params.get("test-modify"): + connection_params["modified"] = True + return UNSET + + return await super().on_ws_connect(context) + class QuartHttpClient(HttpClient): def __init__( @@ -88,8 +122,6 @@ def __init__( "/graphql", view_func=view, methods=["GET"], websocket=True ) - self.client = TestClient(self.app) - def create_app(self, **kwargs: Any) -> None: self.app = Quart(__name__) self.app.debug = True @@ -104,8 +136,6 @@ def create_app(self, **kwargs: Any) -> None: "/graphql", view_func=view, methods=["GET"], websocket=True ) - self.client = TestClient(self.app) - async def _graphql_request( self, method: Literal["get", "post"], @@ -180,25 +210,48 @@ async def ws_connect( *, protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: - with self.client.websocket_connect(url, protocols) as ws: - yield QuartWebSocketClient(ws) - + headers = { + 'sec-websocket-protocol': ", ".join(protocols), + } + async with self.app.test_app() as test_app: + client = test_app.test_client() + client.websocket_connection_class = QuartTestWebsocketConnection + async with client.websocket(url, headers=headers, subprotocols=protocols) as ws: + yield QuartWebSocketClient(ws) + +class QuartTestWebsocketConnection(TestWebsocketConnection): + def __init__(self, app: Quart, scope: WebsocketScope) -> None: + scope['asgi'] = {'spec_version': '2.3'} + super().__init__(app, scope) + + async def _asgi_send(self, message: ASGISendEvent) -> None: + if message["type"] == "websocket.accept": + self.accepted = True + elif message["type"] == "websocket.send": + await self._receive_queue.put(message.get("bytes") or message.get("text")) + elif message["type"] == "websocket.http.response.start": + self.headers = decode_headers(message["headers"]) + self.status_code = message["status"] + elif message["type"] == "websocket.http.response.body": + self.response_data.extend(message["body"]) + elif message["type"] == "websocket.close": + await self._receive_queue.put(json.dumps(message)) class QuartWebSocketClient(WebSocketClient): - def __init__(self, ws: WebSocketTestSession): + def __init__(self, ws: TestWebsocketConnectionProtocol): self.ws = ws self._closed: bool = False self._close_code: Optional[int] = None self._close_reason: Optional[str] = None async def send_text(self, payload: str) -> None: - self.ws.send_text(payload) + await self.ws.send(payload) async def send_json(self, payload: Mapping[str, object]) -> None: - self.ws.send_json(payload) + await self.ws.send_json(payload) async def send_bytes(self, payload: bytes) -> None: - self.ws.send_bytes(payload) + await self.ws.send(payload) async def receive(self, timeout: Optional[float] = None) -> Message: if self._closed: @@ -206,7 +259,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message: return Message( type="websocket.close", data=self._close_code, extra=self._close_reason ) - m = self.ws.receive() + m = await asyncio.wait_for(self.ws.receive_json(), timeout=timeout) if m["type"] == "websocket.close": self._closed = True self._close_code = m["code"] @@ -214,21 +267,21 @@ async def receive(self, timeout: Optional[float] = None) -> Message: return Message(type=m["type"], data=m["code"], extra=m.get("reason", None)) if m["type"] == "websocket.send": return Message(type=m["type"], data=m["text"]) + if m['type'] == "connection_ack": + return Message(type=m['type'], data='') return Message(type=m["type"], data=m["data"], extra=m["extra"]) async def receive_json(self, timeout: Optional[float] = None) -> Any: - m = self.ws.receive() - assert m["type"] == "websocket.send" - assert "text" in m - return json.loads(m["text"]) + m = await asyncio.wait_for(self.ws.receive_json(), timeout=timeout) + return m async def close(self) -> None: - self.ws.close() + await self.ws.close(1000) self._closed = True @property def accepted_subprotocol(self) -> Optional[str]: - return self.ws.accepted_subprotocol + return "" @property def closed(self) -> bool: From d3384b87eb5940e5a3b7da0c225b3813f9c12069 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Mar 2025 10:35:05 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/quart/views.py | 2 +- tests/http/clients/quart.py | 32 +++++++++++++++++--------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 5246531f84..9819bb0d25 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -16,8 +16,8 @@ from strawberry.http.exceptions import ( HTTPException, NonJsonMessageReceived, + NonTextMessageReceived, WebSocketDisconnected, - NonTextMessageReceived ) from strawberry.http.ides import GraphQL_IDE from strawberry.http.types import FormData, HTTPMethod, QueryParams diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 3d5370ca0c..a757c6f7e2 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -4,39 +4,37 @@ import urllib.parse from collections.abc import AsyncGenerator, Mapping from io import BytesIO -from typing import Any, Optional, AsyncGenerator, Mapping, Union +from typing import Any, Optional, Union +from typing_extensions import Literal from asgiref.typing import ASGISendEvent from hypercorn.typing import WebsocketScope -from quart.typing import TestWebsocketConnectionProtocol -from quart.utils import decode_headers -from typing import Any, Optional -from typing_extensions import Literal from quart import Quart from quart import Request as QuartRequest from quart import Response as QuartResponse from quart.datastructures import FileStorage from quart.testing.connections import TestWebsocketConnection - +from quart.typing import TestWebsocketConnectionProtocol +from quart.utils import decode_headers from strawberry.exceptions import ConnectionRejectionError from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.quart.views import GraphQLView as BaseGraphQLView from strawberry.types import ExecutionResult -from strawberry.types.unset import UnsetType, UNSET +from strawberry.types.unset import UNSET, UnsetType from tests.http.context import get_context from tests.views.schema import Query, schema from .base import ( JSON, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Message, Response, ResultOverrideFunction, WebSocketClient, - DebuggableGraphQLTransportWSHandler, - DebuggableGraphQLWSHandler ) @@ -211,17 +209,20 @@ async def ws_connect( protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: headers = { - 'sec-websocket-protocol': ", ".join(protocols), + "sec-websocket-protocol": ", ".join(protocols), } async with self.app.test_app() as test_app: client = test_app.test_client() client.websocket_connection_class = QuartTestWebsocketConnection - async with client.websocket(url, headers=headers, subprotocols=protocols) as ws: + async with client.websocket( + url, headers=headers, subprotocols=protocols + ) as ws: yield QuartWebSocketClient(ws) + class QuartTestWebsocketConnection(TestWebsocketConnection): def __init__(self, app: Quart, scope: WebsocketScope) -> None: - scope['asgi'] = {'spec_version': '2.3'} + scope["asgi"] = {"spec_version": "2.3"} super().__init__(app, scope) async def _asgi_send(self, message: ASGISendEvent) -> None: @@ -237,6 +238,7 @@ async def _asgi_send(self, message: ASGISendEvent) -> None: elif message["type"] == "websocket.close": await self._receive_queue.put(json.dumps(message)) + class QuartWebSocketClient(WebSocketClient): def __init__(self, ws: TestWebsocketConnectionProtocol): self.ws = ws @@ -267,12 +269,12 @@ async def receive(self, timeout: Optional[float] = None) -> Message: return Message(type=m["type"], data=m["code"], extra=m.get("reason", None)) if m["type"] == "websocket.send": return Message(type=m["type"], data=m["text"]) - if m['type'] == "connection_ack": - return Message(type=m['type'], data='') + if m["type"] == "connection_ack": + return Message(type=m["type"], data="") return Message(type=m["type"], data=m["data"], extra=m["extra"]) async def receive_json(self, timeout: Optional[float] = None) -> Any: - m = await asyncio.wait_for(self.ws.receive_json(), timeout=timeout) + m = await asyncio.wait_for(self.ws.receive_json(), timeout=timeout) return m async def close(self) -> None: From 7e751087b35f7a30d22e36cd71acba94d59bb709 Mon Sep 17 00:00:00 2001 From: Paul Dubs Date: Fri, 28 Mar 2025 11:44:04 +0100 Subject: [PATCH 5/6] Apply actionable feedback --- strawberry/quart/views.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 9819bb0d25..8eadc9e27e 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,8 +1,9 @@ +import asyncio import warnings from collections.abc import AsyncGenerator, Mapping from datetime import timedelta from json.decoder import JSONDecodeError -from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast +from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast, Sequence from typing_extensions import TypeGuard from quart import Quart, Request, Response, request, websocket @@ -80,7 +81,7 @@ async def iter_json( async def send_json(self, message: Mapping[str, object]) -> None: try: await self.ws.send(self.view.encode_json(message)) - except Exception as exc: + except asyncio.CancelledError as exc: raise WebSocketDisconnected from exc async def close(self, code: int, reason: str) -> None: @@ -107,7 +108,7 @@ def __init__( keep_alive: bool = True, keep_alive_interval: float = 1, debug: bool = False, - subscription_protocols: list[str] = [ + subscription_protocols: Sequence[str] = [ GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL, ], @@ -204,12 +205,7 @@ async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: async def create_websocket_response( self, request: Request, subprotocol: Optional[str] ) -> Response: - if subprotocol: - # Set the WebSocket protocol if specified - await websocket.accept(subprotocol=subprotocol) - else: - await websocket.accept() - + await websocket.accept(subprotocol=subprotocol) # Return the current websocket context as the "response" return None From f8b192c11e1f9edd002e389c21d55e2a407cd1e9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Mar 2025 10:45:26 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/quart/views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 8eadc9e27e..03a57c6bea 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,9 +1,9 @@ import asyncio import warnings -from collections.abc import AsyncGenerator, Mapping +from collections.abc import AsyncGenerator, Mapping, Sequence from datetime import timedelta from json.decoder import JSONDecodeError -from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast, Sequence +from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast from typing_extensions import TypeGuard from quart import Quart, Request, Response, request, websocket