Skip to content

Commit ed3acb1

Browse files
authored
fix(py/core/reflection): correctly serialize primitive types in the action stream (#4958)
1 parent ce52a47 commit ed3acb1

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

py/packages/genkit/src/genkit/_core/_reflection.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,15 @@ def on_trace_start(self, tid: str, sid: str) -> None:
8080

8181
async def execute(self) -> None:
8282
try:
83-
on_chunk = (lambda c: self.queue.put_nowait(f'{c.model_dump_json()}\n')) if self.stream else None
83+
on_chunk = (
84+
(
85+
lambda c: self.queue.put_nowait(
86+
f'{c.model_dump_json() if isinstance(c, BaseModel) else json.dumps(c)}\n'
87+
)
88+
)
89+
if self.stream
90+
else None
91+
)
8492
output = await self.action.run(
8593
input=self.payload.get('input'),
8694
on_chunk=on_chunk,

py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
from __future__ import annotations
3838

39+
import json
3940
from collections.abc import AsyncIterator, Awaitable, Callable
4041
from typing import Any, cast
4142
from unittest.mock import ANY, AsyncMock, MagicMock
@@ -242,3 +243,62 @@ async def mock_streaming(
242243
assert response.status_code == 200
243244
assert response.headers['X-Genkit-Trace-Id'] == 'stream_trace_id'
244245
assert response.headers['X-Genkit-Span-Id'] == 'stream_span_id'
246+
247+
248+
@pytest.mark.parametrize(
249+
'chunks, expected_lines',
250+
[
251+
(['string chunk 1', 'string chunk 2'], ['"string chunk 1"', '"string chunk 2"']),
252+
([123, 456], ['123', '456']),
253+
([12.3, 45.6], ['12.3', '45.6']),
254+
([True, False], ['true', 'false']),
255+
([None], ['null']),
256+
([{'key': 'value'}], ['{"key": "value"}']),
257+
],
258+
)
259+
@pytest.mark.asyncio
260+
async def test_run_action_streaming_primitive_types(
261+
asgi_client: AsyncClient,
262+
mock_registry: MagicMock,
263+
chunks: list[Any],
264+
expected_lines: list[str],
265+
) -> None:
266+
"""Test that streaming actions with primitive type chunks work correctly."""
267+
mock_action = AsyncMock()
268+
269+
async def mock_streaming(
270+
input: object = None,
271+
on_chunk: object | None = None,
272+
context: object | None = None,
273+
on_trace_start: Callable[[str, str], None] | None = None,
274+
**kwargs: Any, # noqa: ANN401
275+
) -> MagicMock:
276+
if on_trace_start:
277+
on_trace_start('stream_trace_id', 'stream_span_id')
278+
if on_chunk:
279+
on_chunk_fn = cast(Callable[[object], None], on_chunk)
280+
for chunk in chunks:
281+
on_chunk_fn(chunk)
282+
mock_output = MagicMock()
283+
mock_output.response = {'final': 'result'}
284+
mock_output.trace_id = 'stream_trace_id'
285+
mock_output.span_id = 'stream_span_id'
286+
return mock_output
287+
288+
mock_action.run.side_effect = mock_streaming
289+
mock_registry.resolve_action_by_key.return_value = mock_action
290+
291+
response = await asgi_client.post(
292+
'/api/runAction?stream=true',
293+
json={'key': 'test_action', 'input': {'data': 'test'}},
294+
)
295+
296+
assert response.status_code == 200
297+
assert response.headers['X-Genkit-Trace-Id'] == 'stream_trace_id'
298+
assert response.headers['X-Genkit-Span-Id'] == 'stream_span_id'
299+
300+
lines = response.text.strip().split('\n')
301+
assert lines[:-1] == expected_lines
302+
303+
final_result = json.loads(lines[-1])
304+
assert final_result['result'] == {'final': 'result'}

0 commit comments

Comments
 (0)