Skip to content

Commit 4698577

Browse files
authored
Merge pull request #775 from Aiven-Open/matyaskuti/confluent_kafka_asyncio
Implement async confluent-kafka producer
2 parents bc6e0aa + a51027d commit 4698577

8 files changed

Lines changed: 180 additions & 27 deletions

File tree

karapace/kafka/common.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,12 @@ def token_with_expiry(self, config: str | None) -> tuple[str, int | None]:
7171

7272

7373
class KafkaClientParams(TypedDict, total=False):
74+
acks: int | None
7475
client_id: str | None
7576
connections_max_idle_ms: int | None
76-
max_block_ms: int | None
77+
compression_type: str | None
78+
linger_ms: int | None
79+
message_max_bytes: int | None
7780
metadata_max_age_ms: int | None
7881
retries: int | None
7982
sasl_mechanism: str | None
@@ -83,6 +86,7 @@ class KafkaClientParams(TypedDict, total=False):
8386
socket_timeout_ms: int | None
8487
ssl_cafile: str | None
8588
ssl_certfile: str | None
89+
ssl_crlfile: str | None
8690
ssl_keyfile: str | None
8791
sasl_oauth_token_provider: TokenWithExpiryProvider
8892
# Consumer-only
@@ -121,8 +125,12 @@ def _get_config_from_params(self, bootstrap_servers: Iterable[str] | str, **para
121125

122126
config: dict[str, int | str | Callable | None] = {
123127
"bootstrap.servers": bootstrap_servers,
128+
"acks": params.get("acks"),
124129
"client.id": params.get("client_id"),
125130
"connections.max.idle.ms": params.get("connections_max_idle_ms"),
131+
"compression.type": params.get("compression_type"),
132+
"linger.ms": params.get("linger_ms"),
133+
"message.max.bytes": params.get("message_max_bytes"),
126134
"metadata.max.age.ms": params.get("metadata_max_age_ms"),
127135
"retries": params.get("retries"),
128136
"sasl.mechanism": params.get("sasl_mechanism"),
@@ -132,6 +140,7 @@ def _get_config_from_params(self, bootstrap_servers: Iterable[str] | str, **para
132140
"socket.timeout.ms": params.get("socket_timeout_ms"),
133141
"ssl.ca.location": params.get("ssl_cafile"),
134142
"ssl.certificate.location": params.get("ssl_certfile"),
143+
"ssl.crl.location": params.get("ssl_crlfile"),
135144
"ssl.key.location": params.get("ssl_keyfile"),
136145
"error_cb": self._error_callback,
137146
# Consumer-only

karapace/kafka/producer.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55

66
from __future__ import annotations
77

8+
from collections.abc import Iterable
89
from concurrent.futures import Future
910
from confluent_kafka import Message, Producer
1011
from confluent_kafka.admin import PartitionMetadata
1112
from confluent_kafka.error import KafkaError, KafkaException
1213
from functools import partial
13-
from karapace.kafka.common import _KafkaConfigMixin, raise_from_kafkaexception, translate_from_kafkaerror
14+
from karapace.kafka.common import _KafkaConfigMixin, KafkaClientParams, raise_from_kafkaexception, translate_from_kafkaerror
15+
from threading import Event, Thread
1416
from typing import cast, TypedDict
1517
from typing_extensions import Unpack
1618

19+
import asyncio
1720
import logging
1821

1922
LOG = logging.getLogger(__name__)
@@ -59,3 +62,70 @@ def partitions_for(self, topic: str) -> dict[int, PartitionMetadata]:
5962
return self.list_topics(topic).topics[topic].partitions
6063
except KafkaException as exc:
6164
raise_from_kafkaexception(exc)
65+
66+
67+
class AsyncKafkaProducer:
68+
"""An async wrapper around `KafkaProducer` built on top of confluent-kafka.
69+
70+
Calling `start` on an `AsyncKafkaProducer` instantiates a `KafkaProducer`
71+
and starts a poll-thread.
72+
73+
The poll-thread continuously polls the underlying producer so buffered messages
74+
are sent and asyncio futures returned by the `send` method can be awaited.
75+
"""
76+
77+
def __init__(
78+
self,
79+
bootstrap_servers: Iterable[str] | str,
80+
loop: asyncio.AbstractEventLoop | None = None,
81+
**params: Unpack[KafkaClientParams],
82+
) -> None:
83+
self.loop = loop or asyncio.get_running_loop()
84+
85+
self.stopped = Event()
86+
self.poll_thread = Thread(target=self.poll_loop)
87+
88+
self.producer: KafkaProducer | None = None
89+
self._bootstrap_servers = bootstrap_servers
90+
self._producer_params = params
91+
92+
def _start(self) -> None:
93+
assert not self.stopped.is_set(), "The async producer cannot be restarted"
94+
95+
self.producer = KafkaProducer(self._bootstrap_servers, **self._producer_params)
96+
self.poll_thread.start()
97+
98+
async def start(self) -> None:
99+
# The `KafkaProducer` instantiation tries to establish a connection with
100+
# retries, thus can block for a relatively long time. Running in the
101+
# default executor and awaiting makes it async compatible.
102+
await self.loop.run_in_executor(None, self._start)
103+
104+
def _stop(self) -> None:
105+
self.stopped.set()
106+
if self.poll_thread.is_alive():
107+
self.poll_thread.join()
108+
self.producer = None
109+
110+
async def stop(self) -> None:
111+
# Running all actions needed to stop in the default executor, since
112+
# some can be blocking.
113+
await self.loop.run_in_executor(None, self._stop)
114+
115+
def poll_loop(self) -> None:
116+
"""Target of the poll-thread."""
117+
assert self.producer is not None, "The async producer must be started"
118+
119+
while not self.stopped.is_set():
120+
# The call to `poll` is blocking, necessitating running this loop in its own thread.
121+
# In case there is messages to be sent, this loop will do just that (equivalent to
122+
# a `flush` call), otherwise it'll sleep for the given timeout (seconds).
123+
self.producer.poll(timeout=0.1)
124+
125+
async def send(self, topic: str, **params: Unpack[ProducerSendParams]) -> asyncio.Future[Message]:
126+
assert self.producer is not None, "The async producer must be started"
127+
128+
return asyncio.wrap_future(
129+
self.producer.send(topic, **params),
130+
loop=self.loop,
131+
)

karapace/kafka_rest_apis/__init__.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from aiokafka import AIOKafkaProducer
2-
from aiokafka.errors import KafkaConnectionError
31
from binascii import Error as B64DecodeError
42
from collections import namedtuple
53
from confluent_kafka.error import KafkaException
@@ -13,9 +11,10 @@
1311
TopicAuthorizationFailedError,
1412
UnknownTopicOrPartitionError,
1513
)
16-
from karapace.config import Config, create_client_ssl_context
14+
from karapace.config import Config
1715
from karapace.errors import InvalidSchema
1816
from karapace.kafka.admin import KafkaAdminClient
17+
from karapace.kafka.producer import AsyncKafkaProducer
1918
from karapace.kafka_rest_apis.authentication import (
2019
get_auth_config_from_header,
2120
get_expiration_time_from_header,
@@ -36,7 +35,7 @@
3635
SchemaRetrievalError,
3736
)
3837
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType
39-
from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient
38+
from karapace.utils import convert_to_int, json_encode
4039
from typing import Callable, Dict, List, Optional, Tuple, Union
4140

4241
import asyncio
@@ -73,6 +72,7 @@ def __init__(self, config: Config) -> None:
7372
self._idle_proxy_janitor_task: Optional[asyncio.Task] = None
7473

7574
async def close(self) -> None:
75+
log.info("Closing REST proxy application")
7676
if self._idle_proxy_janitor_task is not None:
7777
self._idle_proxy_janitor_task.cancel()
7878
self._idle_proxy_janitor_task = None
@@ -441,7 +441,7 @@ def __init__(
441441
self._auth_expiry = auth_expiry
442442

443443
self._async_producer_lock = asyncio.Lock()
444-
self._async_producer: Optional[AIOKafkaProducer] = None
444+
self._async_producer: Optional[AsyncKafkaProducer] = None
445445
self.naming_strategy = NameStrategy(self.config["name_strategy"])
446446

447447
def __str__(self) -> str:
@@ -461,12 +461,12 @@ def auth_expiry(self) -> datetime.datetime:
461461
def num_consumers(self) -> int:
462462
return len(self.consumer_manager.consumers)
463463

464-
async def _maybe_create_async_producer(self) -> AIOKafkaProducer:
464+
async def _maybe_create_async_producer(self) -> AsyncKafkaProducer:
465465
if self._async_producer is not None:
466466
return self._async_producer
467467

468468
if self.config["producer_acks"] == "all":
469-
acks = "all"
469+
acks = -1
470470
else:
471471
acks = int(self.config["producer_acks"])
472472

@@ -477,33 +477,34 @@ async def _maybe_create_async_producer(self) -> AIOKafkaProducer:
477477

478478
log.info("Creating async producer")
479479

480-
# Don't retry if creating the SSL context fails, likely a configuration issue with
481-
# ciphers or certificate chains
482-
ssl_context = create_client_ssl_context(self.config)
483-
484-
# Don't retry if instantiating the producer fails, likely a configuration error.
485-
producer = AIOKafkaProducer(
480+
producer = AsyncKafkaProducer(
486481
acks=acks,
487482
bootstrap_servers=self.config["bootstrap_uri"],
488483
compression_type=self.config["producer_compression_type"],
489484
connections_max_idle_ms=self.config["connections_max_idle_ms"],
490485
linger_ms=self.config["producer_linger_ms"],
491-
max_request_size=self.config["producer_max_request_size"],
486+
message_max_bytes=self.config["producer_max_request_size"],
492487
metadata_max_age_ms=self.config["metadata_max_age_ms"],
493488
security_protocol=self.config["security_protocol"],
494-
ssl_context=ssl_context,
489+
ssl_cafile=self.config["ssl_cafile"],
490+
ssl_certfile=self.config["ssl_certfile"],
491+
ssl_keyfile=self.config["ssl_keyfile"],
492+
ssl_crlfile=self.config["ssl_crlfile"],
495493
**get_kafka_client_auth_parameters_from_config(self.config),
496494
)
497-
498495
try:
499496
await producer.start()
500-
except KafkaConnectionError:
497+
except (NoBrokersAvailable, AuthenticationFailedError):
498+
await producer.stop()
501499
if retry:
502500
log.exception("Unable to connect to the bootstrap servers, retrying")
503501
else:
504502
log.exception("Giving up after trying to connect to the bootstrap servers")
505503
raise
506504
await asyncio.sleep(1)
505+
except Exception:
506+
await producer.stop()
507+
raise
507508
else:
508509
self._async_producer = producer
509510

@@ -645,10 +646,8 @@ def init_admin_client(self):
645646
ssl_cafile=self.config["ssl_cafile"],
646647
ssl_certfile=self.config["ssl_certfile"],
647648
ssl_keyfile=self.config["ssl_keyfile"],
648-
api_version=(1, 0, 0),
649649
metadata_max_age_ms=self.config["metadata_max_age_ms"],
650650
connections_max_idle_ms=self.config["connections_max_idle_ms"],
651-
kafka_client=KarapaceKafkaClient,
652651
**get_kafka_client_auth_parameters_from_config(self.config, async_client=False),
653652
)
654653
break
@@ -1069,8 +1068,11 @@ async def produce_messages(self, *, topic: str, prepared_records: List) -> List:
10691068
if not isinstance(result, Exception):
10701069
produce_results.append(
10711070
{
1072-
"offset": result.offset if result else -1,
1073-
"partition": result.topic_partition.partition if result else 0,
1071+
# In case the offset is not available, `confluent_kafka.Message.offset()` is
1072+
# `None`. To preserve backwards compatibility, we replace this with -1.
1073+
# -1 was the default `aiokafka` behaviour.
1074+
"offset": result.offset() if result and result.offset() is not None else -1,
1075+
"partition": result.partition() if result else 0,
10741076
}
10751077
)
10761078

karapace/kafka_rest_apis/authentication.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ class SimpleOauthTokenProviderAsync(AbstractTokenProviderAsync):
142142
async def token(self) -> str:
143143
return self._token
144144

145+
def token_with_expiry(self, _config: str | None = None) -> tuple[str, int | None]:
146+
return (self._token, get_expiration_timestamp_from_jwt(self._token))
147+
145148

146149
class SASLOauthParams(TypedDict):
147150
sasl_mechanism: str

karapace/rapu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(
167167
self.app = self._create_aiohttp_application(config=config)
168168
self.log = logging.getLogger(self.app_name)
169169
self.stats = StatsClient(config=config)
170-
self.app.on_cleanup.append(self.close_by_app)
170+
self.app.on_shutdown.append(self.close_by_app)
171171
self.not_ready_handler = not_ready_handler
172172

173173
def _create_aiohttp_application(self, *, config: Config) -> aiohttp.web.Application:

tests/integration/kafka/test_producer.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77

88
from confluent_kafka.admin import NewTopic
99
from kafka.errors import MessageSizeTooLargeError, UnknownTopicOrPartitionError
10-
from karapace.kafka.producer import KafkaProducer
10+
from karapace.kafka.producer import AsyncKafkaProducer, KafkaProducer
1111
from karapace.kafka.types import Timestamp
12+
from tests.integration.utils.kafka_server import KafkaServers
13+
from typing import Iterator
1214

15+
import asyncio
1316
import pytest
1417
import time
1518

@@ -71,3 +74,61 @@ def test_partitions_for(self, producer: KafkaProducer, new_topic: NewTopic) -> N
7174
assert partitions[0].id == 0
7275
assert partitions[0].replicas == [1]
7376
assert partitions[0].isrs == [1]
77+
78+
79+
@pytest.fixture(scope="function", name="asyncproducer")
80+
async def fixture_asyncproducer(
81+
kafka_servers: KafkaServers,
82+
loop: asyncio.AbstractEventLoop,
83+
) -> Iterator[AsyncKafkaProducer]:
84+
asyncproducer = AsyncKafkaProducer(bootstrap_servers=kafka_servers.bootstrap_servers, loop=loop)
85+
await asyncproducer.start()
86+
yield asyncproducer
87+
await asyncproducer.stop()
88+
89+
90+
class TestAsyncSend:
91+
async def test_async_send(self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic) -> None:
92+
key = b"key"
93+
value = b"value"
94+
partition = 0
95+
timestamp = int(time.time() * 1000)
96+
headers = [("something", b"123"), (None, "foobar")]
97+
98+
aiofut = await asyncproducer.send(
99+
new_topic.topic,
100+
key=key,
101+
value=value,
102+
partition=partition,
103+
timestamp=timestamp,
104+
headers=headers,
105+
)
106+
message = await aiofut
107+
108+
assert message.offset() == 0
109+
assert message.partition() == partition
110+
assert message.topic() == new_topic.topic
111+
assert message.key() == key
112+
assert message.value() == value
113+
assert message.timestamp()[0] == Timestamp.CREATE_TIME
114+
assert message.timestamp()[1] == timestamp
115+
116+
async def test_async_send_raises_for_unknown_topic(self, asyncproducer: AsyncKafkaProducer) -> None:
117+
aiofut = await asyncproducer.send("nonexistent")
118+
119+
with pytest.raises(UnknownTopicOrPartitionError):
120+
_ = await aiofut
121+
122+
async def test_async_send_raises_for_unknown_partition(
123+
self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic
124+
) -> None:
125+
aiofut = await asyncproducer.send(new_topic.topic, partition=99)
126+
127+
with pytest.raises(UnknownTopicOrPartitionError):
128+
_ = await aiofut
129+
130+
async def test_async_send_raises_for_too_large_message(
131+
self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic
132+
) -> None:
133+
with pytest.raises(MessageSizeTooLargeError):
134+
await asyncproducer.send(new_topic.topic, value=b"x" * 1000001)

tests/integration/test_rest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,9 @@ async def test_internal(rest_async: KafkaRest | None, admin_client: KafkaAdminCl
227227
assert len(results) == 1
228228
for result in results:
229229
assert "error" in result, "Invalid result missing 'error' key"
230-
assert result["error"] == "Unrecognized partition"
230+
assert result["error"] == "This request is for a topic or partition that does not exist on this broker."
231231
assert "error_code" in result, "Invalid result missing 'error_code' key"
232-
assert result["error_code"] == 1
232+
assert result["error_code"] == 2
233233

234234
assert rest_async_proxy.all_empty({"records": [{"key": {"foo": "bar"}}]}, "key") is False
235235
assert rest_async_proxy.all_empty({"records": [{"value": {"foo": "bar"}}]}, "value") is False

tests/unit/test_authentication.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ async def test_simple_oauth_token_provider_async_returns_configured_token() -> N
120120
assert await token_provider.token() == "TOKEN"
121121

122122

123+
def test_simple_oauth_token_provider_async_returns_configured_token_and_expiry() -> None:
124+
expiry_timestamp = 1697013997
125+
token = jwt.encode({"exp": expiry_timestamp}, "secret")
126+
token_provider = SimpleOauthTokenProviderAsync(token)
127+
128+
assert token_provider.token_with_expiry() == (token, expiry_timestamp)
129+
130+
123131
def test_get_client_auth_parameters_from_config_sasl_plain() -> None:
124132
config = set_config_defaults(
125133
{"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}

0 commit comments

Comments
 (0)