|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | from contextlib import AsyncExitStack
|
4 |
| -from typing import TYPE_CHECKING, Any, Generic, Mapping, TypeVar |
| 4 | +from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar |
| 5 | +from urllib.parse import urljoin |
5 | 6 |
|
6 | 7 | from httpx import USE_CLIENT_DEFAULT, AsyncClient, Response
|
7 | 8 |
|
8 | 9 | from litestar import HttpMethod
|
9 | 10 | from litestar.testing.client.base import BaseTestClient
|
10 | 11 | from litestar.testing.life_span_handler import LifeSpanHandler
|
11 |
| -from litestar.testing.transport import TestClientTransport |
| 12 | +from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport |
12 | 13 | from litestar.types import AnyIOBackend, ASGIApp
|
13 | 14 |
|
14 | 15 | if TYPE_CHECKING:
|
|
27 | 28 | from typing_extensions import Self
|
28 | 29 |
|
29 | 30 | from litestar.middleware.session.base import BaseBackendConfig
|
| 31 | + from litestar.testing.websocket_test_session import WebSocketTestSession |
30 | 32 |
|
31 | 33 |
|
32 | 34 | T = TypeVar("T", bound=ASGIApp)
|
@@ -468,6 +470,59 @@ async def delete(
|
468 | 470 | extensions=None if extensions is None else dict(extensions),
|
469 | 471 | )
|
470 | 472 |
|
| 473 | + async def websocket_connect( |
| 474 | + self, |
| 475 | + url: str, |
| 476 | + subprotocols: Sequence[str] | None = None, |
| 477 | + params: QueryParamTypes | None = None, |
| 478 | + headers: HeaderTypes | None = None, |
| 479 | + cookies: CookieTypes | None = None, |
| 480 | + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, |
| 481 | + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, |
| 482 | + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, |
| 483 | + extensions: Mapping[str, Any] | None = None, |
| 484 | + ) -> WebSocketTestSession: |
| 485 | + """Sends a GET request to establish a websocket connection. |
| 486 | +
|
| 487 | + Args: |
| 488 | + url: Request URL. |
| 489 | + subprotocols: Websocket subprotocols. |
| 490 | + params: Query parameters. |
| 491 | + headers: Request headers. |
| 492 | + cookies: Request cookies. |
| 493 | + auth: Auth headers. |
| 494 | + follow_redirects: Whether to follow redirects. |
| 495 | + timeout: Request timeout. |
| 496 | + extensions: Dictionary of ASGI extensions. |
| 497 | +
|
| 498 | + Returns: |
| 499 | + A `WebSocketTestSession <litestar.testing.WebSocketTestSession>` instance. |
| 500 | + """ |
| 501 | + url = urljoin("ws://testserver", url) |
| 502 | + default_headers: dict[str, str] = {} |
| 503 | + default_headers.setdefault("connection", "upgrade") |
| 504 | + default_headers.setdefault("sec-websocket-key", "testserver==") |
| 505 | + default_headers.setdefault("sec-websocket-version", "13") |
| 506 | + if subprotocols is not None: |
| 507 | + default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) |
| 508 | + try: |
| 509 | + await AsyncClient.request( |
| 510 | + self, |
| 511 | + "GET", |
| 512 | + url, |
| 513 | + headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] |
| 514 | + params=params, |
| 515 | + cookies=cookies, |
| 516 | + auth=auth, |
| 517 | + follow_redirects=follow_redirects, |
| 518 | + timeout=timeout, |
| 519 | + extensions=None if extensions is None else dict(extensions), |
| 520 | + ) |
| 521 | + except ConnectionUpgradeExceptionError as exc: |
| 522 | + return exc.session |
| 523 | + |
| 524 | + raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover |
| 525 | + |
471 | 526 | async def get_session_data(self) -> dict[str, Any]:
|
472 | 527 | """Get session data.
|
473 | 528 |
|
|
0 commit comments