diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 6d1a7c9e44..fa624dbde0 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -13,7 +13,7 @@ cast, overload, ) -from typing_extensions import Literal, TypeGuard +from typing_extensions import Literal, Self, TypeGuard from graphql import GraphQLError @@ -113,7 +113,7 @@ class AsyncBaseHTTPView( request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter] websocket_adapter_class: Callable[ [ - "AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue]", + Self, WebSocketRequest, WebSocketResponse, ], @@ -361,7 +361,7 @@ async def run( response_data=response_data, sub_response=sub_response ) - def encode_multipart_data(self, data: Any, separator: str) -> str: + def encode_multipart_data(self, data: object, separator: str) -> str: return "".join( [ f"\r\n--{separator}\r\n", diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 1323363faa..be5996c0c0 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -2,20 +2,31 @@ import json import warnings +from collections.abc import AsyncGenerator, Sequence +from datetime import timedelta +from json.decoder import JSONDecodeError from typing import ( TYPE_CHECKING, Any, Callable, Optional, + Union, cast, ) from typing_extensions import TypeGuard -from sanic.request import Request -from sanic.response import HTTPResponse, html +from sanic import HTTPResponse, Request, Websocket, html from sanic.views import HTTPMethodView -from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter -from strawberry.http.exceptions import HTTPException +from strawberry.http.async_base_view import ( + AsyncBaseHTTPView, + AsyncHTTPRequestAdapter, + AsyncWebSocketAdapter, +) +from strawberry.http.exceptions import ( + HTTPException, + NonJsonMessageReceived, + NonTextMessageReceived, +) from strawberry.http.temporal_response import TemporalResponse from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import ( @@ -23,6 +34,7 @@ RootValue, ) from strawberry.sanic.utils import convert_request_to_files_dict +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL if TYPE_CHECKING: from collections.abc import AsyncGenerator, Mapping @@ -69,13 +81,40 @@ async def get_form_data(self) -> FormData: return FormData(form=self.request.form, files=files) +class SanicWebSocketAdapter(AsyncWebSocketAdapter): + def __init__( + self, view: AsyncBaseHTTPView, request: Websocket, response: Websocket + ) -> None: + super().__init__(view) + self.ws = request + + async def iter_json( + self, *, ignore_parsing_errors: bool = False + ) -> AsyncGenerator[object, None]: + async for message in self.ws: + if not isinstance(message, str): + raise NonTextMessageReceived + + try: + yield self.view.decode_json(message) + except JSONDecodeError as e: + if not ignore_parsing_errors: + raise NonJsonMessageReceived from e + + async def send_json(self, message: Mapping[str, object]) -> None: + await self.ws.send(self.view.encode_json(message)) + + async def close(self, code: int, reason: str) -> None: + await self.ws.close(code, reason) + + class GraphQLView( AsyncBaseHTTPView[ Request, HTTPResponse, TemporalResponse, - Request, - TemporalResponse, + Websocket, + Websocket, Context, RootValue, ], @@ -100,6 +139,7 @@ class GraphQLView( allow_queries_via_get = True request_adapter_class = SanicHTTPRequestAdapter + websocket_adapter_class = SanicWebSocketAdapter def __init__( self, @@ -107,12 +147,25 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, + keep_alive: bool = True, + keep_alive_interval: float = 1, + debug: bool = False, + subscription_protocols: Sequence[str] = ( + GRAPHQL_TRANSPORT_WS_PROTOCOL, + GRAPHQL_WS_PROTOCOL, + ), + connection_init_wait_timeout: timedelta = timedelta(minutes=1), json_encoder: Optional[type[json.JSONEncoder]] = None, json_dumps_params: Optional[dict[str, Any]] = None, multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get + self.keep_alive = keep_alive + self.keep_alive_interval = keep_alive_interval + self.debug = debug + self.protocols = subscription_protocols + self.connection_init_wait_timeout = connection_init_wait_timeout self.json_encoder = json_encoder self.json_dumps_params = json_dumps_params self.multipart_uploads_enabled = multipart_uploads_enabled @@ -143,11 +196,15 @@ def __init__( else: self.graphql_ide = graphql_ide - async def get_root_value(self, request: Request) -> Optional[RootValue]: + async def get_root_value( + self, request: Union[Request, Websocket] + ) -> Optional[RootValue]: return None async def get_context( - self, request: Request, response: TemporalResponse + self, + request: Union[Request, Websocket], + response: Union[TemporalResponse, Websocket], ) -> Context: return {"request": request, "response": response} # type: ignore @@ -187,6 +244,9 @@ async def get(self, request: Request) -> HTTPResponse: except HTTPException as e: return HTTPResponse(e.reason, status=e.status_code) + async def websocket(self, request: Request, ws: Websocket) -> Websocket: + return await self.run(ws) + async def create_streaming_response( self, request: Request, @@ -213,16 +273,19 @@ async def create_streaming_response( # corner case return None # type: ignore - def is_websocket_request(self, request: Request) -> TypeGuard[Request]: - return False + def is_websocket_request( + self, request: Union[Request, Websocket] + ) -> TypeGuard[Websocket]: + # TODO: sanic gives us a WebSocketConnection when ASGI is used, which has a completely different inferface??? + return isinstance(request, Websocket) - async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: - raise NotImplementedError + async def pick_websocket_subprotocol(self, request: Websocket) -> Optional[str]: + return None async def create_websocket_response( - self, request: Request, subprotocol: Optional[str] - ) -> TemporalResponse: - raise NotImplementedError + self, request: Websocket, subprotocol: Optional[str] + ) -> Websocket: + return request __all__ = ["GraphQLView"] diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index 5bd169998f..6ec6494685 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -1,13 +1,19 @@ from __future__ import annotations +import contextlib +import uuid +from collections.abc import AsyncGenerator from io import BytesIO from json import dumps from random import randint -from typing import Any, Optional +from typing import Any, Optional, Union from typing_extensions import Literal +from starlette.testclient import TestClient + +from sanic import Request as SanicRequest from sanic import Sanic -from sanic.request import Request as SanicRequest +from sanic import Websocket as SanicWebsocket from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.http.temporal_response import TemporalResponse @@ -15,24 +21,40 @@ from strawberry.types import ExecutionResult from tests.http.context import get_context from tests.views.schema import Query, schema +from tests.websockets.views import OnWSConnectMixin -from .base import JSON, HttpClient, Response, ResultOverrideFunction +from .asgi import AsgiWebSocketClient +from .base import ( + JSON, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, + HttpClient, + Response, + ResultOverrideFunction, + WebSocketClient, +) -class GraphQLView(BaseGraphQLView[object, Query]): +class GraphQLView(OnWSConnectMixin, BaseGraphQLView[dict[str, object], object]): result_override: ResultOverrideFunction = None + graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler + graphql_ws_handler_class = DebuggableGraphQLWSHandler def __init__(self, *args: Any, **kwargs: Any): - self.result_override = kwargs.pop("result_override") + self.result_override = kwargs.pop("result_override", None) super().__init__(*args, **kwargs) - async def get_root_value(self, request: SanicRequest) -> Query: + async def get_root_value( + self, request: Union[SanicRequest, SanicWebsocket] + ) -> Query: await super().get_root_value(request) # for coverage return Query() async def get_context( - self, request: SanicRequest, response: TemporalResponse - ) -> object: + self, + request: Union[SanicRequest, SanicWebsocket], + response: Union[TemporalResponse, SanicWebsocket], + ) -> dict[str, object]: context = await super().get_context(request, response) return get_context(context) @@ -58,17 +80,38 @@ def __init__( self.app = Sanic( f"test_{int(randint(0, 1000))}", # noqa: S311 ) - view = GraphQLView.as_view( + http_view = GraphQLView.as_view( schema=schema, graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + keep_alive=False, multipart_uploads_enabled=multipart_uploads_enabled, ) - self.app.add_route( - view, - "/graphql", + ws_view = GraphQLView( + schema=schema, + graphiql=graphiql, + graphql_ide=graphql_ide, + allow_queries_via_get=allow_queries_via_get, + result_override=result_override, + keep_alive=False, + multipart_uploads_enabled=multipart_uploads_enabled, + ) + # self.app.add_route(http_view, "/graphql") + + # TODO: do we need the ws view here even? + self.app.add_websocket_route(ws_view.websocket, "/graphql", subprotocols=[]) + + def create_app(self, **kwargs: Any) -> None: + self.app = Sanic(f"test-{uuid.uuid4().hex}") + http_view = GraphQLView.as_view(schema=schema, **kwargs) + ws_view = GraphQLView(schema=schema, **kwargs) + # self.app.add_route(http_view, "/graphql") + + protocols = kwargs.get("subscription_protocols", []) + self.app.add_websocket_route( + ws_view.websocket, "/graphql", subprotocols=protocols ) async def _graphql_request( @@ -153,3 +196,14 @@ async def post( data=response.content, headers=response.headers, ) + + @contextlib.asynccontextmanager + async def ws_connect( + self, + url: str, + *, + protocols: list[str], + ) -> AsyncGenerator[WebSocketClient, None]: + with TestClient(self.app) as client: + with client.websocket_connect(url, protocols) as ws: + yield AsgiWebSocketClient(ws) diff --git a/tests/websockets/conftest.py b/tests/websockets/conftest.py index 9fd56317b2..6c8b4fab2c 100644 --- a/tests/websockets/conftest.py +++ b/tests/websockets/conftest.py @@ -15,6 +15,7 @@ def _get_http_client_classes() -> Generator[Any, None, None]: ("FastAPIHttpClient", "fastapi", [pytest.mark.fastapi]), ("LitestarHttpClient", "litestar", [pytest.mark.litestar]), ("QuartHttpClient", "quart", [pytest.mark.quart]), + ("SanicHttpClient", "sanic", [pytest.mark.sanic]), ]: try: client_class = getattr(