|
| 1 | +import pathlib |
1 | 2 | from collections.abc import AsyncIterator, Iterator |
2 | 3 |
|
3 | 4 | import anyio |
| 5 | +import httpx |
4 | 6 | import pytest |
5 | 7 | from httpx_sse import ServerSentEvent as HTTPXServerSentEvent |
6 | 8 | from httpx_sse import aconnect_sse |
|
9 | 11 | from litestar.exceptions import ImproperlyConfiguredException |
10 | 12 | from litestar.response import ServerSentEvent |
11 | 13 | from litestar.response.sse import ServerSentEventMessage |
12 | | -from litestar.testing import create_async_test_client |
| 14 | +from litestar.testing import create_async_test_client, subprocess_async_client |
13 | 15 | from litestar.types import SSEData |
14 | 16 |
|
| 17 | +ROOT = pathlib.Path(__file__).parent |
| 18 | +APP = "demo:app" |
| 19 | + |
| 20 | + |
| 21 | +@pytest.fixture(name="async_client") |
| 22 | +async def fx_async_client() -> AsyncIterator[httpx.AsyncClient]: |
| 23 | + async with subprocess_async_client(workdir=ROOT, app=APP) as client: |
| 24 | + yield client |
| 25 | + |
15 | 26 |
|
16 | 27 | async def test_sse_steaming_response() -> None: |
17 | 28 | @get( |
@@ -96,28 +107,21 @@ async def numbers() -> AsyncIterator[SSEData]: |
96 | 107 | assert events[i].retry == expected_events[i].retry |
97 | 108 |
|
98 | 109 |
|
99 | | -async def test_sse_cleanup() -> None: |
100 | | - shared_state = [] |
| 110 | +async def test_sse_cleanup(async_client: httpx.AsyncClient) -> None: |
| 111 | + topic = "cleanup" |
101 | 112 |
|
102 | | - @get("/testme") |
103 | | - async def handler() -> ServerSentEvent: |
104 | | - async def numbers() -> AsyncIterator[SSEData]: |
105 | | - try: |
106 | | - yield 0 |
107 | | - await anyio.sleep(5) |
108 | | - yield 1 |
109 | | - finally: |
110 | | - shared_state.append(42) |
111 | | - |
112 | | - return ServerSentEvent(numbers(), event_type="special", event_id="123", retry_duration=1000) |
113 | | - |
114 | | - async with create_async_test_client(handler) as client: |
115 | | - async with aconnect_sse(client, "GET", f"{client.base_url}/testme") as event_source: |
116 | | - async for sse in event_source.aiter_sse(): |
117 | | - assert sse.data == "0" |
118 | | - break |
| 113 | + async with ( |
| 114 | + anyio.NamedTemporaryFile("r", suffix=topic, prefix="test") as file, |
| 115 | + async_client as client, |
| 116 | + aconnect_sse(client, "GET", f"notify/{topic}") as event_source, |
| 117 | + ): |
| 118 | + file_path = pathlib.Path(str(file.name)) |
| 119 | + assert not file_path.exists() |
| 120 | + async for sse in event_source.aiter_sse(): |
| 121 | + assert sse.data == topic |
| 122 | + assert file_path.read_text() == topic |
119 | 123 |
|
120 | | - assert shared_state.pop() == 42 |
| 124 | + assert not file_path.exists() |
121 | 125 |
|
122 | 126 |
|
123 | 127 | def test_invalid_content_type_raises() -> None: |
|
0 commit comments