Skip to content

Commit 43e3041

Browse files
authored
feat: Add async websocket_connect to AsyncTestClient (#3328)
feat: Add `websocket_connect` method to AsyncTestClient Co-authored-by: kedod <kedod>
1 parent fac641a commit 43e3041

File tree

2 files changed

+107
-3
lines changed

2 files changed

+107
-3
lines changed

litestar/testing/client/async_client.py

+57-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from __future__ import annotations
22

33
from contextlib import AsyncExitStack
4-
from typing import TYPE_CHECKING, Any, Generic, Mapping, TypeVar
4+
from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar
5+
from urllib.parse import urljoin
56

67
from httpx import USE_CLIENT_DEFAULT, AsyncClient, Response
78

89
from litestar import HttpMethod
910
from litestar.testing.client.base import BaseTestClient
1011
from litestar.testing.life_span_handler import LifeSpanHandler
11-
from litestar.testing.transport import TestClientTransport
12+
from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport
1213
from litestar.types import AnyIOBackend, ASGIApp
1314

1415
if TYPE_CHECKING:
@@ -27,6 +28,7 @@
2728
from typing_extensions import Self
2829

2930
from litestar.middleware.session.base import BaseBackendConfig
31+
from litestar.testing.websocket_test_session import WebSocketTestSession
3032

3133

3234
T = TypeVar("T", bound=ASGIApp)
@@ -468,6 +470,59 @@ async def delete(
468470
extensions=None if extensions is None else dict(extensions),
469471
)
470472

473+
async def websocket_connect(
474+
self,
475+
url: str,
476+
subprotocols: Sequence[str] | None = None,
477+
params: QueryParamTypes | None = None,
478+
headers: HeaderTypes | None = None,
479+
cookies: CookieTypes | None = None,
480+
auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
481+
follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
482+
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
483+
extensions: Mapping[str, Any] | None = None,
484+
) -> WebSocketTestSession:
485+
"""Sends a GET request to establish a websocket connection.
486+
487+
Args:
488+
url: Request URL.
489+
subprotocols: Websocket subprotocols.
490+
params: Query parameters.
491+
headers: Request headers.
492+
cookies: Request cookies.
493+
auth: Auth headers.
494+
follow_redirects: Whether to follow redirects.
495+
timeout: Request timeout.
496+
extensions: Dictionary of ASGI extensions.
497+
498+
Returns:
499+
A `WebSocketTestSession <litestar.testing.WebSocketTestSession>` instance.
500+
"""
501+
url = urljoin("ws://testserver", url)
502+
default_headers: dict[str, str] = {}
503+
default_headers.setdefault("connection", "upgrade")
504+
default_headers.setdefault("sec-websocket-key", "testserver==")
505+
default_headers.setdefault("sec-websocket-version", "13")
506+
if subprotocols is not None:
507+
default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
508+
try:
509+
await AsyncClient.request(
510+
self,
511+
"GET",
512+
url,
513+
headers={**dict(headers or {}), **default_headers}, # type: ignore[misc]
514+
params=params,
515+
cookies=cookies,
516+
auth=auth,
517+
follow_redirects=follow_redirects,
518+
timeout=timeout,
519+
extensions=None if extensions is None else dict(extensions),
520+
)
521+
except ConnectionUpgradeExceptionError as exc:
522+
return exc.session
523+
524+
raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
525+
471526
async def get_session_data(self) -> dict[str, Any]:
472527
"""Get session data.
473528

tests/unit/test_testing/test_test_client.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from litestar import Controller, WebSocket, delete, head, patch, put, websocket
77
from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT
8-
from litestar.testing import AsyncTestClient, WebSocketTestSession, create_test_client
8+
from litestar.testing import AsyncTestClient, WebSocketTestSession, create_async_test_client, create_test_client
99

1010
if TYPE_CHECKING:
1111
from litestar.middleware.session.base import BaseBackendConfig
@@ -261,3 +261,52 @@ async def handler(socket: WebSocket) -> None:
261261
Empty
262262
), client.websocket_connect("/"):
263263
pass
264+
265+
266+
@pytest.mark.parametrize("block,timeout", [(False, None), (False, 0.001), (True, 0.001)])
267+
@pytest.mark.parametrize(
268+
"receive_method",
269+
[
270+
WebSocketTestSession.receive,
271+
WebSocketTestSession.receive_json,
272+
WebSocketTestSession.receive_text,
273+
WebSocketTestSession.receive_bytes,
274+
],
275+
)
276+
async def test_websocket_test_session_block_timeout_async(
277+
receive_method: Callable[..., Any], block: bool, timeout: Optional[float], anyio_backend: "AnyIOBackend"
278+
) -> None:
279+
@websocket()
280+
async def handler(socket: WebSocket) -> None:
281+
await socket.accept()
282+
283+
with pytest.raises(Empty):
284+
async with create_async_test_client(handler, backend=anyio_backend) as client:
285+
with await client.websocket_connect("/") as ws:
286+
receive_method(ws, timeout=timeout, block=block)
287+
288+
289+
async def test_websocket_accept_timeout_async(anyio_backend: "AnyIOBackend") -> None:
290+
@websocket()
291+
async def handler(socket: WebSocket) -> None:
292+
pass
293+
294+
async with create_async_test_client(handler, backend=anyio_backend, timeout=0.1) as client:
295+
with pytest.raises(Empty):
296+
with await client.websocket_connect("/"):
297+
pass
298+
299+
300+
async def test_websocket_connect_async(anyio_backend: "AnyIOBackend") -> None:
301+
@websocket()
302+
async def handler(socket: WebSocket) -> None:
303+
await socket.accept()
304+
data = await socket.receive_json()
305+
await socket.send_json(data)
306+
await socket.close()
307+
308+
async with create_async_test_client(handler, backend=anyio_backend, timeout=0.1) as client:
309+
with await client.websocket_connect("/", subprotocols="wamp") as ws:
310+
ws.send_json({"data": "123"})
311+
data = ws.receive_json()
312+
assert data == {"data": "123"}

0 commit comments

Comments
 (0)