|
36 | 36 |
|
37 | 37 | from __future__ import annotations |
38 | 38 |
|
| 39 | +import json |
39 | 40 | from collections.abc import AsyncIterator, Awaitable, Callable |
40 | 41 | from typing import Any, cast |
41 | 42 | from unittest.mock import ANY, AsyncMock, MagicMock |
@@ -242,3 +243,62 @@ async def mock_streaming( |
242 | 243 | assert response.status_code == 200 |
243 | 244 | assert response.headers['X-Genkit-Trace-Id'] == 'stream_trace_id' |
244 | 245 | assert response.headers['X-Genkit-Span-Id'] == 'stream_span_id' |
| 246 | + |
| 247 | + |
| 248 | +@pytest.mark.parametrize( |
| 249 | + 'chunks, expected_lines', |
| 250 | + [ |
| 251 | + (['string chunk 1', 'string chunk 2'], ['"string chunk 1"', '"string chunk 2"']), |
| 252 | + ([123, 456], ['123', '456']), |
| 253 | + ([12.3, 45.6], ['12.3', '45.6']), |
| 254 | + ([True, False], ['true', 'false']), |
| 255 | + ([None], ['null']), |
| 256 | + ([{'key': 'value'}], ['{"key": "value"}']), |
| 257 | + ], |
| 258 | +) |
| 259 | +@pytest.mark.asyncio |
| 260 | +async def test_run_action_streaming_primitive_types( |
| 261 | + asgi_client: AsyncClient, |
| 262 | + mock_registry: MagicMock, |
| 263 | + chunks: list[Any], |
| 264 | + expected_lines: list[str], |
| 265 | +) -> None: |
| 266 | + """Test that streaming actions with primitive type chunks work correctly.""" |
| 267 | + mock_action = AsyncMock() |
| 268 | + |
| 269 | + async def mock_streaming( |
| 270 | + input: object = None, |
| 271 | + on_chunk: object | None = None, |
| 272 | + context: object | None = None, |
| 273 | + on_trace_start: Callable[[str, str], None] | None = None, |
| 274 | + **kwargs: Any, # noqa: ANN401 |
| 275 | + ) -> MagicMock: |
| 276 | + if on_trace_start: |
| 277 | + on_trace_start('stream_trace_id', 'stream_span_id') |
| 278 | + if on_chunk: |
| 279 | + on_chunk_fn = cast(Callable[[object], None], on_chunk) |
| 280 | + for chunk in chunks: |
| 281 | + on_chunk_fn(chunk) |
| 282 | + mock_output = MagicMock() |
| 283 | + mock_output.response = {'final': 'result'} |
| 284 | + mock_output.trace_id = 'stream_trace_id' |
| 285 | + mock_output.span_id = 'stream_span_id' |
| 286 | + return mock_output |
| 287 | + |
| 288 | + mock_action.run.side_effect = mock_streaming |
| 289 | + mock_registry.resolve_action_by_key.return_value = mock_action |
| 290 | + |
| 291 | + response = await asgi_client.post( |
| 292 | + '/api/runAction?stream=true', |
| 293 | + json={'key': 'test_action', 'input': {'data': 'test'}}, |
| 294 | + ) |
| 295 | + |
| 296 | + assert response.status_code == 200 |
| 297 | + assert response.headers['X-Genkit-Trace-Id'] == 'stream_trace_id' |
| 298 | + assert response.headers['X-Genkit-Span-Id'] == 'stream_span_id' |
| 299 | + |
| 300 | + lines = response.text.strip().split('\n') |
| 301 | + assert lines[:-1] == expected_lines |
| 302 | + |
| 303 | + final_result = json.loads(lines[-1]) |
| 304 | + assert final_result['result'] == {'final': 'result'} |
0 commit comments