Skip to content

Commit d26bee3

Browse files
author
z50053222
committed
kafka
1 parent 2c047f7 commit d26bee3

File tree

5 files changed

+1022
-20
lines changed

5 files changed

+1022
-20
lines changed

src/a2a/client/transports/kafka.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -493,10 +493,24 @@ async def set_task_callback(
493493
context: ClientCallContext | None = None,
494494
) -> TaskPushNotificationConfig:
495495
"""Set task push notification configuration."""
496-
# For Kafka, we can store the callback configuration locally
497-
# and use it when we receive push notifications
498-
# This is a simplified implementation
499-
return request
496+
correlation_id = await self._send_request('task_push_notification_config_set', request, context)
497+
future = await self.correlation_manager.register(correlation_id)
498+
499+
try:
500+
timeout = 30.0
501+
if context and context.timeout:
502+
timeout = context.timeout
503+
504+
result = await asyncio.wait_for(future, timeout=timeout)
505+
if isinstance(result, TaskPushNotificationConfig):
506+
return result
507+
raise A2AClientError(f"Expected TaskPushNotificationConfig, got {type(result)}")
508+
except asyncio.TimeoutError:
509+
await self.correlation_manager.complete_with_exception(
510+
correlation_id,
511+
A2AClientError(f"Set task callback request timed out after {timeout} seconds")
512+
)
513+
raise A2AClientError(f"Set task callback request timed out after {timeout} seconds")
500514

