Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
cast,
overload,
)
from typing_extensions import Literal, TypeGuard
from typing_extensions import Literal, Self, TypeGuard

from graphql import GraphQLError

Expand Down Expand Up @@ -113,7 +113,7 @@ class AsyncBaseHTTPView(
request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter]
websocket_adapter_class: Callable[
[
"AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue]",
Self,
WebSocketRequest,
WebSocketResponse,
],
Expand Down Expand Up @@ -361,7 +361,7 @@ async def run(
response_data=response_data, sub_response=sub_response
)

def encode_multipart_data(self, data: Any, separator: str) -> str:
def encode_multipart_data(self, data: object, separator: str) -> str:
return "".join(
[
f"\r\n--{separator}\r\n",
Expand Down
93 changes: 78 additions & 15 deletions strawberry/sanic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,39 @@

import json
import warnings
from collections.abc import AsyncGenerator, Sequence
from datetime import timedelta
from json.decoder import JSONDecodeError
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
cast,
)
from typing_extensions import TypeGuard

from sanic.request import Request
from sanic.response import HTTPResponse, html
from sanic import HTTPResponse, Request, Websocket, html
from sanic.views import HTTPMethodView
from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter
from strawberry.http.exceptions import HTTPException
from strawberry.http.async_base_view import (
AsyncBaseHTTPView,
AsyncHTTPRequestAdapter,
AsyncWebSocketAdapter,
)
from strawberry.http.exceptions import (
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
)
from strawberry.http.temporal_response import TemporalResponse
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import (
Context,
RootValue,
)
from strawberry.sanic.utils import convert_request_to_files_dict
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Mapping
Expand Down Expand Up @@ -69,13 +81,40 @@
return FormData(form=self.request.form, files=files)


class SanicWebSocketAdapter(AsyncWebSocketAdapter):
def __init__(
self, view: AsyncBaseHTTPView, request: Websocket, response: Websocket
) -> None:
super().__init__(view)
self.ws = request

Check warning on line 89 in strawberry/sanic/views.py

View check run for this annotation

Codecov / codecov/patch

strawberry/sanic/views.py#L88-L89

Added lines #L88 - L89 were not covered by tests

async def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[object, None]:
async for message in self.ws:
if not isinstance(message, str):
raise NonTextMessageReceived

Check warning on line 96 in strawberry/sanic/views.py

View check run for this annotation

Codecov / codecov/patch

strawberry/sanic/views.py#L96

Added line #L96 was not covered by tests

try:
yield self.view.decode_json(message)
except JSONDecodeError as e:

Check warning on line 100 in strawberry/sanic/views.py

View check run for this annotation

Codecov / codecov/patch

strawberry/sanic/views.py#L98-L100

Added lines #L98 - L100 were not covered by tests
if not ignore_parsing_errors:
raise NonJsonMessageReceived from e

Check warning on line 102 in strawberry/sanic/views.py

View check run for this annotation

Codecov / codecov/patch

strawberry/sanic/views.py#L102

Added line #L102 was not covered by tests

async def send_json(self, message: Mapping[str, object]) -> None:
await self.ws.send(self.view.encode_json(message))

Check warning on line 105 in strawberry/sanic/views.py

View check run for this annotation

Codecov / codecov/patch

strawberry/sanic/views.py#L105

Added line #L105 was not covered by tests

async def close(self, code: int, reason: str) -> None:
await self.ws.close(code, reason)

Check warning on line 108 in strawberry/sanic/views.py

View check run for this annotation

Codecov / codecov/patch

strawberry/sanic/views.py#L108

Added line #L108 was not covered by tests


class GraphQLView(
AsyncBaseHTTPView[
Request,
HTTPResponse,
TemporalResponse,
Request,
TemporalResponse,
Websocket,
Websocket,
Context,
RootValue,
],
Expand All @@ -100,19 +139,33 @@

allow_queries_via_get = True
request_adapter_class = SanicHTTPRequestAdapter
websocket_adapter_class = SanicWebSocketAdapter

def __init__(
self,
schema: BaseSchema,
graphiql: Optional[bool] = None,
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
keep_alive: bool = True,
keep_alive_interval: float = 1,
debug: bool = False,
subscription_protocols: Sequence[str] = (
GRAPHQL_TRANSPORT_WS_PROTOCOL,
GRAPHQL_WS_PROTOCOL,
),
connection_init_wait_timeout: timedelta = timedelta(minutes=1),
json_encoder: Optional[type[json.JSONEncoder]] = None,
json_dumps_params: Optional[dict[str, Any]] = None,
multipart_uploads_enabled: bool = False,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
self.keep_alive = keep_alive
self.keep_alive_interval = keep_alive_interval
self.debug = debug
self.protocols = subscription_protocols
self.connection_init_wait_timeout = connection_init_wait_timeout
self.json_encoder = json_encoder
self.json_dumps_params = json_dumps_params
self.multipart_uploads_enabled = multipart_uploads_enabled
Expand Down Expand Up @@ -143,11 +196,15 @@
else:
self.graphql_ide = graphql_ide

async def get_root_value(self, request: Request) -> Optional[RootValue]:
async def get_root_value(
self, request: Union[Request, Websocket]
) -> Optional[RootValue]:
return None

async def get_context(
self, request: Request, response: TemporalResponse
self,
request: Union[Request, Websocket],
response: Union[TemporalResponse, Websocket],
) -> Context:
return {"request": request, "response": response} # type: ignore

Expand Down Expand Up @@ -187,6 +244,9 @@
except HTTPException as e:
return HTTPResponse(e.reason, status=e.status_code)

async def websocket(self, request: Request, ws: Websocket) -> Websocket:
return await self.run(ws)

Check warning on line 248 in strawberry/sanic/views.py

View check run for this annotation

Codecov / codecov/patch

strawberry/sanic/views.py#L248

Added line #L248 was not covered by tests

async def create_streaming_response(
self,
request: Request,
Expand All @@ -213,16 +273,19 @@
# corner case
return None # type: ignore

def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
return False
def is_websocket_request(
self, request: Union[Request, Websocket]
) -> TypeGuard[Websocket]:
# TODO: sanic gives us a WebSocketConnection when ASGI is used, which has a completely different inferface???
return isinstance(request, Websocket)

async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
raise NotImplementedError
async def pick_websocket_subprotocol(self, request: Websocket) -> Optional[str]:
return None

Check warning on line 283 in strawberry/sanic/views.py

View check run for this annotation

Codecov / codecov/patch

strawberry/sanic/views.py#L283

Added line #L283 was not covered by tests

async def create_websocket_response(
self, request: Request, subprotocol: Optional[str]
) -> TemporalResponse:
raise NotImplementedError
self, request: Websocket, subprotocol: Optional[str]
) -> Websocket:
return request

Check warning on line 288 in strawberry/sanic/views.py

View check run for this annotation

Codecov / codecov/patch

strawberry/sanic/views.py#L288

Added line #L288 was not covered by tests


__all__ = ["GraphQLView"]
78 changes: 66 additions & 12 deletions tests/http/clients/sanic.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,60 @@
from __future__ import annotations

import contextlib
import uuid
from collections.abc import AsyncGenerator
from io import BytesIO
from json import dumps
from random import randint
from typing import Any, Optional
from typing import Any, Optional, Union
from typing_extensions import Literal

from starlette.testclient import TestClient

from sanic import Request as SanicRequest
from sanic import Sanic
from sanic.request import Request as SanicRequest
from sanic import Websocket as SanicWebsocket
from strawberry.http import GraphQLHTTPResponse
from strawberry.http.ides import GraphQL_IDE
from strawberry.http.temporal_response import TemporalResponse
from strawberry.sanic.views import GraphQLView as BaseGraphQLView
from strawberry.types import ExecutionResult
from tests.http.context import get_context
from tests.views.schema import Query, schema
from tests.websockets.views import OnWSConnectMixin

from .base import JSON, HttpClient, Response, ResultOverrideFunction
from .asgi import AsgiWebSocketClient
from .base import (
JSON,
DebuggableGraphQLTransportWSHandler,
DebuggableGraphQLWSHandler,
HttpClient,
Response,
ResultOverrideFunction,
WebSocketClient,
)


class GraphQLView(BaseGraphQLView[object, Query]):
class GraphQLView(OnWSConnectMixin, BaseGraphQLView[dict[str, object], object]):
result_override: ResultOverrideFunction = None
graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler
graphql_ws_handler_class = DebuggableGraphQLWSHandler

def __init__(self, *args: Any, **kwargs: Any):
self.result_override = kwargs.pop("result_override")
self.result_override = kwargs.pop("result_override", None)
super().__init__(*args, **kwargs)

async def get_root_value(self, request: SanicRequest) -> Query:
async def get_root_value(
self, request: Union[SanicRequest, SanicWebsocket]
) -> Query:
await super().get_root_value(request) # for coverage
return Query()

async def get_context(
self, request: SanicRequest, response: TemporalResponse
) -> object:
self,
request: Union[SanicRequest, SanicWebsocket],
response: Union[TemporalResponse, SanicWebsocket],
) -> dict[str, object]:
context = await super().get_context(request, response)

return get_context(context)
Expand All @@ -58,17 +80,38 @@
self.app = Sanic(
f"test_{int(randint(0, 1000))}", # noqa: S311
)
view = GraphQLView.as_view(
http_view = GraphQLView.as_view(
schema=schema,
graphiql=graphiql,
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
result_override=result_override,
keep_alive=False,
multipart_uploads_enabled=multipart_uploads_enabled,
)
self.app.add_route(
view,
"/graphql",
ws_view = GraphQLView(
schema=schema,
graphiql=graphiql,
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
result_override=result_override,
keep_alive=False,
multipart_uploads_enabled=multipart_uploads_enabled,
)
# self.app.add_route(http_view, "/graphql")

# TODO: do we need the ws view here even?
self.app.add_websocket_route(ws_view.websocket, "/graphql", subprotocols=[])

def create_app(self, **kwargs: Any) -> None:
self.app = Sanic(f"test-{uuid.uuid4().hex}")
http_view = GraphQLView.as_view(schema=schema, **kwargs)
ws_view = GraphQLView(schema=schema, **kwargs)
# self.app.add_route(http_view, "/graphql")

protocols = kwargs.get("subscription_protocols", [])
self.app.add_websocket_route(
ws_view.websocket, "/graphql", subprotocols=protocols
)

async def _graphql_request(
Expand Down Expand Up @@ -153,3 +196,14 @@
data=response.content,
headers=response.headers,
)

@contextlib.asynccontextmanager
async def ws_connect(
self,
url: str,
*,
protocols: list[str],
) -> AsyncGenerator[WebSocketClient, None]:
with TestClient(self.app) as client:
with client.websocket_connect(url, protocols) as ws:
yield AsgiWebSocketClient(ws)

Check warning on line 209 in tests/http/clients/sanic.py

View check run for this annotation

Codecov / codecov/patch

tests/http/clients/sanic.py#L208-L209

Added lines #L208 - L209 were not covered by tests
1 change: 1 addition & 0 deletions tests/websockets/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _get_http_client_classes() -> Generator[Any, None, None]:
("FastAPIHttpClient", "fastapi", [pytest.mark.fastapi]),
("LitestarHttpClient", "litestar", [pytest.mark.litestar]),
("QuartHttpClient", "quart", [pytest.mark.quart]),
("SanicHttpClient", "sanic", [pytest.mark.sanic]),
]:
try:
client_class = getattr(
Expand Down
Loading