11from aiokafka import AIOKafkaProducer
2+ from aiokafka .errors import KafkaConnectionError
23from binascii import Error as B64DecodeError
34from collections import namedtuple
45from http import HTTPStatus
@@ -36,13 +37,9 @@ class FormatError(Exception):
3637
3738
3839class KafkaRest (KarapaceBase ):
39- # pylint: disable=attribute-defined-outside-init
4040 def __init__ (self , config : dict ) -> None :
4141 super ().__init__ (config = config )
4242 self ._add_kafka_rest_routes ()
43- self ._init_kafka_rest (config = config )
44-
45- def _init_kafka_rest (self , config : dict ) -> None :
4643 self .serializer = SchemaRegistrySerializer (config = config )
4744 self .log = logging .getLogger ("KarapaceRest" )
4845 self ._cluster_metadata = None
@@ -54,8 +51,9 @@ def _init_kafka_rest(self, config: dict) -> None:
5451 self .schemas_cache = {}
5552 self .consumer_manager = ConsumerManager (config = config )
5653 self .init_admin_client ()
57- self .producer_refs = []
58- self .producer_queue = asyncio .Queue ()
54+
55+ self ._async_producer : Optional [AIOKafkaProducer ] = None
56+ self ._async_producer_lock = asyncio .Lock ()
5957
6058 def _add_kafka_rest_routes (self ) -> None :
6159 # Brokers
@@ -163,35 +161,41 @@ def _add_kafka_rest_routes(self) -> None:
163161 self .route ("/topics/<topic:path>" , callback = self .topic_details , method = "GET" , rest_request = True )
164162 self .route ("/topics/<topic:path>" , callback = self .topic_publish , method = "POST" , rest_request = True )
165163
166- async def get_producer (self ) -> AIOKafkaProducer :
167- if self .producer_queue .empty ():
168- for _ in range (self .config ["producer_count" ]):
169- self .log .info ("Creating async producers" )
170- p = await self ._create_async_producer ()
171- await self .producer_queue .put (p )
172- self .producer_refs .append (p )
173- return await self .producer_queue .get ()
164+ async def _maybe_create_async_producer (self ) -> AIOKafkaProducer :
165+ if self .config ["producer_acks" ] == "all" :
166+ acks = "all"
167+ else :
168+ acks = int (self .config ["producer_acks" ])
174169
175- async def _create_async_producer (self ) -> AIOKafkaProducer :
176- while True :
177- try :
178- acks = self .config ["producer_acks" ]
179- acks = acks if acks == "all" else int (acks )
180- p = AIOKafkaProducer (
170+ async with self ._async_producer_lock :
171+ while self ._async_producer is None :
172+ self .log .info ("Creating async producer" )
173+
174+ # Don't retry if creating the SSL context fails, likely a configuration issue with
175+ # ciphers or certificate chains
176+ ssl_context = create_client_ssl_context (self .config )
177+
178+ # Don't retry if instantiating the producer fails, likely a configuration error.
179+ producer = AIOKafkaProducer (
181180 bootstrap_servers = self .config ["bootstrap_uri" ],
182181 security_protocol = self .config ["security_protocol" ],
183- ssl_context = create_client_ssl_context ( self . config ) ,
182+ ssl_context = ssl_context ,
184183 metadata_max_age_ms = self .config ["metadata_max_age_ms" ],
185184 acks = acks ,
186185 compression_type = self .config ["producer_compression_type" ],
187186 linger_ms = self .config ["producer_linger_ms" ],
188187 connections_max_idle_ms = self .config ["connections_max_idle_ms" ],
189188 )
190- await p .start ()
191- return p
192- except : # pylint: disable=bare-except
193- self .log .exception ("Unable to start async producer, retrying" )
194- await asyncio .sleep (1 )
189+
190+ try :
191+ await producer .start ()
192+ except KafkaConnectionError :
193+ self .log .exception ("Unable to connect to the bootstrap servers, retrying" )
194+ await asyncio .sleep (1 )
195+ else :
196+ self ._async_producer = producer
197+
198+ return self ._async_producer
195199
196200 # CONSUMERS
197201 async def create_consumer (self , group_name : str , content_type : str , * , request : HTTPRequest ):
@@ -316,19 +320,14 @@ def init_admin_client(self):
316320 self .log .exception ("Unable to start admin client, retrying" )
317321 time .sleep (1 )
318322
319- async def close_producers (self ):
320- if not self .producer_refs :
321- return
322- for prod in self .producer_refs :
323- self .log .info ("Disposing of async producers" )
324- await prod .stop ()
325- self .producer_refs = None
326- self .producer_queue = None
327- return
328-
329323 async def close (self ) -> None :
330324 await super ().close ()
331- await self .close_producers ()
325+
326+ async with self ._async_producer_lock :
327+ if self ._async_producer is not None :
328+ self .log .info ("Disposing async producer" )
329+ await self ._async_producer .stop ()
330+
332331 if self .admin_client :
333332 self .admin_client .close ()
334333 self .admin_client = None
@@ -580,12 +579,18 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte
580579 )
581580
582581 async def produce_message (self , * , topic : str , key : bytes , value : bytes , partition : int = None ) -> dict :
583- prod = None
584582 try :
585- prod = await self .get_producer ()
586- result = await asyncio .wait_for (
587- fut = prod .send_and_wait (topic , key = key , value = value , partition = partition ), timeout = self .kafka_timeout
588- )
583+ producer = await self ._maybe_create_async_producer ()
584+
585+ # Cancelling the returned future **will not** stop event from being sent, but cancelling
586+ # the ``send`` coroutine itself **will**.
587+ coroutine = producer .send (topic , key = key , value = value , partition = partition )
588+
589+ # Schedule the co-routine, it will be cancelled if the it is not complete in
590+ # `self.kafka_timeout` seconds.
591+ future = await asyncio .wait_for (fut = coroutine , timeout = self .kafka_timeout )
592+
593+ result = await future
589594 return {
590595 "offset" : result .offset if result else - 1 ,
591596 "partition" : result .topic_partition .partition if result else 0 ,
@@ -603,9 +608,6 @@ async def produce_message(self, *, topic: str, key: bytes, value: bytes, partiti
603608 if hasattr (e , "retriable" ) and e .retriable :
604609 resp ["error_code" ] = 2
605610 return resp
606- finally :
607- if prod :
608- await self .producer_queue .put (prod )
609611
610612 def list_topics (self , content_type : str ):
611613 metadata = self .cluster_metadata ()
0 commit comments