501515
async def get_task_callback(
502516
self,
@@ -505,7 +519,10 @@ async def get_task_callback(
505519
context: ClientCallContext | None = None,
506520
) -> TaskPushNotificationConfig:
507521
"""Get task push notification configuration."""
508-
return await self.get_task_push_notification_config(request, context=context)
522+
result = await self.get_task_push_notification_config(request, context=context)
523+
if result is None:
524+
raise A2AClientError(f"No task callback configuration found for task {request.task_id}")
525+
return result
509526

510527
async def resubscribe(
511528
self,
@@ -514,12 +531,44 @@ async def resubscribe(
514531
context: ClientCallContext | None = None,
515532
) -> AsyncGenerator[Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent]:
516533
"""Reconnect to get task updates."""
517-
# For Kafka, resubscription is handled automatically by the consumer
518-
# This method can be used to request task updates
519-
task_request = TaskQueryParams(task_id=request.task_id)
520-
task = await self.get_task(task_request, context=context)
521-
if task:
522-
yield task
534+
# For Kafka, we send a resubscribe request to get streaming updates
535+
correlation_id = await self._send_request('task_resubscribe', request, context, streaming=True)
536+
537+
# Register streaming request
538+
streaming_future = await self.correlation_manager.register_streaming(correlation_id)
539+
540+
try:
541+
timeout = 30.0
542+
if context and context.timeout:
543+
timeout = context.timeout
544+
545+
# First, get the current task state
546+
task_request = TaskQueryParams(task_id=request.task_id)
547+
try:
548+
task = await self.get_task(task_request, context=context)
549+
yield task
550+
except Exception as e:
551+
logger.warning(f"Failed to get initial task state: {e}")
552+
553+
# Then yield streaming updates as they arrive
554+
while not streaming_future.is_done():
555+
try:
556+
# Wait for next response with timeout
557+
result = await asyncio.wait_for(streaming_future.get(), timeout=5.0)
558+
yield result
559+
except asyncio.TimeoutError:
560+
# Check if stream is done or if we've exceeded total timeout
561+
if streaming_future.is_done():
562+
break
563+
# Continue waiting for more responses
564+
continue
565+
566+
except Exception as e:
567+
await self.correlation_manager.complete_with_exception(
568+
correlation_id,
569+
A2AClientError(f"Resubscribe request failed: {e}")
570+
)
571+
raise A2AClientError(f"Resubscribe request failed: {e}") from e
523572

524573
async def get_card(
525574
self,

src/a2a/server/request_handlers/kafka_handler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,11 @@ async def _handle_single_request(
188188
result = await self.request_handler.on_list_task_push_notification_config(request, context)
189189
response_type = "task_push_notification_config_list"
190190

191+
elif method == "task_push_notification_config_set":
192+
request = TaskPushNotificationConfig.model_validate(params)
193+
result = await self.request_handler.on_set_task_push_notification_config(request, context)
194+
response_type = "task_push_notification_config"
195+
191196
elif method == "task_push_notification_config_delete":
192197
request = DeleteTaskPushNotificationConfigParams.model_validate(params)
193198
await self.request_handler.on_delete_task_push_notification_config(request, context)
@@ -232,6 +237,25 @@ async def _handle_streaming_request(
232237

233238
# Send stream completion signal
234239
await self.response_sender.send_stream_complete(reply_topic, correlation_id)
240+
241+
elif method == "task_resubscribe":
242+
request = TaskIdParams.model_validate(params)
243+
244+
# Handle streaming resubscription
245+
async for event in self.request_handler.on_resubscribe_to_task(request, context):
246+
if isinstance(event, TaskStatusUpdateEvent):
247+
response_type = "task_status_update"
248+
elif isinstance(event, TaskArtifactUpdateEvent):
249+
response_type = "task_artifact_update"
250+
elif isinstance(event, Task):
251+
response_type = "task"
252+
else:
253+
response_type = "message"
254+
255+
await self.response_sender.send_response(reply_topic, correlation_id, event, response_type)
256+
257+
# Send stream completion signal
258+
await self.response_sender.send_stream_complete(reply_topic, correlation_id)
235259

236260
else:
237261
raise ServerError(f"Streaming not supported for method: {method}")

tests/client/test_kafka_client.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ async def test_internal_start_stop(self, mock_consumer_class, mock_producer_clas
155155
mock_producer_class.return_value = mock_producer
156156
mock_consumer_class.return_value = mock_consumer
157157

158-
# Start transport using internal method
159-
await kafka_transport._start()
158+
# Start transport
159+
await kafka_transport.start()
160160

161161
assert kafka_transport._running is True
162162
assert kafka_transport.producer == mock_producer
@@ -167,8 +167,8 @@ async def test_internal_start_stop(self, mock_consumer_class, mock_producer_clas
167167
mock_producer.start.assert_called_once()
168168
mock_consumer.start.assert_called_once()
169169

170-
# Stop transport using internal method
171-
await kafka_transport._stop()
170+
# Stop transport
171+
await kafka_transport.stop()
172172

173173
assert kafka_transport._running is False
174174
mock_producer.stop.assert_called_once()
@@ -279,8 +279,8 @@ def test_parse_response(self, kafka_transport):
279279
@pytest.mark.asyncio
280280
async def test_context_manager(self, kafka_transport):
281281
"""Test async context manager."""
282-
with patch.object(kafka_transport, '_start') as mock_start, \
283-
patch.object(kafka_transport, '_stop') as mock_stop:
282+
with patch.object(kafka_transport, 'start') as mock_start, \
283+
patch.object(kafka_transport, 'stop') as mock_stop:
284284

285285
async with kafka_transport:
286286
mock_start.assert_called_once()
@@ -426,7 +426,6 @@ def test_create_classmethod(self, agent_card):
426426
interceptors=[]
427427
)
428428

429-
430429
@pytest.mark.integration
431430
class TestKafkaIntegration:
432431
"""Integration tests for Kafka transport (requires running Kafka)."""
@@ -441,8 +440,8 @@ async def test_real_kafka_connection(self, agent_card):
441440
)
442441

443442
try:
444-
await transport._start()
443+
await transport.start()
445444
assert transport._running is True
446445
finally:
447-
await transport._stop()
446+
await transport.stop()
448447
assert transport._running is False
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import asyncio
2+
import sys
3+
import types
4+
from dataclasses import dataclass
5+
from typing import Any, List, Optional
6+
7+
import pytest
8+
from unittest.mock import AsyncMock
9+
10+
# Inject a fake aiokafka module before importing the app under test
11+
fake_aiokafka = types.ModuleType("aiokafka")
12+
fake_aiokafka_errors = types.ModuleType("aiokafka.errors")
13+
14+
15+
class FakeKafkaError(Exception):
16+
pass
17+
18+
19+
fake_aiokafka_errors.KafkaError = FakeKafkaError
20+
21+
22+
class FakeProducer:
23+
def __init__(self, *args, **kwargs):
24+
self.args = args
25+
self.kwargs = kwargs
26+
self.started = False
27+
self.sent: List[tuple] = [] # (topic, value, headers)
28+
29+
async def start(self):
30+
self.started = True
31+
32+
async def stop(self):
33+
self.started = False
34+
35+
async def send_and_wait(self, topic: str, value: Any, headers: list[tuple[str, bytes]] | None = None):
36+
self.sent.append((topic, value, headers or []))
37+
38+
39+
@dataclass
40+
class FakeMessage:
41+
value: Any
42+
headers: Optional[List[tuple[str, bytes]]] = None
43+
44+
45+
class FakeConsumer:
46+
def __init__(self, *args, **kwargs):
47+
self.args = args
48+
self.kwargs = kwargs
49+
self.started = False
50+
# queue of messages to yield
51+
self._messages: List[FakeMessage] = []
52+
53+
def add_message(self, value: Any, headers: Optional[List[tuple[str, bytes]]] = None):
54+
self._messages.append(FakeMessage(value=value, headers=headers))
55+
56+
async def start(self):
57+
self.started = True
58+
59+
async def stop(self):
60+
self.started = False
61+
62+
def __aiter__(self):
63+
self._iter = iter(self._messages)
64+
return self
65+
66+
async def __anext__(self):
67+
try:
68+
return next(self._iter)
69+
except StopIteration:
70+
raise StopAsyncIteration
71+
72+
73+
fake_aiokafka.AIOKafkaProducer = FakeProducer
74+
fake_aiokafka.AIOKafkaConsumer = FakeConsumer
75+
76+
sys.modules.setdefault("aiokafka", fake_aiokafka)
77+
sys.modules.setdefault("aiokafka.errors", fake_aiokafka_errors)
78+
79+
# Now safe to import the module under test
80+
from a2a.server.apps.kafka.kafka_app import KafkaServerApp, KafkaHandler
81+
from a2a.server.request_handlers.request_handler import RequestHandler
82+
from a2a.server.request_handlers.kafka_handler import KafkaMessage
83+
84+
85+
class DummyHandler:
86+
"""A minimal KafkaHandler drop-in used to capture consumed messages."""
87+
88+
def __init__(self):
89+
self.handled: list[KafkaMessage] = []
90+
91+
async def handle_request(self, message: KafkaMessage) -> None:
92+
self.handled.append(message)
93+
94+
95+
@pytest.fixture
96+
def request_handler():
97+
return AsyncMock(spec=RequestHandler)
98+
99+
100+
@pytest.fixture
101+
def app(monkeypatch, request_handler):
102+
# Replace KafkaHandler inside the kafka_app module to our DummyHandler
103+
dummy = DummyHandler()
104+
105+
def _fake_kafka_handler_ctor(rh, response_sender):
106+
# validate response_sender is the app instance later
107+
return dummy
108+
109+
# Patch the symbol used by kafka_app
110+
monkeypatch.setattr(
111+
"a2a.server.apps.kafka.kafka_app.KafkaHandler", _fake_kafka_handler_ctor
112+
)
113+
114+
a = KafkaServerApp(
115+
request_handler=request_handler,
116+
bootstrap_servers="dummy:9092",
117+
request_topic="a2a-requests",
118+
consumer_group_id="a2a-server",
119+
)
120+
# expose dummy for assertions
121+
a._dummy_handler = dummy
122+
return a
123+
124+
125+
@pytest.mark.asyncio
126+
async def test_start_initializes_components(app: KafkaServerApp):
127+
await app.start()
128+
assert app._running is True
129+
assert isinstance(app.producer, FakeProducer) and app.producer.started
130+
assert isinstance(app.consumer, FakeConsumer) and app.consumer.started
131+
# handler constructed
132+
assert app.handler is app._dummy_handler
133+
134+
135+
@pytest.mark.asyncio
136+
async def test_stop_closes_components(app: KafkaServerApp):
137+
await app.start()
138+
await app.stop()
139+
assert app._running is False
140+
assert app.producer is not None and app.producer.started is False
141+
assert app.consumer is not None and app.consumer.started is False
142+
143+
144+
@pytest.mark.asyncio
145+
async def test_send_response_uses_producer_headers_and_payload(app: KafkaServerApp):
146+
await app.start()
147+
await app.send_response("reply-topic", "corr-1", {"k": 1}, "task")
148+
assert len(app.producer.sent) == 1
149+
topic, value, headers = app.producer.sent[0]
150+
assert topic == "reply-topic"
151+
assert value["type"] == "task" and value["data"] == {"k": 1}
152+
assert ("correlation_id", b"corr-1") in headers
153+
154+
155+
@pytest.mark.asyncio
156+
async def test_send_stream_complete_uses_producer(app: KafkaServerApp):
157+
await app.start()
158+
await app.send_stream_complete("reply-topic", "corr-2")
159+
topic, value, headers = app.producer.sent[-1]
160+
assert topic == "reply-topic"
161+
assert value["type"] == "stream_complete"
162+
assert ("correlation_id", b"corr-2") in headers
163+
164+
165+
@pytest.mark.asyncio
166+
async def test_send_error_response_uses_producer(app: KafkaServerApp):
167+
await app.start()
168+
await app.send_error_response("reply-topic", "corr-3", "boom")
169+
topic, value, headers = app.producer.sent[-1]
170+
assert topic == "reply-topic"
171+
assert value["type"] == "error"
172+
assert value["data"]["error"] == "boom"
173+
assert ("correlation_id", b"corr-3") in headers
174+
175+
176+
@pytest.mark.asyncio
177+
async def test_consume_requests_converts_and_delegates(app: KafkaServerApp):
178+
await app.start()
179+
# Prepare a message for the consumer
180+
assert isinstance(app.consumer, FakeConsumer)
181+
app.consumer.add_message(
182+
value={"method": "message_send", "params": {}, "streaming": False},
183+
headers=[("reply_topic", b"replies"), ("correlation_id", b"cid-1")],
184+
)
185+
186+
# Run consume loop once; since FakeConsumer yields finite messages, it will end
187+
await app._consume_requests()
188+
189+
# Verify our dummy handler saw the converted KafkaMessage
190+
handled = app._dummy_handler.handled
191+
assert len(handled) == 1
192+
km: KafkaMessage = handled[0]
193+
assert km.get_header("reply_topic") == "replies"
194+
assert km.get_header("correlation_id") == "cid-1"
195+
assert km.value["method"] == "message_send"

0 commit comments

Comments
 (0)