|
| 1 | +import asyncio |
| 2 | +import sys |
| 3 | +import types |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import Any, List, Optional |
| 6 | + |
| 7 | +import pytest |
| 8 | +from unittest.mock import AsyncMock |
| 9 | + |
| 10 | +# Inject a fake aiokafka module before importing the app under test |
| 11 | +fake_aiokafka = types.ModuleType("aiokafka") |
| 12 | +fake_aiokafka_errors = types.ModuleType("aiokafka.errors") |
| 13 | + |
| 14 | + |
| 15 | +class FakeKafkaError(Exception): |
| 16 | + pass |
| 17 | + |
| 18 | + |
| 19 | +fake_aiokafka_errors.KafkaError = FakeKafkaError |
| 20 | + |
| 21 | + |
| 22 | +class FakeProducer: |
| 23 | + def __init__(self, *args, **kwargs): |
| 24 | + self.args = args |
| 25 | + self.kwargs = kwargs |
| 26 | + self.started = False |
| 27 | + self.sent: List[tuple] = [] # (topic, value, headers) |
| 28 | + |
| 29 | + async def start(self): |
| 30 | + self.started = True |
| 31 | + |
| 32 | + async def stop(self): |
| 33 | + self.started = False |
| 34 | + |
| 35 | + async def send_and_wait(self, topic: str, value: Any, headers: list[tuple[str, bytes]] | None = None): |
| 36 | + self.sent.append((topic, value, headers or [])) |
| 37 | + |
| 38 | + |
| 39 | +@dataclass |
| 40 | +class FakeMessage: |
| 41 | + value: Any |
| 42 | + headers: Optional[List[tuple[str, bytes]]] = None |
| 43 | + |
| 44 | + |
| 45 | +class FakeConsumer: |
| 46 | + def __init__(self, *args, **kwargs): |
| 47 | + self.args = args |
| 48 | + self.kwargs = kwargs |
| 49 | + self.started = False |
| 50 | + # queue of messages to yield |
| 51 | + self._messages: List[FakeMessage] = [] |
| 52 | + |
| 53 | + def add_message(self, value: Any, headers: Optional[List[tuple[str, bytes]]] = None): |
| 54 | + self._messages.append(FakeMessage(value=value, headers=headers)) |
| 55 | + |
| 56 | + async def start(self): |
| 57 | + self.started = True |
| 58 | + |
| 59 | + async def stop(self): |
| 60 | + self.started = False |
| 61 | + |
| 62 | + def __aiter__(self): |
| 63 | + self._iter = iter(self._messages) |
| 64 | + return self |
| 65 | + |
| 66 | + async def __anext__(self): |
| 67 | + try: |
| 68 | + return next(self._iter) |
| 69 | + except StopIteration: |
| 70 | + raise StopAsyncIteration |
| 71 | + |
| 72 | + |
| 73 | +fake_aiokafka.AIOKafkaProducer = FakeProducer |
| 74 | +fake_aiokafka.AIOKafkaConsumer = FakeConsumer |
| 75 | + |
| 76 | +sys.modules.setdefault("aiokafka", fake_aiokafka) |
| 77 | +sys.modules.setdefault("aiokafka.errors", fake_aiokafka_errors) |
| 78 | + |
| 79 | +# Now safe to import the module under test |
| 80 | +from a2a.server.apps.kafka.kafka_app import KafkaServerApp, KafkaHandler |
| 81 | +from a2a.server.request_handlers.request_handler import RequestHandler |
| 82 | +from a2a.server.request_handlers.kafka_handler import KafkaMessage |
| 83 | + |
| 84 | + |
| 85 | +class DummyHandler: |
| 86 | + """A minimal KafkaHandler drop-in used to capture consumed messages.""" |
| 87 | + |
| 88 | + def __init__(self): |
| 89 | + self.handled: list[KafkaMessage] = [] |
| 90 | + |
| 91 | + async def handle_request(self, message: KafkaMessage) -> None: |
| 92 | + self.handled.append(message) |
| 93 | + |
| 94 | + |
| 95 | +@pytest.fixture |
| 96 | +def request_handler(): |
| 97 | + return AsyncMock(spec=RequestHandler) |
| 98 | + |
| 99 | + |
| 100 | +@pytest.fixture |
| 101 | +def app(monkeypatch, request_handler): |
| 102 | + # Replace KafkaHandler inside the kafka_app module to our DummyHandler |
| 103 | + dummy = DummyHandler() |
| 104 | + |
| 105 | + def _fake_kafka_handler_ctor(rh, response_sender): |
| 106 | + # validate response_sender is the app instance later |
| 107 | + return dummy |
| 108 | + |
| 109 | + # Patch the symbol used by kafka_app |
| 110 | + monkeypatch.setattr( |
| 111 | + "a2a.server.apps.kafka.kafka_app.KafkaHandler", _fake_kafka_handler_ctor |
| 112 | + ) |
| 113 | + |
| 114 | + a = KafkaServerApp( |
| 115 | + request_handler=request_handler, |
| 116 | + bootstrap_servers="dummy:9092", |
| 117 | + request_topic="a2a-requests", |
| 118 | + consumer_group_id="a2a-server", |
| 119 | + ) |
| 120 | + # expose dummy for assertions |
| 121 | + a._dummy_handler = dummy |
| 122 | + return a |
| 123 | + |
| 124 | + |
| 125 | +@pytest.mark.asyncio |
| 126 | +async def test_start_initializes_components(app: KafkaServerApp): |
| 127 | + await app.start() |
| 128 | + assert app._running is True |
| 129 | + assert isinstance(app.producer, FakeProducer) and app.producer.started |
| 130 | + assert isinstance(app.consumer, FakeConsumer) and app.consumer.started |
| 131 | + # handler constructed |
| 132 | + assert app.handler is app._dummy_handler |
| 133 | + |
| 134 | + |
| 135 | +@pytest.mark.asyncio |
| 136 | +async def test_stop_closes_components(app: KafkaServerApp): |
| 137 | + await app.start() |
| 138 | + await app.stop() |
| 139 | + assert app._running is False |
| 140 | + assert app.producer is not None and app.producer.started is False |
| 141 | + assert app.consumer is not None and app.consumer.started is False |
| 142 | + |
| 143 | + |
| 144 | +@pytest.mark.asyncio |
| 145 | +async def test_send_response_uses_producer_headers_and_payload(app: KafkaServerApp): |
| 146 | + await app.start() |
| 147 | + await app.send_response("reply-topic", "corr-1", {"k": 1}, "task") |
| 148 | + assert len(app.producer.sent) == 1 |
| 149 | + topic, value, headers = app.producer.sent[0] |
| 150 | + assert topic == "reply-topic" |
| 151 | + assert value["type"] == "task" and value["data"] == {"k": 1} |
| 152 | + assert ("correlation_id", b"corr-1") in headers |
| 153 | + |
| 154 | + |
| 155 | +@pytest.mark.asyncio |
| 156 | +async def test_send_stream_complete_uses_producer(app: KafkaServerApp): |
| 157 | + await app.start() |
| 158 | + await app.send_stream_complete("reply-topic", "corr-2") |
| 159 | + topic, value, headers = app.producer.sent[-1] |
| 160 | + assert topic == "reply-topic" |
| 161 | + assert value["type"] == "stream_complete" |
| 162 | + assert ("correlation_id", b"corr-2") in headers |
| 163 | + |
| 164 | + |
| 165 | +@pytest.mark.asyncio |
| 166 | +async def test_send_error_response_uses_producer(app: KafkaServerApp): |
| 167 | + await app.start() |
| 168 | + await app.send_error_response("reply-topic", "corr-3", "boom") |
| 169 | + topic, value, headers = app.producer.sent[-1] |
| 170 | + assert topic == "reply-topic" |
| 171 | + assert value["type"] == "error" |
| 172 | + assert value["data"]["error"] == "boom" |
| 173 | + assert ("correlation_id", b"corr-3") in headers |
| 174 | + |
| 175 | + |
| 176 | +@pytest.mark.asyncio |
| 177 | +async def test_consume_requests_converts_and_delegates(app: KafkaServerApp): |
| 178 | + await app.start() |
| 179 | + # Prepare a message for the consumer |
| 180 | + assert isinstance(app.consumer, FakeConsumer) |
| 181 | + app.consumer.add_message( |
| 182 | + value={"method": "message_send", "params": {}, "streaming": False}, |
| 183 | + headers=[("reply_topic", b"replies"), ("correlation_id", b"cid-1")], |
| 184 | + ) |
| 185 | + |
| 186 | + # Run consume loop once; since FakeConsumer yields finite messages, it will end |
| 187 | + await app._consume_requests() |
| 188 | + |
| 189 | + # Verify our dummy handler saw the converted KafkaMessage |
| 190 | + handled = app._dummy_handler.handled |
| 191 | + assert len(handled) == 1 |
| 192 | + km: KafkaMessage = handled[0] |
| 193 | + assert km.get_header("reply_topic") == "replies" |
| 194 | + assert km.get_header("correlation_id") == "cid-1" |
| 195 | + assert km.value["method"] == "message_send" |
0 commit comments