From 32ffcfe16a11fe289634d0542a7f8294435b3b3c Mon Sep 17 00:00:00 2001 From: Marek Szymutko Date: Wed, 25 Mar 2026 15:19:24 +0100 Subject: [PATCH 1/5] feat(ISV-7020): add chunking capabilities * chunking enable only in retry topics * tests assisted with Claude Signed-off-by: Marek Szymutko Assisted-by: Claude-4.6-Opus --- README.md | 19 +++ src/retriable_kafka_client/chunking.py | 102 +++++++++++ src/retriable_kafka_client/config.py | 1 + src/retriable_kafka_client/consumer.py | 27 ++- .../consumer_tracking.py | 106 +++++------- src/retriable_kafka_client/headers.py | 43 +++++ src/retriable_kafka_client/kafka_settings.py | 13 +- src/retriable_kafka_client/kafka_utils.py | 91 ++++++++++ src/retriable_kafka_client/producer.py | 158 ++++++++++++------ src/retriable_kafka_client/retry_utils.py | 60 ++++--- tests/integration/integration_utils.py | 38 ++++- tests/integration/test_chunking.py | 88 ++++++++++ .../integration/test_user_function_filter.py | 5 +- tests/unit/test_chunking.py | 44 +++++ tests/unit/test_consumer.py | 106 +++++++++--- tests/unit/test_consumer_tracking.py | 133 +++++++-------- tests/unit/test_producer.py | 108 ++++++++++-- tests/unit/test_retry_utils.py | 58 +++++-- uv.lock | 2 +- 19 files changed, 898 insertions(+), 304 deletions(-) create mode 100644 src/retriable_kafka_client/chunking.py create mode 100644 src/retriable_kafka_client/headers.py create mode 100644 tests/integration/test_chunking.py create mode 100644 tests/unit/test_chunking.py diff --git a/README.md b/README.md index 9bb017a..c25f975 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,25 @@ This whole mechanism **does not ensure message ordering**. When a message is sent to be retried, another message processing from the same topic is still unblocked. +### Message splitting + +This library also supports sending large messages to topics which don't have +the capacity to process these messages as whole. To do this, producers are +configurable to automatically split messages according to message size. + +This feature is custom and is therefore turned off by default. The only place +this feature is always enabled are retry-topics, which are meant to be consumed +only by clients using this library. + +The chunked messages have 3 additional headers: + +* Group ID (uuid4 value) +* Chunk ID (serial number of the chunk within group, starting with 0) +* Number of chunks (is always +1 from the last chunk ID) + +Message is deserialized and processed only if all expected chunks have been +found. + ## Contributing guidelines To check contributing guidelines, please check `CONTRIBUTING.md` in the diff --git a/src/retriable_kafka_client/chunking.py b/src/retriable_kafka_client/chunking.py new file mode 100644 index 0000000..2a0468d --- /dev/null +++ b/src/retriable_kafka_client/chunking.py @@ -0,0 +1,102 @@ +"""Module for tracking chunked messages""" + +import logging +import sys +import uuid +from collections import defaultdict + +from confluent_kafka import Message + +from retriable_kafka_client.headers import ( + CHUNK_GROUP_HEADER, + NUMBER_OF_CHUNKS_HEADER, + CHUNK_ID_HEADER, + deserialize_number_from_bytes, + get_header_value, +) +from retriable_kafka_client.kafka_utils import MessageGroup + +LOGGER = logging.getLogger(__name__) + + +def generate_group_id() -> bytes: + """Generate a random group id.""" + return uuid.uuid4().bytes + + +def calculate_header_size( + headers: dict[str, str | bytes] | list[tuple[str, str | bytes]] | None, +) -> int: + """Approximate the space needed for headers within a message.""" + if not headers: + return 0 + result = 0 + if not isinstance(headers, dict): + headers = dict(headers) + for header_key, header_value in headers.items(): + result += len(header_key) + sys.getsizeof(header_value) + 8 + # Two 32-bit numbers specifying lengths of fields are also + # present, therefore adding 8 + return result + + +class ChunkingCache: + # pylint: disable=too-few-public-methods + """Class for storing information about received message fragments.""" + + def __init__(self): + # The tuple holds group ID, topic and partition + # in case of group ID collision (that could mean an attack attempt). + # If the same consumer was used for different topics, an adversary + # may want to override split messages to execute different operations. + # By using group id as well as topic and partition, this attack is + # made impossible. + self._message_chunks: dict[tuple[bytes, str, int], dict[int, Message]] = ( + defaultdict(dict) + ) + + def receive(self, message: Message) -> MessageGroup | None: + """ + Receive a message. If the message is whole, or it is + the last fragment, returns the whole message group + and flushes cache of this group. Otherwise, returns None + and the message fragment shall not be processed. + """ + topic: str = message.topic() # type: ignore[assignment] + partition: int = message.partition() # type: ignore[assignment] + if ( + (group_id := get_header_value(message, CHUNK_GROUP_HEADER)) is not None + and ( + number_of_chunks_raw := get_header_value( + message, NUMBER_OF_CHUNKS_HEADER + ) + ) + is not None + and (chunk_id_raw := get_header_value(message, CHUNK_ID_HEADER)) is not None + ): + number_of_chunks = deserialize_number_from_bytes(number_of_chunks_raw) + chunk_id = deserialize_number_from_bytes(chunk_id_raw) + identifier = (group_id, topic, partition) + stored_message_ids_from_group = set( + self._message_chunks.get(identifier, {}).keys() + ) + stored_message_ids_from_group.add(chunk_id) + if not all( + i in stored_message_ids_from_group for i in range(number_of_chunks) + ): + LOGGER.debug( + "Received a message chunk, waiting for the other chunks..." + ) + self._message_chunks[identifier][chunk_id] = message + return None + LOGGER.debug( + "Received all message chunks, assembling group %s composed of %s messages.", + group_id.hex(), + number_of_chunks, + ) + # Clear cache and reassemble, cache can be empty if + # this is just one-message-sized value + messages = self._message_chunks.pop(identifier, {}) + messages[chunk_id] = message + return MessageGroup(topic, partition, messages, group_id) + return MessageGroup(topic, partition, {0: message}, None) diff --git a/src/retriable_kafka_client/config.py b/src/retriable_kafka_client/config.py index b02ff2b..ff646a1 100644 --- a/src/retriable_kafka_client/config.py +++ b/src/retriable_kafka_client/config.py @@ -42,6 +42,7 @@ class ProducerConfig(CommonConfig): retries: int = field(default=3) fallback_factor: float = field(default=2.0) fallback_base: float = field(default=5.0) + split_messages: bool = field(default=False) @dataclass diff --git a/src/retriable_kafka_client/consumer.py b/src/retriable_kafka_client/consumer.py index c230877..4a14863 100644 --- a/src/retriable_kafka_client/consumer.py +++ b/src/retriable_kafka_client/consumer.py @@ -1,6 +1,5 @@ """Base Kafka Consumer module""" -import json import logging import sys from concurrent.futures import Executor, Future @@ -9,8 +8,9 @@ from confluent_kafka import Consumer, Message, KafkaException, TopicPartition +from .chunking import ChunkingCache from .health import perform_healthcheck_using_client -from .kafka_utils import message_to_partition +from .kafka_utils import message_to_partition, MessageGroup from .kafka_settings import KafkaOptions, DEFAULT_CONSUMER_SETTINGS from .consumer_tracking import TrackingManager from .config import ConsumerConfig @@ -91,6 +91,8 @@ def __init__( self.__retry_manager = RetryManager(config) # Store information about pending retried messages self.__schedule_cache = RetryScheduleCache() + # Store chunking information + self.__chunk_tracker = ChunkingCache() @property def _consumer(self) -> Consumer: @@ -140,7 +142,7 @@ def __on_revoke(self, _: Consumer, partitions: list[TopicPartition]) -> None: self.__tracking_manager.register_revoke(partitions) self.__perform_commits() - def __ack_message(self, message: Message, finished_future: Future) -> None: + def __ack_message(self, message: MessageGroup, finished_future: Future) -> None: """ Private method only ever intended to be used from within _process_message(). It commits offsets and releases @@ -155,7 +157,7 @@ def __ack_message(self, message: Message, finished_future: Future) -> None: if problem := finished_future.exception(): LOGGER.error( "Message could not be processed! Message: %s.", - message.value(), + message.deserialize(), exc_info=problem, ) self.__retry_manager.resend_message(message) @@ -212,20 +214,17 @@ def _process_message(self, message: Message) -> Future[Any] | None: Returns: Future of the target execution if the message can be processed. None otherwise. """ - message_value = message.value() - if not message_value: - # Discard empty messages + message_group = self.__chunk_tracker.receive(message) + if not message_group: return None - try: - message_data = json.loads(message_value) - except json.decoder.JSONDecodeError: - # This message cannot be deserialized, just log and discard it - LOGGER.exception("Decoding error: not a valid JSON: %s", message.value()) + message_data = message_group.deserialize() + if not message_data: + self.__tracking_manager.schedule_commit(message_group) return None future = self._executor.submit(self._config.target, message_data) - self.__tracking_manager.process_message(message, future) + self.__tracking_manager.process_message(message_group, future) # The semaphore is released within this callback - future.add_done_callback(lambda res: self.__ack_message(message, res)) + future.add_done_callback(lambda res: self.__ack_message(message_group, res)) return future ### Public methods ### diff --git a/src/retriable_kafka_client/consumer_tracking.py b/src/retriable_kafka_client/consumer_tracking.py index 9c8e140..1ccc30e 100644 --- a/src/retriable_kafka_client/consumer_tracking.py +++ b/src/retriable_kafka_client/consumer_tracking.py @@ -3,51 +3,19 @@ import logging from collections import defaultdict from concurrent.futures import Future +from itertools import chain from threading import Lock, Semaphore -from typing import NamedTuple, Any +from typing import Any, Iterable -from confluent_kafka import Message, TopicPartition +from confluent_kafka import TopicPartition + +from retriable_kafka_client.kafka_utils import TrackingInfo, MessageGroup LOGGER = logging.getLogger(__name__) -class _PartitionInfo(NamedTuple): - """ - Consistently hashable dataclass for storing information about a partition, - namely offset information. Can be used as keys in a dictionary. - """ - - topic: str - partition: int - - @staticmethod - def from_message(message: Message) -> "_PartitionInfo": - """ - Create a PartitionInfo from a Kafka message. - Args: - message: Kafka message object - Returns: hashable info about a partition - """ - message_topic = message.topic() - message_partition = message.partition() - # This should never happen with polled messages. Polled messages need - # the information asserted to be valid Kafka messages. This can - # happen only for custom-created messages objects, which this - # method is not intended to be used for - assert message_topic is not None and message_partition is not None, ( - "Invalid message cannot be converted to partition info" - ) - return _PartitionInfo(message_topic, message_partition) - - def to_offset_info(self, offset: int) -> TopicPartition: - """ - Create a Kafka-committable object using the provided offset. - Args: - offset: The offset to be committed. Make sure to commit - offset one higher than the latest processed message. - Returns: The committable Kafka object - """ - return TopicPartition(topic=self.topic, partition=self.partition, offset=offset) +def _flatten_offsets(done_offsets: Iterable[tuple[int, ...]]) -> list[int]: + return list(chain(*done_offsets)) class TrackingManager: @@ -85,8 +53,10 @@ class TrackingManager: """ def __init__(self, concurrency: int, cancel_wait_time: float): - self.__to_process: dict[_PartitionInfo, dict[int, Future]] = defaultdict(dict) - self.__to_commit: dict[_PartitionInfo, set[int]] = defaultdict(set) + self.__to_process: dict[TrackingInfo, dict[tuple[int, ...], Future]] = ( + defaultdict(dict) + ) + self.__to_commit: dict[TrackingInfo, set[tuple[int, ...]]] = defaultdict(set) self.__access_lock = Lock() # For handling multithreaded access to this object self.__semaphore = Semaphore(concurrency) self.__cancel_wait_time = cancel_wait_time @@ -105,15 +75,15 @@ def pop_committable(self) -> list[TopicPartition]: """ to_commit = [] with self.__access_lock: - for partition_info, pending_to_commit in self.__to_commit.items(): - if not pending_to_commit: + for partition_info, tuples_pending_to_commit in self.__to_commit.items(): + if not tuples_pending_to_commit: # Nothing to commit continue - pending_to_process = self.__to_process.get(partition_info, None) - if not pending_to_process: + tuples_pending_to_process = self.__to_process.get(partition_info, None) + if not tuples_pending_to_process: # Nothing is blocking the committing - max_to_commit = max(pending_to_commit) + max_to_commit = max(_flatten_offsets(tuples_pending_to_commit)) to_commit.append( TopicPartition( topic=partition_info.topic, @@ -123,17 +93,18 @@ def pop_committable(self) -> list[TopicPartition]: ) self.__to_commit[partition_info] = set() continue - - min_pending_to_process = min(pending_to_process) + min_pending_to_process = min( + _flatten_offsets(tuples_pending_to_process) + ) commit_candidates = { - offset - for offset in pending_to_commit - if offset < min_pending_to_process + offset_tuple + for offset_tuple in tuples_pending_to_commit + if all(offset < min_pending_to_process for offset in offset_tuple) } if not commit_candidates: # Nothing to commit continue - max_to_commit = max(commit_candidates) + max_to_commit = max(_flatten_offsets(commit_candidates)) to_commit.append( TopicPartition( topic=partition_info.topic, @@ -157,11 +128,11 @@ def reschedule_uncommittable( failed_committable: list of data that failed to be committed """ for failed in failed_committable: - self.__to_commit.setdefault( - _PartitionInfo(topic=failed.topic, partition=failed.partition), set() - ).add(failed.offset) + self.__to_commit[ + TrackingInfo(topic=failed.topic, partition=failed.partition) + ].add((failed.offset,)) - def process_message(self, message: Message, future: Future[Any]) -> None: + def process_message(self, message: MessageGroup, future: Future[Any]) -> None: """ Mark message as pending for processing. Args: @@ -170,15 +141,16 @@ def process_message(self, message: Message, future: Future[Any]) -> None: """ # We cannot really use context manager, the semaphore is released in # future's callback or when the future is cancelled + self.__semaphore.acquire() # pylint: disable=consider-using-with - message_offset: int = message.offset() # type: ignore[assignment] + message_offsets = message.offsets with self.__access_lock: # Mark the message as being processed - self.__to_process[_PartitionInfo.from_message(message)][ - message_offset + 1 + self.__to_process[TrackingInfo.from_message_group(message)][ + tuple(message_offset + 1 for message_offset in message_offsets) ] = future - def schedule_commit(self, message: Message) -> bool: + def schedule_commit(self, message: MessageGroup) -> bool: """ Mark message as pending for committing when its processing is fully done. Args: @@ -188,12 +160,12 @@ def schedule_commit(self, message: Message) -> bool: as pending for processing), False otherwise """ self.__semaphore.release() - partition_info = _PartitionInfo.from_message(message) - message_offset: int = message.offset() # type: ignore[assignment] - stored_offset = message_offset + 1 + partition_info = TrackingInfo.from_message_group(message) + message_offsets = message.offsets + stored_offsets = tuple(message_offset + 1 for message_offset in message_offsets) with self.__access_lock: - self.__to_process[partition_info].pop(stored_offset, None) - self.__to_commit.setdefault(partition_info, set()).add(stored_offset) + self.__to_process[partition_info].pop(stored_offsets, None) + self.__to_commit.setdefault(partition_info, set()).add(stored_offsets) self._cleanup() return True @@ -212,7 +184,7 @@ def _cleanup(self) -> None: cache_to_clean.pop(key, None) def _revoke_processing( - self, revoked_partitions: set[_PartitionInfo] + self, revoked_partitions: set[TrackingInfo] ) -> list[Future[Any]]: """ Cancel all pending tracked futures related to the given partitions. @@ -253,7 +225,7 @@ def register_revoke(self, partitions: list[TopicPartition] | None = None) -> Non revoked_partition_keys = set(self.__to_process.keys()) else: revoked_partition_keys = { - _PartitionInfo(partition=partition.partition, topic=partition.topic) + TrackingInfo(partition=partition.partition, topic=partition.topic) for partition in partitions } pending_futures = self._revoke_processing(revoked_partition_keys) diff --git a/src/retriable_kafka_client/headers.py b/src/retriable_kafka_client/headers.py new file mode 100644 index 0000000..ce8fe0d --- /dev/null +++ b/src/retriable_kafka_client/headers.py @@ -0,0 +1,43 @@ +""" +Module with definitions and utilities +related to Kafka headers +""" + +from confluent_kafka import Message + +from retriable_kafka_client.kafka_utils import MessageGroup + +_HEADER_PREFIX = "retriable_kafka_" +# Header name which holds the number of retry attempts +ATTEMPT_HEADER = _HEADER_PREFIX + "attempt" +# Header name which holds the timestamp of next reprocessing +TIMESTAMP_HEADER = _HEADER_PREFIX + "timestamp" +# Header name which holds the ID of a chunk +CHUNK_GROUP_HEADER = _HEADER_PREFIX + "chunk_group" +# Header name which holds the total number of chunks within group +NUMBER_OF_CHUNKS_HEADER = _HEADER_PREFIX + "number_of_chunks" +# Header name which holds the serial number of chunk within a group +CHUNK_ID_HEADER = _HEADER_PREFIX + "chunk_id" + + +def serialize_number_to_bytes(value: int | float) -> bytes: + """Store a number as bytes, converts to integers first.""" + return int(value).to_bytes(length=8, byteorder="big") + + +def deserialize_number_from_bytes(value: bytes) -> int: + """Restore integer from bytes.""" + return int.from_bytes(value, byteorder="big") + + +def get_header_value( + message: Message | MessageGroup, searched_header_name: str +) -> bytes | None: + """Fetch header value from message.""" + message_headers = message.headers() + if not message_headers: + return None + for header_name, header_value in message_headers: + if header_name == searched_header_name: + return header_value + return None diff --git a/src/retriable_kafka_client/kafka_settings.py b/src/retriable_kafka_client/kafka_settings.py index 4997caa..b2205f5 100644 --- a/src/retriable_kafka_client/kafka_settings.py +++ b/src/retriable_kafka_client/kafka_settings.py @@ -1,5 +1,12 @@ """Settings definitions for Kafka. Introduces intended defaults.""" +DEFAULT_MESSAGE_SIZE = 1000000 # default in librdkafka +MESSAGE_OVERHEAD = 500 # When splitting messages, +# the exact length cannot be easily computed. +# Therefore, we check the size of passed objects +# and subtract additional 500 B out of default 1 MB +# This number also includes custom headers from this library + class KafkaOptions: """ @@ -16,6 +23,7 @@ class KafkaOptions: USERNAME = "sasl.username" PASSWORD = "sasl.password" PARTITION_ASSIGNMENT_STRAT = "partition.assignment.strategy" + MAX_MESSAGE_SIZE = "message.max.bytes" _DEFAULT_COMMON_SETTINGS = { @@ -30,4 +38,7 @@ class KafkaOptions: **_DEFAULT_COMMON_SETTINGS, } -DEFAULT_PRODUCER_SETTINGS = {**_DEFAULT_COMMON_SETTINGS} +DEFAULT_PRODUCER_SETTINGS = { + KafkaOptions.MAX_MESSAGE_SIZE: DEFAULT_MESSAGE_SIZE, + **_DEFAULT_COMMON_SETTINGS, +} diff --git a/src/retriable_kafka_client/kafka_utils.py b/src/retriable_kafka_client/kafka_utils.py index 2d2d6bc..b249dd8 100644 --- a/src/retriable_kafka_client/kafka_utils.py +++ b/src/retriable_kafka_client/kafka_utils.py @@ -1,7 +1,14 @@ """Module for kafka utility functions""" +import json +import logging +from dataclasses import dataclass +from typing import NamedTuple, Any + from confluent_kafka import Message, TopicPartition +LOGGER = logging.getLogger(__name__) + def message_to_partition(message: Message) -> TopicPartition: """Convert message to info about a partition.""" @@ -13,3 +20,87 @@ def message_to_partition(message: Message) -> TopicPartition: "Maybe an error was unchecked?" ) return TopicPartition(topic, partition, offset) + + +@dataclass +class MessageGroup: + """ + Class for grouping messages which share chunked data. + Attributes: + topic: The topic of the messages. + partition: The partition of the messages + (all should share the same partition). + messages: The messages within this group. + The keys are chunk IDs (different from + offsets). + group_id: The group ID, UUID4. Is None + when the received message is standalone. + """ + + topic: str + partition: int + messages: dict[int, Message] + group_id: bytes | None = None + + @property + def all_chunks(self) -> list[bytes]: + """ + Return all data, chunked according to messages. + """ + result = [] + for _, message in sorted(self.messages.items(), key=lambda i: i[0]): + message_value = message.value() + if message_value is not None: + result.append(message_value) + return result + + def deserialize(self) -> dict[str, Any] | None: + """Deserialize messages into dict.""" + cumulative_value = b"" + for _, message in sorted(self.messages.items(), key=lambda i: i[0]): + message_value = message.value() + if not message_value: + # Discard empty messages + continue + cumulative_value += message_value + try: + return json.loads(cumulative_value) + except json.decoder.JSONDecodeError: + # This message cannot be deserialized, just log and discard it + LOGGER.exception("Decoding error: not a valid JSON: %s", cumulative_value) + return None + + @property + def offsets(self) -> list[int]: + """Return all offsets of messages within this group.""" + return list( + message.offset() # type: ignore[misc] + for message in self.messages.values() + ) + + def headers(self) -> list[tuple[str, bytes]] | None: + """Return headers of these messages (first message's headers).""" + first_message = next(iter(self.messages.values()), None) + return first_message.headers() if first_message else None + + +class TrackingInfo(NamedTuple): + """ + Consistently hashable dataclass for storing information about a partition, + namely offset information. Can be used as keys in a dictionary. + """ + + topic: str + partition: int + + @staticmethod + def from_message_group(message_group: MessageGroup) -> "TrackingInfo": + """ + Create a PartitionInfo from a Kafka message. + Args: + message_group: Kafka composite message object + Returns: hashable info about a partition + """ + message_topic = message_group.topic + message_partition = message_group.partition + return TrackingInfo(message_topic, message_partition) diff --git a/src/retriable_kafka_client/producer.py b/src/retriable_kafka_client/producer.py index 278430b..0355c2f 100644 --- a/src/retriable_kafka_client/producer.py +++ b/src/retriable_kafka_client/producer.py @@ -4,12 +4,25 @@ import json import logging import time +from copy import copy from typing import Any from confluent_kafka import Producer, KafkaException +from .chunking import generate_group_id, calculate_header_size +from .headers import ( + CHUNK_GROUP_HEADER, + NUMBER_OF_CHUNKS_HEADER, + serialize_number_to_bytes, + CHUNK_ID_HEADER, +) from .health import perform_healthcheck_using_client -from .kafka_settings import KafkaOptions, DEFAULT_PRODUCER_SETTINGS +from .kafka_settings import ( + KafkaOptions, + DEFAULT_PRODUCER_SETTINGS, + DEFAULT_MESSAGE_SIZE, + MESSAGE_OVERHEAD, +) from .config import ProducerConfig LOGGER = logging.getLogger(__name__) @@ -28,6 +41,13 @@ def __init__(self, config: ProducerConfig): """ self._config = config self.__producer_object: Producer | None = None + self._config_dict = { + KafkaOptions.KAFKA_NODES: ",".join(self._config.kafka_hosts), + KafkaOptions.USERNAME: self._config.username, + KafkaOptions.PASSWORD: self._config.password, + **DEFAULT_PRODUCER_SETTINGS, + } + self._config_dict.update(**self._config.additional_settings) @property def topics(self) -> list[str]: @@ -40,22 +60,44 @@ def _producer(self) -> Producer: Get and cache the producer object. """ if not self.__producer_object: - config_dict = { - KafkaOptions.KAFKA_NODES: ",".join(self._config.kafka_hosts), - KafkaOptions.USERNAME: self._config.username, - KafkaOptions.PASSWORD: self._config.password, - **DEFAULT_PRODUCER_SETTINGS, - } - config_dict.update(**self._config.additional_settings) - self.__producer_object = Producer(config_dict) + self.__producer_object = Producer(self._config_dict) return self.__producer_object - @staticmethod - def __serialize_message(message: dict[str, Any] | bytes) -> bytes: + def _get_chunk_size( + self, + headers: list[tuple[str, str | bytes]] | dict[str, str | bytes] | None, + ) -> int: + """Calculate chunk size to fit messages into Kafka.""" + chunk_size_base: int = self._config_dict.get( # type: ignore[assignment] + KafkaOptions.MAX_MESSAGE_SIZE, DEFAULT_MESSAGE_SIZE + ) + return chunk_size_base - calculate_header_size(headers) - MESSAGE_OVERHEAD + + def __serialize_message( + self, + message: dict[str, Any] | bytes | list[bytes], + headers: list[tuple[str, str | bytes]] | dict[str, str | bytes] | None, + split_messages: bool, + ) -> list[bytes]: """Convert message to bytes if needed.""" + # Get the information from rendered config dict + # to take user overrides into consideration + chunk_size = self._get_chunk_size(headers) + if isinstance(message, list): + if all(len(chunk) <= chunk_size for chunk in message) and split_messages: + return message + # Split is wrong, needs re-chunking + message = b"".join(message) if isinstance(message, bytes): - return message - return json.dumps(message).encode("utf-8") + full_bytes = message + else: + full_bytes = json.dumps(message).encode("utf-8") + if split_messages: + result = [] + for i in range(0, len(full_bytes), chunk_size): + result.append(full_bytes[i : i + chunk_size]) + return result + return [full_bytes] def __calculate_backoff(self, attempt_idx: int) -> float: """Calculate exponential backoff time for a given attempt.""" @@ -82,7 +124,7 @@ def __handle_problems(problems: dict[str, Exception]) -> None: def send_sync( self, - message: dict[str, Any] | bytes, + message: dict[str, Any] | bytes | list[bytes], headers: dict[str, str | bytes] | None = None, ) -> None: """ @@ -97,32 +139,43 @@ def send_sync( BufferError: if Kafka queue is full even after all attempts KafkaException: if some Kafka error occurs even after all attempts """ - byte_message = self.__serialize_message(message) + chunks = self.__serialize_message(message, headers, self._config.split_messages) + number_of_chunks = len(chunks) problems: dict[str, Exception] = {} timestamp = int(time.time() * 1000) # Kafka expects milliseconds for topic in self._config.topics: - for attempt_idx in range(self._config.retries + 1): - try: - self._producer.produce( - topic=topic, - value=byte_message, - timestamp=timestamp, - headers=headers, + group_id = generate_group_id() + for chunk_id, chunk in enumerate(chunks): + headers = copy(headers) if headers else {} + if self._config.split_messages: + headers[CHUNK_GROUP_HEADER] = group_id + headers[NUMBER_OF_CHUNKS_HEADER] = serialize_number_to_bytes( + number_of_chunks ) - break - except (BufferError, KafkaException) as err: - if attempt_idx < self._config.retries: - backoff_time = self.__calculate_backoff(attempt_idx) - self.__log_retry(attempt_idx, backoff_time) - time.sleep(backoff_time) - continue - problems[topic] = err + headers[CHUNK_ID_HEADER] = serialize_number_to_bytes(chunk_id) + for attempt_idx in range(self._config.retries + 1): + try: + self._producer.produce( + topic=topic, + value=chunk, + timestamp=timestamp, + headers=headers, + key=group_id, + ) + break + except (BufferError, KafkaException) as err: + if attempt_idx < self._config.retries: + backoff_time = self.__calculate_backoff(attempt_idx) + self.__log_retry(attempt_idx, backoff_time) + time.sleep(backoff_time) + continue + problems[topic] = err self._producer.flush() self.__handle_problems(problems) async def send( self, - message: dict[str, Any] | bytes, + message: dict[str, Any] | bytes | list[bytes], headers: dict[str, str | bytes] | None = None, ) -> None: """ @@ -137,26 +190,37 @@ async def send( BufferError: if Kafka queue is full even after all attempts KafkaException: if some Kafka error occurs even after all attempts """ - byte_message = self.__serialize_message(message) + chunks = self.__serialize_message(message, headers, self._config.split_messages) + number_of_chunks = len(chunks) problems: dict[str, Exception] = {} timestamp = int(time.time() * 1000) # Kafka expects milliseconds for topic in self._config.topics: - for attempt_idx in range(self._config.retries + 1): - try: - self._producer.produce( - topic=topic, - value=byte_message, - timestamp=timestamp, - headers=headers, + group_id = generate_group_id() + for chunk_id, chunk in enumerate(chunks): + headers = copy(headers) if headers else {} + if self._config.split_messages: + headers[CHUNK_GROUP_HEADER] = group_id + headers[NUMBER_OF_CHUNKS_HEADER] = serialize_number_to_bytes( + number_of_chunks ) - break - except (BufferError, KafkaException) as err: - if attempt_idx < self._config.retries: - backoff_time = self.__calculate_backoff(attempt_idx) - self.__log_retry(attempt_idx, backoff_time) - await asyncio.sleep(backoff_time) - continue - problems[topic] = err + headers[CHUNK_ID_HEADER] = serialize_number_to_bytes(chunk_id) + for attempt_idx in range(self._config.retries + 1): + try: + self._producer.produce( + topic=topic, + value=chunk, + timestamp=timestamp, + headers=headers, + key=group_id, + ) + break + except (BufferError, KafkaException) as err: + if attempt_idx < self._config.retries: + backoff_time = self.__calculate_backoff(attempt_idx) + self.__log_retry(attempt_idx, backoff_time) + await asyncio.sleep(backoff_time) + continue + problems[topic] = err self._producer.flush() self.__handle_problems(problems) diff --git a/src/retriable_kafka_client/retry_utils.py b/src/retriable_kafka_client/retry_utils.py index 8e4d5ec..76a8959 100644 --- a/src/retriable_kafka_client/retry_utils.py +++ b/src/retriable_kafka_client/retry_utils.py @@ -24,11 +24,16 @@ from confluent_kafka import Message, KafkaException, TopicPartition from .config import ProducerConfig, ConsumerConfig, ConsumeTopicConfig +from .headers import ( + TIMESTAMP_HEADER, + ATTEMPT_HEADER, + serialize_number_to_bytes, + deserialize_number_from_bytes, + get_header_value, +) +from .kafka_utils import MessageGroup from .producer import BaseProducer -_HEADER_PREFIX = "retriable_kafka_" -ATTEMPT_HEADER = _HEADER_PREFIX + "attempt" -TIMESTAMP_HEADER = _HEADER_PREFIX + "timestamp" LOGGER = logging.getLogger(__name__) @@ -42,29 +47,23 @@ def _get_retry_timestamp(message: Message) -> float | None: Returns: the timestamp in POSIX format or None if no timestamp was found """ - headers = message.headers() - if headers is None: - return None - for header_name, header_value in headers: - if header_name == TIMESTAMP_HEADER: - return int.from_bytes(header_value, "big") + header_val = get_header_value(message, TIMESTAMP_HEADER) + if header_val is not None: + return deserialize_number_from_bytes(header_val) return None -def _get_retry_attempt(message: Message) -> int: +def _get_retry_attempt(message_group: MessageGroup) -> int: """ Retrieves the attempt number from the message's header. Args: - message: Kafka message object + message_group: Kafka message group object Returns: the number of attempt or 0 if no attempt header was found """ - headers = message.headers() - if headers is None: - return 0 - for header_name, header_value in headers: - if header_name == ATTEMPT_HEADER: - return int.from_bytes(header_value, "big") + header_val = get_header_value(message_group, ATTEMPT_HEADER) + if header_val is not None: + return deserialize_number_from_bytes(header_val) return 0 @@ -239,12 +238,15 @@ def __populate_topics_and_producers(self) -> None: username=self.__config.username, password=self.__config.password, additional_settings=self.__config.additional_settings, + split_messages=True, ) producer = BaseProducer(config=producer_config) self.__retry_producers[topic_config.retry_topic] = producer self.__retry_producers[topic_config.base_topic] = producer - def _get_retry_headers(self, message: Message) -> dict[str, str | bytes] | None: + def _get_retry_headers( + self, message: MessageGroup + ) -> dict[str, str | bytes] | None: """ Create a dictionary of retry headers that will be used for the retried mechanism. The headers are generated based on the headers from the previous message. @@ -252,18 +254,17 @@ def _get_retry_headers(self, message: Message) -> dict[str, str | bytes] | None: message: Kafka message that will be retried Returns: dictionary of retry headers used for next sending """ - message_topic: str = message.topic() # type: ignore[assignment] - relevant_config = self.__topic_lookup.get(message_topic) + relevant_config = self.__topic_lookup.get(message.topic) if relevant_config is None: return None previous_attempt = _get_retry_attempt(message) retry_timestamp = _get_current_timestamp() + relevant_config.fallback_delay return { - ATTEMPT_HEADER: (previous_attempt + 1).to_bytes(length=8, byteorder="big"), - TIMESTAMP_HEADER: int(retry_timestamp).to_bytes(length=8, byteorder="big"), + ATTEMPT_HEADER: serialize_number_to_bytes(previous_attempt + 1), + TIMESTAMP_HEADER: serialize_number_to_bytes(retry_timestamp), } - def resend_message(self, message: Message) -> None: + def resend_message(self, message: MessageGroup) -> None: """ Send the message's copy to the specified retry topic. Also update its headers so that it will only be retried after @@ -275,10 +276,7 @@ def resend_message(self, message: Message) -> None: Args: message: the Kafka message that failed to be processed """ - message_topic = message.topic() - message_value = message.value() - if message_topic is None or message_value is None: - return + message_topic = message.topic relevant_producer = self.__retry_producers.get(message_topic) if relevant_producer is None: LOGGER.debug( @@ -298,23 +296,23 @@ def resend_message(self, message: Message) -> None: "Message will not be retried.", message_topic, relevant_config.retries, - extra={"message_raw": str(message.value())}, + extra={"message_raw": str(message.all_chunks)}, ) return try: relevant_producer.send_sync( - message_value, headers=self._get_retry_headers(message) + message.all_chunks, headers=self._get_retry_headers(message) ) LOGGER.debug( "Message from topic sent for reprocessing, %s", message_topic, - extra={"message_raw": str(message.value())}, + extra={"message_raw": str(message.all_chunks)}, ) except (TypeError, BufferError, KafkaException): LOGGER.exception( "Cannot resend message from topic: %s to its retry topic %s", message_topic, relevant_producer.topics, - extra={"message_raw": str(message.value())}, + extra={"message_raw": str(message.all_chunks)}, ) diff --git a/tests/integration/integration_utils.py b/tests/integration/integration_utils.py index 47155a2..c9d9ee9 100644 --- a/tests/integration/integration_utils.py +++ b/tests/integration/integration_utils.py @@ -5,7 +5,7 @@ import threading import time from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Callable from confluent_kafka.admin import AdminClient, NewTopic @@ -44,15 +44,25 @@ class MessageGenerator: def __init__(self) -> None: self._call_count = 0 - def generate(self, count: int) -> list[dict[str, Any]]: + def generate( + self, count: int, extra_fields: dict[str, Any] | None = None + ) -> list[dict[str, Any]]: """ Generate a chunk of messages. Args: count: Number of messages to generate in this chunk + extra_fields: Extra fields to add to the generated messages """ + extra_fields = extra_fields or {} result = [] for _ in range(count): - result.append({"id": self._call_count, "message": "This is a test message"}) + result.append( + { + "id": self._call_count, + "message": "This is a test message", + **extra_fields, + } + ) self._call_count += 1 return result @@ -226,6 +236,8 @@ class ScaffoldConfig: topics: list[ConsumeTopicConfig] group_id: str timeout: float = 15.0 + split_messages: bool = False + additional_settings: dict[str, Any] = field(default_factory=dict) class IntegrationTestScaffold: @@ -301,7 +313,11 @@ def _create_producer(self) -> BaseProducer: username=self.kafka_config[KafkaOptions.USERNAME], password=self.kafka_config[KafkaOptions.PASSWORD], fallback_base=0.1, - additional_settings={KafkaOptions.SECURITY_PROTO: "SASL_PLAINTEXT"}, + split_messages=self.config.split_messages, + additional_settings={ + KafkaOptions.SECURITY_PROTO: "SASL_PLAINTEXT", + **self.config.additional_settings, + }, ) return BaseProducer(producer_config) @@ -349,7 +365,10 @@ def start_consumer( group_id=self.config.group_id, target=target, filter_function=filter_function, - additional_settings={KafkaOptions.SECURITY_PROTO: "SASL_PLAINTEXT"}, + additional_settings={ + KafkaOptions.SECURITY_PROTO: "SASL_PLAINTEXT", + **self.config.additional_settings, + }, ) consumer = BaseConsumer( @@ -388,12 +407,17 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: pass async def send_messages( - self, count: int, headers: dict[str, bytes] | None = None) -> list[dict[str, Any]]: + self, + count: int, + extra_fields: dict[str, Any] | None = None, + headers: dict[str, bytes] | None = None, + ) -> list[dict[str, Any]]: """ Generate and send messages. Args: count: Number of messages to send. + extra_fields: Extra fields to add to each message. headers: Message headers (optional parameter) Returns: @@ -402,7 +426,7 @@ async def send_messages( if self._producer is None: raise RuntimeError("Harness not started. Use 'async with' context manager.") - messages = self.generator.generate(count) + messages = self.generator.generate(count, extra_fields=extra_fields) for msg in messages: await self._producer.send(msg, headers=headers) diff --git a/tests/integration/test_chunking.py b/tests/integration/test_chunking.py new file mode 100644 index 0000000..ac2b553 --- /dev/null +++ b/tests/integration/test_chunking.py @@ -0,0 +1,88 @@ +""" +Integration tests for Kafka producer and consumer +using larger amount of messages than other tests. +""" + +import asyncio +import re +from typing import Any + +import pytest +from confluent_kafka.admin import AdminClient + +from retriable_kafka_client.kafka_settings import KafkaOptions +from retriable_kafka_client import ConsumeTopicConfig + +from .integration_utils import ( + IntegrationTestScaffold, + ScaffoldConfig, + RandomDelay, +) + + +@pytest.mark.asyncio +async def test_chunking( + kafka_config: dict[str, Any], + admin_client: AdminClient, + caplog: pytest.LogCaptureFixture, +) -> None: + """ + Test that a large number of messages (30+) can be processed without deadlocks + when using: + - A small number of workers (2) + - Random delays in processing + - Exceptions on first attempt that trigger retries + + This test ensures the consumer can handle concurrent processing with retries + without getting into a deadlock state. + """ + message_count = 5 + very_large_message = {"sample": "a" * 10000} + + config = ScaffoldConfig( + topics=[ + ConsumeTopicConfig( + base_topic="test-chunks-base-topic", + retry_topic="test-chunks-retry-topic", + retries=3, + fallback_delay=0.5, + ), + ], + group_id="test-chunks", + timeout=30.0, + split_messages=True, + additional_settings={KafkaOptions.MAX_MESSAGE_SIZE: 1000}, + ) + + async with IntegrationTestScaffold(kafka_config, admin_client, config) as scaffold: + scaffold.start_consumer( + delay=RandomDelay(min_delay=0, max_delay=0.05), + fail_chance_on_first=0.7, + max_workers=2, + max_concurrency=4, + ) + await asyncio.sleep(2) # Wait for consumer to be ready + + await scaffold.send_messages(message_count, extra_fields=very_large_message) + success = await scaffold.wait_for_success() + + assert success, ( + f"Expected {scaffold.messages_sent} successful messages after retries, " + f"got {len(scaffold.tracker.success_counts)}. " + f"This may indicate a deadlock. " + f"Call counts: {scaffold.tracker.call_counts}" + ) + + # Verify all messages were processed + for msg_id in range(scaffold.messages_sent): + assert scaffold.tracker.success_counts.get(msg_id) == 1, ( + f"Message {msg_id} should have succeeded 1 time, " + f"got {scaffold.tracker.success_counts.get(msg_id)}" + ) + assert any( + re.match( + r"Received all message chunks, assembling group [a-f0-9]+ composed of [0-9]{2} messages\.", + message, + ) + for message in caplog.messages + ) diff --git a/tests/integration/test_user_function_filter.py b/tests/integration/test_user_function_filter.py index 5912379..1e67e52 100644 --- a/tests/integration/test_user_function_filter.py +++ b/tests/integration/test_user_function_filter.py @@ -2,6 +2,7 @@ Integration test to check for messages in headers """ + import asyncio from typing import Any @@ -48,7 +49,9 @@ def filter_headers(msg: Message) -> bool: await scaffold.send_messages(1, headers={"repository_name": b"helm-charts"}) await scaffold.send_messages(1, headers={"repository_name": b"other-repo"}) - await scaffold.send_messages(1, headers={"repository_name": b"my/helm-charts/repo"}) + await scaffold.send_messages( + 1, headers={"repository_name": b"my/helm-charts/repo"} + ) await scaffold.send_messages(1) # Helper function that processes only filtered related messages diff --git a/tests/unit/test_chunking.py b/tests/unit/test_chunking.py new file mode 100644 index 0000000..5aec0bc --- /dev/null +++ b/tests/unit/test_chunking.py @@ -0,0 +1,44 @@ +from unittest.mock import MagicMock + +import pytest +from confluent_kafka import Message + +from retriable_kafka_client.chunking import ChunkingCache, calculate_header_size +from retriable_kafka_client.headers import ( + CHUNK_GROUP_HEADER, + NUMBER_OF_CHUNKS_HEADER, + CHUNK_ID_HEADER, + serialize_number_to_bytes, +) + + +def test_chunking_receive(): + message1 = MagicMock(spec=Message) + message2 = MagicMock(spec=Message) + message3 = MagicMock(spec=Message) + for message_id, message in enumerate([message1, message2, message3]): + message.headers.return_value = [ + (CHUNK_GROUP_HEADER, b"foo"), + (NUMBER_OF_CHUNKS_HEADER, serialize_number_to_bytes(3)), + (CHUNK_ID_HEADER, serialize_number_to_bytes(message_id)), + ] + message.topic.return_value = "foo_topic" + message.partition.return_value = 0 + message1.value.return_value = b'{"hello' + message2.value.return_value = b'": "wor' + message3.value.return_value = b'ld"}' + chunking_cache = ChunkingCache() + assert chunking_cache.receive(message1) is None + assert chunking_cache.receive(message2) is None + assert chunking_cache.receive(message3).deserialize() == {"hello": "world"} + + +@pytest.mark.parametrize( + ["headers", "expected_size"], + [(None, 0), ({"foo": "bar"}, 55), ([("foo", "bar"), ("spam", "ham")], 111)], +) +def test_calculate_header_size( + headers: dict[str, str | bytes] | list[tuple[str, str | bytes]] | None, + expected_size: int, +) -> None: + assert calculate_header_size(headers) == expected_size diff --git a/tests/unit/test_consumer.py b/tests/unit/test_consumer.py index e81f23d..16c8d23 100644 --- a/tests/unit/test_consumer.py +++ b/tests/unit/test_consumer.py @@ -10,7 +10,13 @@ from retriable_kafka_client import BaseConsumer, ConsumerConfig from retriable_kafka_client.config import ConsumeTopicConfig -from retriable_kafka_client.consumer_tracking import _PartitionInfo as PartitionInfo +from retriable_kafka_client.headers import ( + CHUNK_GROUP_HEADER, + NUMBER_OF_CHUNKS_HEADER, + serialize_number_to_bytes, + CHUNK_ID_HEADER, +) +from retriable_kafka_client.kafka_utils import TrackingInfo, MessageGroup @pytest.fixture @@ -55,23 +61,57 @@ def test_consumer__process_message_decode_fail( ) -> None: caplog.set_level(logging.DEBUG) mock_message = MagicMock( - spec=Message, value=lambda: b"---\nthis: is not a json\n", error=lambda: None + spec=Message, + value=lambda: b"---\nthis: is not a json\n", + error=lambda: None, + topic=lambda: "test-topic", + partition=lambda: 0, + offset=lambda: 0, + headers=lambda: None, ) assert base_consumer._process_message(mock_message) is None assert "Decoding error: not a valid JSON" in caplog.messages[-1] +def test_process_message_skip_chunk( + base_consumer: BaseConsumer, +): + mock_message = MagicMock( + spec=Message, + ) + mock_message.headers.return_value = [ + (CHUNK_GROUP_HEADER, b"foo"), + (NUMBER_OF_CHUNKS_HEADER, serialize_number_to_bytes(3)), + (CHUNK_ID_HEADER, serialize_number_to_bytes(0)), + ] + mock_message.topic.return_value = "foo_topic" + mock_message.partition.return_value = 0 + result = base_consumer._process_message(mock_message) + assert result is None + + def test_consumer__process_message_empty_value( base_consumer: BaseConsumer, ) -> None: - """Test that empty messages are discarded and committed.""" - mock_message = MagicMock(spec=Message, value=lambda: None, error=lambda: None) - mock_consumer = base_consumer._consumer - mock_consumer.commit = MagicMock() + """Test that empty messages are discarded and scheduled for commit.""" + mock_message = MagicMock( + spec=Message, + value=lambda: None, + error=lambda: None, + topic=lambda: "test-topic", + partition=lambda: 0, + offset=lambda: 0, + headers=lambda: None, + ) result = base_consumer._process_message(mock_message) assert result is None + tracking_manager = base_consumer._BaseConsumer__tracking_manager + partition_info = TrackingInfo("test-topic", 0) + assert (1,) in tracking_manager._TrackingManager__to_commit.get( + partition_info, set() + ) def test_consumer__process_message_valid_json( @@ -81,7 +121,13 @@ def test_consumer__process_message_valid_json( message_data = {"key": "value", "number": 42} message_json = json.dumps(message_data).encode() mock_message = MagicMock( - spec=Message, value=lambda: message_json, error=lambda: None + spec=Message, + value=lambda: message_json, + error=lambda: None, + topic=lambda: "test-topic", + partition=lambda: 0, + offset=lambda: 0, + headers=lambda: None, ) with patch.object(base_consumer._executor, "submit") as mock_submit: @@ -123,9 +169,11 @@ def test_consumer__graceful_shutdown( base_consumer._executor = MagicMock() # Pre-fill the tracking manager - partition_info = PartitionInfo("test-topic", 0) + partition_info = TrackingInfo("test-topic", 0) tracking_manager = base_consumer._BaseConsumer__tracking_manager - tracking_manager._TrackingManager__to_commit[partition_info].update({100, 101, 102}) + tracking_manager._TrackingManager__to_commit[partition_info].update( + {(100,), (101,), (102,)} + ) # Verify cache has data assert len(tracking_manager._TrackingManager__to_commit[partition_info]) == 3 @@ -299,17 +347,16 @@ def test_perform_commits_logic( mock_consumer = base_consumer._consumer mock_consumer.commit = MagicMock() - partition_info = PartitionInfo("test-topic", 0) + partition_info = TrackingInfo("test-topic", 0) tracking_manager = base_consumer._BaseConsumer__tracking_manager if to_process_offsets: - # __to_process now stores dict[int, Future] for offset in to_process_offsets: - tracking_manager._TrackingManager__to_process[partition_info][offset] = ( + tracking_manager._TrackingManager__to_process[partition_info][(offset,)] = ( MagicMock(spec=Future) ) if to_commit_offsets: tracking_manager._TrackingManager__to_commit[partition_info].update( - to_commit_offsets + {(offset,) for offset in to_commit_offsets} ) base_consumer._BaseConsumer__perform_commits() @@ -326,11 +373,11 @@ def test_perform_commits_failed(base_consumer: BaseConsumer) -> None: mock_consumer = base_consumer._consumer mock_consumer.commit = MagicMock(side_effect=KafkaException()) - partition_info = PartitionInfo("test-topic", 0) + partition_info = TrackingInfo("test-topic", 0) tracking_manager = base_consumer._BaseConsumer__tracking_manager - tracking_manager._TrackingManager__to_commit[partition_info] = {1} + tracking_manager._TrackingManager__to_commit[partition_info] = {(1,)} base_consumer._BaseConsumer__perform_commits() - assert tracking_manager._TrackingManager__to_commit[partition_info] == {1} + assert tracking_manager._TrackingManager__to_commit[partition_info] == {(1,)} @pytest.mark.parametrize( @@ -350,11 +397,11 @@ def test_on_revoke( mock_consumer.commit = MagicMock() # Pre-fill the tracking manager if needed - partition_info = PartitionInfo("test-topic", 0) + partition_info = TrackingInfo("test-topic", 0) tracking_manager = base_consumer._BaseConsumer__tracking_manager if to_commit_offsets: tracking_manager._TrackingManager__to_commit[partition_info].update( - to_commit_offsets + {(offset,) for offset in to_commit_offsets} ) # Create mock partitions list (required by on_revoke signature) @@ -381,7 +428,8 @@ def test_on_revoke( def test_ack_message_cancelled(base_consumer: BaseConsumer) -> None: mock_future = MagicMock(spec=Future) mock_future.cancelled = lambda: True - base_consumer._BaseConsumer__ack_message(MagicMock(), mock_future) + mock_message_group = MagicMock(spec=MessageGroup) + base_consumer._BaseConsumer__ack_message(mock_message_group, mock_future) mock_future.exception.assert_not_called() @@ -400,14 +448,17 @@ def test_ack_message_with_exception( mock_retry_manager = MagicMock() base_consumer._BaseConsumer__retry_manager = mock_retry_manager - # Create a mock message - mock_message = MagicMock( + # Create a MessageGroup + mock_inner_message = MagicMock( spec=Message, value=lambda: b'{"test": "data"}', - topic=lambda: "test-topic", - partition=lambda: 0, offset=lambda: 42, ) + message_group = MessageGroup( + topic="test-topic", + partition=0, + messages={0: mock_inner_message}, + ) # Create a mock future that raises an exception test_exception = ValueError("Test processing error") @@ -416,19 +467,19 @@ def test_ack_message_with_exception( mock_future.cancelled.return_value = False # Call __ack_message - base_consumer._BaseConsumer__ack_message(mock_message, mock_future) + base_consumer._BaseConsumer__ack_message(message_group, mock_future) # Verify tracking manager schedule_commit was called - mock_tracking_manager.schedule_commit.assert_called_once_with(mock_message) + mock_tracking_manager.schedule_commit.assert_called_once_with(message_group) # Verify retry manager was called to resend the message - mock_retry_manager.resend_message.assert_called_once_with(mock_message) + mock_retry_manager.resend_message.assert_called_once_with(message_group) # Verify error was logged assert len(caplog.records) == 1 assert caplog.records[0].levelname == "ERROR" assert "Message could not be processed!" in caplog.records[0].message - assert '{"test": "data"}' in caplog.records[0].message + assert "test" in caplog.records[0].message assert caplog.records[0].exc_info[1] is test_exception @@ -574,6 +625,7 @@ def test_consumer_with_filter_function( Test that filter_function receives message with accessible headers and filters messages correctly """ + # Helper filter function to check for message in headers def filter_with_header_access(msg: Message) -> bool: headers = msg.headers() diff --git a/tests/unit/test_consumer_tracking.py b/tests/unit/test_consumer_tracking.py index 5adc69c..4e20699 100644 --- a/tests/unit/test_consumer_tracking.py +++ b/tests/unit/test_consumer_tracking.py @@ -8,35 +8,8 @@ from retriable_kafka_client.consumer_tracking import ( TrackingManager, - _PartitionInfo as PartitionInfo, ) - - -def test_partition_info_round_trip() -> None: - """Test that _PartitionInfo can be converted to TopicPartition and maintains data integrity.""" - # Create a mock Kafka message - mock_message = MagicMock( - spec=Message, - topic=lambda: "test-topic", - partition=lambda: 3, - offset=lambda: 150, - ) - - # Convert Message to _PartitionInfo - partition_info = PartitionInfo.from_message(mock_message) - - # Verify extracted data - assert partition_info.topic == "test-topic" - assert partition_info.partition == 3 - - # Convert back to TopicPartition with the same offset - topic_partition = partition_info.to_offset_info(150) - - # Verify data integrity after round-trip - assert topic_partition.topic == "test-topic" - assert topic_partition.partition == 3 - assert topic_partition.offset == 150 - assert isinstance(topic_partition, TopicPartition) +from retriable_kafka_client.kafka_utils import TrackingInfo, MessageGroup @pytest.mark.parametrize( @@ -101,12 +74,12 @@ def test_offset_cache_pop_committable( # Setup the cache state for state in partition_states: - partition_info = PartitionInfo("test-topic", state["partition"]) - # Handle special case for explicitly creating empty set in to_commit - cache._TrackingManager__to_commit[partition_info] = state["to_commit"] - # to_process now stores dict[int, Future], convert sets to dicts with mock futures + partition_info = TrackingInfo("test-topic", state["partition"]) + cache._TrackingManager__to_commit[partition_info] = { + (offset,) for offset in state["to_commit"] + } cache._TrackingManager__to_process[partition_info] = { - offset: MagicMock(spec=Future) for offset in state["to_process"] + (offset,): MagicMock(spec=Future) for offset in state["to_process"] } # Call pop_committable @@ -128,10 +101,11 @@ def test_offset_cache_pop_committable( # Verify remaining state in to_commit for partition, expected_offsets in expected_remaining_to_commit.items(): - partition_info = PartitionInfo("test-topic", partition) + partition_info = TrackingInfo("test-topic", partition) actual_offsets = cache._TrackingManager__to_commit.get(partition_info, set()) - assert actual_offsets == expected_offsets, ( - f"Partition {partition}: expected {expected_offsets}, got {actual_offsets}" + expected_as_tuples = {(offset,) for offset in expected_offsets} + assert actual_offsets == expected_as_tuples, ( + f"Partition {partition}: expected {expected_as_tuples}, got {actual_offsets}" ) @@ -141,35 +115,38 @@ def test_offset_cache_schedule_commit_success( """Test schedule_commit successfully moves offset from to_process to to_commit.""" cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) - # Create a mock message - mock_message = MagicMock( + # Create a MessageGroup + mock_inner_message = MagicMock( spec=Message, topic=lambda: "test-topic", partition=lambda: 0, offset=lambda: 42, value=lambda: b'{"test": "data"}', ) + message_group = MessageGroup( + topic="test-topic", partition=0, messages={0: mock_inner_message} + ) # Create a mock future mock_future = MagicMock(spec=Future) # First, mark the message as being processed - cache.process_message(mock_message, mock_future) + cache.process_message(message_group, mock_future) # Verify it's in to_process - partition_info = PartitionInfo("test-topic", 0) - assert 43 in cache._TrackingManager__to_process[partition_info] - assert 43 not in cache._TrackingManager__to_commit.get(partition_info, set()) + partition_info = TrackingInfo("test-topic", 0) + assert (43,) in cache._TrackingManager__to_process[partition_info] + assert (43,) not in cache._TrackingManager__to_commit.get(partition_info, set()) # Now schedule it for commit - result = cache.schedule_commit(mock_message) + result = cache.schedule_commit(message_group) # Verify success assert result is True # Verify it moved from to_process to to_commit - assert 43 not in cache._TrackingManager__to_process[partition_info] - assert 43 in cache._TrackingManager__to_commit[partition_info] + assert (43,) not in cache._TrackingManager__to_process.get(partition_info, {}) + assert (43,) in cache._TrackingManager__to_commit[partition_info] # No warning should be logged assert len(caplog.records) == 0 @@ -182,25 +159,28 @@ def test_offset_cache_schedule_commit_without_prior_processing( caplog.set_level(logging.WARNING) cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) - # Create a mock message - mock_message = MagicMock( + # Create a MessageGroup + mock_inner_message = MagicMock( spec=Message, topic=lambda: "test-topic", partition=lambda: 0, offset=lambda: 42, value=lambda: b'{"test": "data"}', ) + message_group = MessageGroup( + topic="test-topic", partition=0, messages={0: mock_inner_message} + ) # Don't add it to to_process (simulating it was never marked for processing) # Try to schedule commit - result = cache.schedule_commit(mock_message) + result = cache.schedule_commit(message_group) # schedule_commit always returns True in the new implementation assert result is True # Verify offset was added to to_commit (new behavior) - partition_info = PartitionInfo("test-topic", 0) - assert 43 in cache._TrackingManager__to_commit.get(partition_info, set()) + partition_info = TrackingInfo("test-topic", 0) + assert (43,) in cache._TrackingManager__to_commit.get(partition_info, set()) # No warning is logged in the new implementation assert len(caplog.records) == 0 @@ -211,32 +191,32 @@ def test_offset_cache_schedule_commit_without_prior_processing( [ pytest.param( { - PartitionInfo("topic-a", 0): { + TrackingInfo("topic-a", 0): { 100: "future_1", 101: "future_2", 102: "future_3", }, }, - {PartitionInfo("topic-a", 0)}, + {TrackingInfo("topic-a", 0)}, ["future_1", "future_3"], ["future_2"], id="mixed_some_cancelled_some_not", ), pytest.param( { - PartitionInfo("topic-a", 0): {100: "future_1", 101: "future_2"}, - PartitionInfo("topic-b", 0): {200: "future_3"}, + TrackingInfo("topic-a", 0): {100: "future_1", 101: "future_2"}, + TrackingInfo("topic-b", 0): {200: "future_3"}, }, - {PartitionInfo("topic-a", 0)}, + {TrackingInfo("topic-a", 0)}, ["future_1", "future_2"], [], id="partial_revoke_keeps_non_revoked", ), pytest.param( { - PartitionInfo("topic-a", 0): {100: "future_1", 101: "future_2"}, + TrackingInfo("topic-a", 0): {100: "future_1", 101: "future_2"}, }, - {PartitionInfo("topic-b", 0)}, + {TrackingInfo("topic-b", 0)}, [], ["future_1", "future_2"], id="different_topic", @@ -244,8 +224,8 @@ def test_offset_cache_schedule_commit_without_prior_processing( ], ) def offset_cache_revoke_processing( - to_process_data: dict[PartitionInfo, dict[int, str]], - revoked_partitions: set[PartitionInfo], + to_process_data: dict[TrackingInfo, dict[int, str]], + revoked_partitions: set[TrackingInfo], expected_cancelled: list[str], expected_not_cancelled: list[str], ) -> None: @@ -325,23 +305,23 @@ def offset_cache_revoke_processing( id="only_to_commit_has_data_remains_after_revoke", ), pytest.param( - {PartitionInfo("test-topic", 1): {50: MagicMock(), 51: MagicMock()}}, + {TrackingInfo("test-topic", 1): {50: MagicMock(), 51: MagicMock()}}, [TopicPartition("test-topic", 1)], {}, id="only_to_process_has_data_all_cancelled", ), pytest.param( { - PartitionInfo("topic-a", 0): {12: MagicMock()}, - PartitionInfo("topic-b", 1): {22: "FakeFuture"}, + TrackingInfo("topic-a", 0): {12: MagicMock()}, + TrackingInfo("topic-b", 1): {22: "FakeFuture"}, }, [TopicPartition("topic-a", 0)], - {PartitionInfo("topic-b", 1): {22: "FakeFuture"}}, + {TrackingInfo("topic-b", 1): {22: "FakeFuture"}}, id="partial", ), pytest.param( { - PartitionInfo("topic-a", 0): {12: MagicMock(cancel=lambda: False)}, + TrackingInfo("topic-a", 0): {12: MagicMock(cancel=lambda: False)}, }, [TopicPartition("topic-a", 0)], {}, @@ -352,7 +332,7 @@ def offset_cache_revoke_processing( def test_offset_cache_register_revoke( to_process_data: dict, partitions_to_revoke: list[TopicPartition], - expected_remaining_process: dict[PartitionInfo, set[int]], + expected_remaining_process: dict[TrackingInfo, set[int]], ) -> None: cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) cache._TrackingManager__to_process = to_process_data @@ -373,7 +353,7 @@ def test_offset_cache_register_revoke( def test_offset_cache_register_revoke_err( error: type[Exception], ) -> None: - stub_partition = PartitionInfo("topic-a", 0) + stub_partition = TrackingInfo("topic-a", 0) cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) mock_future = MagicMock(spec=Future) mock_future.cancel.return_value = False @@ -394,24 +374,27 @@ def test_offset_cache_schedule_commit_offset_not_in_partition( caplog.set_level(logging.WARNING) cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) - # Create a mock message - mock_message = MagicMock( + # Create a MessageGroup + mock_inner_message = MagicMock( spec=Message, topic=lambda: "test-topic", partition=lambda: 0, offset=lambda: 42, value=lambda: b'{"test": "data"}', ) + message_group = MessageGroup( + topic="test-topic", partition=0, messages={0: mock_inner_message} + ) - partition_info = PartitionInfo("test-topic", 0) + partition_info = TrackingInfo("test-topic", 0) # Add partition but with DIFFERENT offsets (not including 42) for offset in [100, 101, 102]: mock_future = MagicMock(spec=Future) - cache._TrackingManager__to_process[partition_info][offset] = mock_future + cache._TrackingManager__to_process[partition_info][(offset,)] = mock_future # Try to schedule commit for offset 42 which doesn't exist in the partition - result = cache.schedule_commit(mock_message) + result = cache.schedule_commit(message_group) # In new implementation, schedule_commit always succeeds assert result is True @@ -420,11 +403,11 @@ def test_offset_cache_schedule_commit_offset_not_in_partition( assert len(caplog.records) == 0 # The offset IS added to to_commit (new behavior) - assert 43 in cache._TrackingManager__to_commit.get(partition_info, set()) + assert (43,) in cache._TrackingManager__to_commit.get(partition_info, set()) # Verify original offsets remain in to_process (42 wasn't there to remove) assert set(cache._TrackingManager__to_process[partition_info].keys()) == { - 100, - 101, - 102, + (100,), + (101,), + (102,), } diff --git a/tests/unit/test_producer.py b/tests/unit/test_producer.py index 13bbed5..aea0695 100644 --- a/tests/unit/test_producer.py +++ b/tests/unit/test_producer.py @@ -10,6 +10,12 @@ from retriable_kafka_client.producer import BaseProducer from retriable_kafka_client.config import ProducerConfig +from retriable_kafka_client.headers import ( + CHUNK_GROUP_HEADER, + NUMBER_OF_CHUNKS_HEADER, + CHUNK_ID_HEADER, + serialize_number_to_bytes, +) @pytest.fixture @@ -58,15 +64,14 @@ def test_producer_send_success_first_attempt( _run_send_method(base_producer, send_method, message) mock_kafka_producer: MagicMock = base_producer._producer - # Should produce to all topics once assert mock_kafka_producer.produce.call_count == len(base_producer._config.topics) - for topic in base_producer._config.topics: - mock_kafka_producer.produce.assert_any_call( - topic=topic, - value=json.dumps(message).encode("utf-8"), - timestamp=1000, - headers=None, - ) + for call in mock_kafka_producer.produce.call_args_list: + kwargs = call[1] + assert kwargs["topic"] in base_producer._config.topics + assert kwargs["value"] == json.dumps(message).encode("utf-8") + assert kwargs["timestamp"] == 1000 + assert kwargs["key"] is not None + assert isinstance(kwargs["headers"], dict) @pytest.mark.parametrize("send_method", ["send", "send_sync"]) @@ -174,18 +179,89 @@ def test_producer_close(base_producer: BaseProducer) -> None: @pytest.mark.parametrize( "input_message,expected_output", [ - pytest.param(b"raw bytes", b"raw bytes", id="bytes_passthrough"), - pytest.param({"key": "value"}, b'{"key": "value"}', id="dict_to_json"), + pytest.param(b"raw bytes", [b"raw bytes"], id="bytes_passthrough"), + pytest.param({"key": "value"}, [b'{"key": "value"}'], id="dict_to_json"), pytest.param( {"num": 42, "flag": True}, - b'{"num": 42, "flag": true}', + [b'{"num": 42, "flag": true}'], id="dict_with_types", ), - pytest.param([], b"[]", id="empty_list"), + pytest.param([], [], id="empty_chunk_list"), + ], +) +def test_serialize_message( + base_producer: BaseProducer, input_message, expected_output +) -> None: + """Test that messages are correctly serialized to byte chunks.""" + result = base_producer._BaseProducer__serialize_message(input_message, None, True) + assert result == expected_output + + +@pytest.mark.parametrize("send_method", ["send", "send_sync"]) +@pytest.mark.parametrize("split_messages", [True, False]) +def test_producer_send_chunking_headers(send_method: str, split_messages: bool) -> None: + """Test that chunking headers are added only when split_messages is enabled.""" + config = ProducerConfig( + kafka_hosts=["example.com:9092"], + topics=["test_topic"], + username="test_user", + password="test_pass", + retries=2, + fallback_factor=1.1, + fallback_base=0.01, + additional_settings={}, + split_messages=split_messages, + ) + with patch("retriable_kafka_client.producer.Producer"): + producer = BaseProducer(config=config) + _ = producer._producer + + message = {"key": "value"} + + with patch("time.time", return_value=1): + _run_send_method(producer, send_method, message) + + mock_kafka_producer: MagicMock = producer._producer + assert mock_kafka_producer.produce.call_count == 1 + call_kwargs = mock_kafka_producer.produce.call_args[1] + headers = call_kwargs["headers"] + + chunking_headers = {CHUNK_GROUP_HEADER, NUMBER_OF_CHUNKS_HEADER, CHUNK_ID_HEADER} + + if split_messages: + for header in chunking_headers: + assert header in headers, ( + f"Expected header {header!r} when split_messages=True" + ) + assert isinstance(headers[CHUNK_GROUP_HEADER], bytes) + assert len(headers[CHUNK_GROUP_HEADER]) == 16 # UUID bytes + assert headers[NUMBER_OF_CHUNKS_HEADER] == serialize_number_to_bytes(1) + assert headers[CHUNK_ID_HEADER] == serialize_number_to_bytes(0) + else: + for header in chunking_headers: + assert header not in headers, ( + f"Unexpected header {header!r} when split_messages=False" + ) + + +@pytest.mark.parametrize( + ["chunking_size", "input_message", "expected_output"], + [ + (5, [b"AAAAAA", b"AA"], [b"AAAAA", b"AAA"]), + (0, [b"AAAAAA", b"AA"], [b"AAAAAAAA"]), ], ) -def test_serialize_message(input_message, expected_output) -> None: - """Test that messages are correctly serialized to bytes.""" - # Access the private static method via name mangling - result = BaseProducer._BaseProducer__serialize_message(input_message) +def test_serialize_message_resize_chunks( + chunking_size: int, + input_message: list[bytes] | bytes | dict[str, Any], + expected_output: list[bytes], + base_producer: BaseProducer, +) -> None: + """Test that messages are correctly serialized to byte chunks.""" + with patch.object(base_producer, "_get_chunk_size", return_value=chunking_size): + result = base_producer._BaseProducer__serialize_message( + input_message, + None, + chunking_size != 0, + ) assert result == expected_output diff --git a/tests/unit/test_retry_utils.py b/tests/unit/test_retry_utils.py index 6187852..b9cc8b9 100644 --- a/tests/unit/test_retry_utils.py +++ b/tests/unit/test_retry_utils.py @@ -4,9 +4,10 @@ from unittest.mock import MagicMock, patch import pytest -from confluent_kafka import KafkaException, KafkaError, TopicPartition +from confluent_kafka import KafkaException, KafkaError, Message, TopicPartition from retriable_kafka_client.config import ConsumerConfig, ConsumeTopicConfig +from retriable_kafka_client.kafka_utils import MessageGroup from retriable_kafka_client.retry_utils import ( ATTEMPT_HEADER, TIMESTAMP_HEADER, @@ -39,7 +40,7 @@ def _make_message( error: KafkaError | None = None, ) -> MagicMock: """Helper to create a mock Kafka message.""" - message = MagicMock() + message = MagicMock(spec=Message) message.topic.return_value = topic message.partition.return_value = partition message.offset.return_value = offset @@ -49,6 +50,26 @@ def _make_message( return message +def _make_message_group( + topic: str = "test-topic", + partition: int = 0, + offset: int = 0, + value: bytes | None = b"test", + headers: list[tuple[str, bytes]] | None = None, +) -> MessageGroup: + """Helper to create a MessageGroup wrapping a mock Kafka message. + Uses a plain MagicMock (no spec) for the inner message so that + it stays truthy, which MessageGroup.headers property relies on.""" + inner = MagicMock() + inner.topic.return_value = topic + inner.partition.return_value = partition + inner.offset.return_value = offset + inner.value.return_value = value + inner.headers.return_value = headers + inner.error.return_value = None + return MessageGroup(topic=topic, partition=partition, messages={0: inner}) + + @pytest.mark.parametrize( "headers,expected", [ @@ -101,8 +122,8 @@ def test_get_retry_attempt( headers: list[tuple[str, bytes]] | None, expected: int ) -> None: """Test _get_retry_attempt extracts attempt number from headers correctly.""" - message = _make_message(headers=headers) - assert _get_retry_attempt(message) == expected + message_group = _make_message_group(headers=headers) + assert _get_retry_attempt(message_group) == expected class TestRetryScheduleCache: @@ -189,11 +210,13 @@ def test_clean_empty_removes_empty_timestamp_keys(self) -> None: cache = RetryScheduleCache() cache._RetryScheduleCache__schedule = { 100: [], - 200: _make_message( - topic="topic-a", - partition=0, - headers=[(TIMESTAMP_HEADER, int(200).to_bytes(8, "big"))], - ), + 200: [ + _make_message( + topic="topic-a", + partition=0, + headers=[(TIMESTAMP_HEADER, int(200).to_bytes(8, "big"))], + ) + ], } cache._cleanup() @@ -424,7 +447,9 @@ def test__get_retry_headers( return_value=timestamp, ): manager = RetryManager(config) - result = manager._get_retry_headers(_make_message(topic="t", headers=headers)) + result = manager._get_retry_headers( + _make_message_group(topic="t", headers=headers) + ) assert int.from_bytes(result[ATTEMPT_HEADER], "big") == expected_attempt assert int.from_bytes(result[TIMESTAMP_HEADER], "big") == expected_ts @@ -437,7 +462,7 @@ def test__get_retry_headers_no_config() -> None: ] ) manager = RetryManager(config) - result = manager._get_retry_headers(_make_message(topic="t", headers=[])) + result = manager._get_retry_headers(_make_message_group(topic="t", headers=[])) assert result is None @@ -480,21 +505,21 @@ def test_resend_message(topics, message_topic, send_error, expect_send) -> None: with patch("retriable_kafka_client.retry_utils.BaseProducer") as mock_cls: mock_cls.return_value.send_sync.side_effect = send_error manager = RetryManager(config) - manager.resend_message(_make_message(topic=message_topic)) + manager.resend_message(_make_message_group(topic=message_topic)) assert mock_cls.return_value.send_sync.called == expect_send def test_resend_message_no_value() -> None: - message = _make_message(value=None) + message_group = _make_message_group(value=None) config = _make_config([]) with patch("retriable_kafka_client.retry_utils.BaseProducer") as mock_cls: manager = RetryManager(config) - manager.resend_message(message) + manager.resend_message(message_group) mock_cls.return_value.send_sync.assert_not_called() def test_resend_message_exhausted_attempts(caplog: pytest.LogCaptureFixture) -> None: - message = _make_message( + message_group = _make_message_group( topic="t-retry", value=b'{"foo": "bar"}', headers=[ @@ -508,7 +533,6 @@ def test_resend_message_exhausted_attempts(caplog: pytest.LogCaptureFixture) -> ) with patch("retriable_kafka_client.retry_utils.BaseProducer") as mock_cls: manager = RetryManager(config) - manager.resend_message(message) + manager.resend_message(message_group) mock_cls.return_value.send_sync.assert_not_called() - print(caplog.messages) assert "Message will not be retried." in caplog.messages[-1] diff --git a/uv.lock b/uv.lock index 20be67a..360939e 100644 --- a/uv.lock +++ b/uv.lock @@ -894,7 +894,7 @@ wheels = [ [[package]] name = "retriable-kafka-client" -version = "0.2.0" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "confluent-kafka" }, From 6922de310f50e47e70722f3ad6b670df534c9da7 Mon Sep 17 00:00:00 2001 From: Marek Szymutko Date: Tue, 7 Apr 2026 14:43:50 +0200 Subject: [PATCH 2/5] feat(ISV-7020): enhance chunking capabilities * enable chunk expiration * tests assisted with Claude Signed-off-by: Marek Szymutko Assisted-by: Claude-4.6-Opus --- src/retriable_kafka_client/chunking.py | 147 +++++++++++------- src/retriable_kafka_client/config.py | 7 + src/retriable_kafka_client/consumer.py | 16 +- .../consumer_tracking.py | 110 ++++++++++++- tests/integration/integration_utils.py | 5 +- tests/integration/test_lost_chunk_recover.py | 93 +++++++++++ tests/unit/test_chunking.py | 89 ++++++++--- tests/unit/test_consumer.py | 4 +- tests/unit/test_consumer_tracking.py | 122 +++++++++++++-- 9 files changed, 492 insertions(+), 101 deletions(-) create mode 100644 tests/integration/test_lost_chunk_recover.py diff --git a/src/retriable_kafka_client/chunking.py b/src/retriable_kafka_client/chunking.py index 2a0468d..b9f2915 100644 --- a/src/retriable_kafka_client/chunking.py +++ b/src/retriable_kafka_client/chunking.py @@ -1,9 +1,9 @@ """Module for tracking chunked messages""" +from datetime import datetime, timedelta, timezone import logging import sys import uuid -from collections import defaultdict from confluent_kafka import Message @@ -14,7 +14,7 @@ deserialize_number_from_bytes, get_header_value, ) -from retriable_kafka_client.kafka_utils import MessageGroup +from retriable_kafka_client.kafka_utils import MessageGroup, TrackingInfo LOGGER = logging.getLogger(__name__) @@ -40,63 +40,100 @@ def calculate_header_size( return result -class ChunkingCache: - # pylint: disable=too-few-public-methods - """Class for storing information about received message fragments.""" +class MessageGroupBuilder: + """Class for gathering data from message chunks.""" - def __init__(self): - # The tuple holds group ID, topic and partition - # in case of group ID collision (that could mean an attack attempt). - # If the same consumer was used for different topics, an adversary - # may want to override split messages to execute different operations. - # By using group id as well as topic and partition, this attack is - # made impossible. - self._message_chunks: dict[tuple[bytes, str, int], dict[int, Message]] = ( - defaultdict(dict) - ) + def __init__(self, max_wait_time: timedelta) -> None: + self._last_update_time = datetime.now(tz=timezone.utc) + self._max_wait_time = max_wait_time + self._messages: list[Message] = [] + self.partition_info: TrackingInfo | None = None + self.group_id: bytes | None = None + self.full_length: int = 0 + self._present_chunks: set[int] = set() + + @staticmethod + def get_group_id_from_message(message: Message) -> bytes | None: + """Get the group id from the message.""" + return get_header_value(message, CHUNK_GROUP_HEADER) + + @staticmethod + def get_chunk_id_from_message(message: Message) -> int | None: + """Get the chunk id from the message.""" + if header_value := get_header_value(message, CHUNK_ID_HEADER): + return deserialize_number_from_bytes(header_value) + return None + + @staticmethod + def get_number_of_chunks_from_message(message: Message) -> int | None: + """Get the number of chunks from the message.""" + if header_value := get_header_value(message, NUMBER_OF_CHUNKS_HEADER): + return deserialize_number_from_bytes(header_value) + return None - def receive(self, message: Message) -> MessageGroup | None: + def add(self, message: Message) -> None: """ - Receive a message. If the message is whole, or it is - the last fragment, returns the whole message group - and flushes cache of this group. Otherwise, returns None - and the message fragment shall not be processed. + Add a message to the message group builder. + Args: + message: The Kafka message to add. """ - topic: str = message.topic() # type: ignore[assignment] - partition: int = message.partition() # type: ignore[assignment] if ( - (group_id := get_header_value(message, CHUNK_GROUP_HEADER)) is not None - and ( - number_of_chunks_raw := get_header_value( - message, NUMBER_OF_CHUNKS_HEADER - ) - ) - is not None - and (chunk_id_raw := get_header_value(message, CHUNK_ID_HEADER)) is not None + (new_group_id := self.get_group_id_from_message(message)) is None + or (new_number_of_chunks := self.get_number_of_chunks_from_message(message)) + is None + or ((new_chunk_id := self.get_chunk_id_from_message(message)) is None) ): - number_of_chunks = deserialize_number_from_bytes(number_of_chunks_raw) - chunk_id = deserialize_number_from_bytes(chunk_id_raw) - identifier = (group_id, topic, partition) - stored_message_ids_from_group = set( - self._message_chunks.get(identifier, {}).keys() - ) - stored_message_ids_from_group.add(chunk_id) - if not all( - i in stored_message_ids_from_group for i in range(number_of_chunks) - ): - LOGGER.debug( - "Received a message chunk, waiting for the other chunks..." - ) - self._message_chunks[identifier][chunk_id] = message - return None - LOGGER.debug( - "Received all message chunks, assembling group %s composed of %s messages.", - group_id.hex(), - number_of_chunks, + raise ValueError("The new message is missing required chunk headers!") + if self.group_id is None: + self.group_id = new_group_id + self.full_length = new_number_of_chunks + self.partition_info = TrackingInfo( + message.topic(), # type: ignore[arg-type] + message.partition(), # type: ignore[arg-type] ) - # Clear cache and reassemble, cache can be empty if - # this is just one-message-sized value - messages = self._message_chunks.pop(identifier, {}) - messages[chunk_id] = message - return MessageGroup(topic, partition, messages, group_id) - return MessageGroup(topic, partition, {0: message}, None) + self._last_update_time = datetime.now(tz=timezone.utc) + self._messages.append(message) + self._present_chunks.add(new_chunk_id) + + @property + def is_complete(self) -> bool: + """Does this builder contain all the needed chunks?""" + if not self._messages: + return False + return all(i in self._present_chunks for i in range(self.full_length)) + + @property + def offsets(self) -> tuple[int, ...]: + """Return the offsets of the chunks.""" + return tuple(message.offset() for message in self._messages) # type: ignore[misc] + + def is_still_valid(self) -> bool: + """Isn't this builder stale? Useful for discarding corrupted data.""" + return ( + datetime.now(tz=timezone.utc) - self._last_update_time < self._max_wait_time + ) + + def get_message_group(self, allow_incomplete: bool = False) -> MessageGroup | None: + """ + Generate the message group object from the builder. + Args: + allow_incomplete: If true, this will return a messageGroup + object even if not all chunks have been gathered. + Returns: MessageGroup object if all required data is available, + None otherwise. + """ + if ( + (not allow_incomplete and not self.is_complete) + or not self._messages + or not self.partition_info + ): + return None + return MessageGroup( + topic=self.partition_info.topic, + partition=self.partition_info.partition, + messages={ + self.get_chunk_id_from_message(message): message # type: ignore[misc] + for message in self._messages + }, + group_id=self.group_id, + ) diff --git a/src/retriable_kafka_client/config.py b/src/retriable_kafka_client/config.py index ff646a1..5b69e18 100644 --- a/src/retriable_kafka_client/config.py +++ b/src/retriable_kafka_client/config.py @@ -1,6 +1,7 @@ """Types used in this library""" from dataclasses import dataclass, field +from datetime import timedelta from typing import Callable, Any from confluent_kafka import Message @@ -85,6 +86,11 @@ class ConsumerConfig(CommonConfig): Returns True if the message will be processed or False if skipped. In case False or exception is returned, message will be committed without processing. + max_chunk_reassembly_wait_time: Maximal time to wait for all the chunks + of the message to arrive. Has any effect only if chunking is enabled. + If some chunks are still waiting for reassembly after this threshold, + they are deleted and a warning is logged. This happens if the producer + crashed during producing of the chunked message, data cannot be salvaged. """ group_id: str @@ -92,3 +98,4 @@ class ConsumerConfig(CommonConfig): topics: list[ConsumeTopicConfig] = field(default_factory=list) cancel_future_wait_time: float = field(default=30.0) filter_function: Callable[[Message], bool] | None = field(default=None) + max_chunk_reassembly_wait_time: timedelta = field(default=timedelta(minutes=15)) diff --git a/src/retriable_kafka_client/consumer.py b/src/retriable_kafka_client/consumer.py index 4a14863..77edfcf 100644 --- a/src/retriable_kafka_client/consumer.py +++ b/src/retriable_kafka_client/consumer.py @@ -8,7 +8,6 @@ from confluent_kafka import Consumer, Message, KafkaException, TopicPartition -from .chunking import ChunkingCache from .health import perform_healthcheck_using_client from .kafka_utils import message_to_partition, MessageGroup from .kafka_settings import KafkaOptions, DEFAULT_CONSUMER_SETTINGS @@ -85,14 +84,14 @@ def __init__( self.__stop_flag: bool = False # Store information about offsets and tasks self.__tracking_manager = TrackingManager( - max_concurrency, config.cancel_future_wait_time + max_concurrency, + config.cancel_future_wait_time, + self._config.max_chunk_reassembly_wait_time, ) # Manage re-sending messages to retry topics self.__retry_manager = RetryManager(config) # Store information about pending retried messages self.__schedule_cache = RetryScheduleCache() - # Store chunking information - self.__chunk_tracker = ChunkingCache() @property def _consumer(self) -> Consumer: @@ -162,7 +161,7 @@ def __ack_message(self, message: MessageGroup, finished_future: Future) -> None: ) self.__retry_manager.resend_message(message) finally: - self.__tracking_manager.schedule_commit(message) + self.__tracking_manager.schedule_commit(message, release_semaphore=True) def __graceful_shutdown(self) -> None: """ @@ -214,12 +213,15 @@ def _process_message(self, message: Message) -> Future[Any] | None: Returns: Future of the target execution if the message can be processed. None otherwise. """ - message_group = self.__chunk_tracker.receive(message) + message_group = self.__tracking_manager.receive(message) if not message_group: return None message_data = message_group.deserialize() if not message_data: - self.__tracking_manager.schedule_commit(message_group) + # Semaphore was not acquired + self.__tracking_manager.schedule_commit( + message_group, release_semaphore=False + ) return None future = self._executor.submit(self._config.target, message_data) self.__tracking_manager.process_message(message_group, future) diff --git a/src/retriable_kafka_client/consumer_tracking.py b/src/retriable_kafka_client/consumer_tracking.py index 1ccc30e..01c98f1 100644 --- a/src/retriable_kafka_client/consumer_tracking.py +++ b/src/retriable_kafka_client/consumer_tracking.py @@ -3,12 +3,20 @@ import logging from collections import defaultdict from concurrent.futures import Future +from datetime import timedelta from itertools import chain from threading import Lock, Semaphore from typing import Any, Iterable -from confluent_kafka import TopicPartition +from confluent_kafka import TopicPartition, Message +from retriable_kafka_client.chunking import MessageGroupBuilder +from retriable_kafka_client.headers import ( + get_header_value, + CHUNK_GROUP_HEADER, + NUMBER_OF_CHUNKS_HEADER, + CHUNK_ID_HEADER, +) from retriable_kafka_client.kafka_utils import TrackingInfo, MessageGroup LOGGER = logging.getLogger(__name__) @@ -25,6 +33,7 @@ class TrackingManager: Each message can be either: - untracked (committed or not polled) + - pending for reassembly (if not all message chunks have been received) - pending for processing (then we also track its task object) - pending for committing @@ -52,7 +61,18 @@ class TrackingManager: using semaphore. On task finish or cancellation, the semaphore is released. """ - def __init__(self, concurrency: int, cancel_wait_time: float): + def __init__( + self, concurrency: int, cancel_wait_time: float, max_chunk_wait_time: timedelta + ): + # The tuple holds group ID, topic and partition + # in case of group ID collision (that could mean an attack attempt). + # If the same consumer was used for different topics, an adversary + # may want to override split messages to execute different operations. + # By using group id as well as topic and partition, this attack is + # made impossible. + self.__message_builders: dict[ + tuple[bytes, TrackingInfo], MessageGroupBuilder + ] = {} self.__to_process: dict[TrackingInfo, dict[tuple[int, ...], Future]] = ( defaultdict(dict) ) @@ -60,6 +80,74 @@ def __init__(self, concurrency: int, cancel_wait_time: float): self.__access_lock = Lock() # For handling multithreaded access to this object self.__semaphore = Semaphore(concurrency) self.__cancel_wait_time = cancel_wait_time + self.__max_chunk_wait_time = max_chunk_wait_time + + def _cleanup_stale_builders(self) -> None: + """ + Delete stale message builders if the wait time exceeded the configured + maximum to release resources. + + Returns: + """ + to_pop = set() + for builder_id, builder in self.__message_builders.items(): + if not builder.is_still_valid(): + to_pop.add(builder_id) + for item in to_pop: + deleted_builder = self.__message_builders.pop(item) + message_group = deleted_builder.get_message_group(allow_incomplete=True) + if message_group is not None: + self.schedule_commit( + message_group, + release_semaphore=False, + ) + if deleted_builder.partition_info: + LOGGER.warning( + "Removing stale message builder that failed to assemble message in time. " + "Lost message topic: %s, offsets: %s, group: %s", + deleted_builder.partition_info.topic, + ",".join(str(offset) for offset in deleted_builder.offsets), + deleted_builder.group_id, + ) + + def receive(self, message: Message) -> MessageGroup | None: + """ + Receive a message. If the message is whole, or it is + the last fragment, returns the whole message group + and flushes cache of this group. Otherwise, returns None + and the message fragment shall not be processed. + """ + topic: str = message.topic() # type: ignore[assignment] + partition: int = message.partition() # type: ignore[assignment] + self._cleanup_stale_builders() + if ( + (group_id := get_header_value(message, CHUNK_GROUP_HEADER)) is not None + and get_header_value(message, NUMBER_OF_CHUNKS_HEADER) is not None + and get_header_value(message, CHUNK_ID_HEADER) is not None + ): + builder_id = (group_id, TrackingInfo(topic, partition)) + message_builder = self.__message_builders.get(builder_id, None) + if message_builder is None: + message_builder = MessageGroupBuilder(self.__max_chunk_wait_time) + self.__message_builders[builder_id] = message_builder + message_builder.add(message) + if message_builder.is_complete: + complete_group = message_builder.get_message_group() + self.__message_builders.pop(builder_id) + LOGGER.debug( + "Received all message chunks, assembling group %s composed of %s messages.", + group_id.hex(), + message_builder.full_length, + ) + return complete_group + LOGGER.debug( + "Received a message chunk from group %s in topic %s and partition %s.", + group_id.hex(), + topic, + partition, + ) + return None + return MessageGroup(topic, partition, {0: message}, None) def pop_committable(self) -> list[TopicPartition]: """ @@ -73,6 +161,7 @@ def pop_committable(self) -> list[TopicPartition]: Returns: list of committable message offsets """ + self._cleanup_stale_builders() to_commit = [] with self.__access_lock: for partition_info, tuples_pending_to_commit in self.__to_commit.items(): @@ -80,7 +169,14 @@ def pop_committable(self) -> list[TopicPartition]: # Nothing to commit continue - tuples_pending_to_process = self.__to_process.get(partition_info, None) + tuples_pending_to_process = set( + self.__to_process.get(partition_info, {}).keys() + ) + tuples_pending_to_process.update( + message_builder.offsets + for builder_id, message_builder in self.__message_builders.items() + if builder_id[1] == partition_info + ) if not tuples_pending_to_process: # Nothing is blocking the committing max_to_commit = max(_flatten_offsets(tuples_pending_to_commit)) @@ -150,16 +246,20 @@ def process_message(self, message: MessageGroup, future: Future[Any]) -> None: tuple(message_offset + 1 for message_offset in message_offsets) ] = future - def schedule_commit(self, message: MessageGroup) -> bool: + def schedule_commit(self, message: MessageGroup, release_semaphore: bool) -> bool: """ Mark message as pending for committing when its processing is fully done. Args: message: Kafka message object + release_semaphore: Should the semaphore be released? + Must be set to False if this is called before the + message was sent to processing. Returns: True if successful (the message was previously marked as pending for processing), False otherwise """ - self.__semaphore.release() + if release_semaphore: + self.__semaphore.release() partition_info = TrackingInfo.from_message_group(message) message_offsets = message.offsets stored_offsets = tuple(message_offset + 1 for message_offset in message_offsets) diff --git a/tests/integration/integration_utils.py b/tests/integration/integration_utils.py index c9d9ee9..19c1808 100644 --- a/tests/integration/integration_utils.py +++ b/tests/integration/integration_utils.py @@ -6,6 +6,7 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field +from datetime import timedelta from typing import Any, Callable from confluent_kafka.admin import AdminClient, NewTopic @@ -187,7 +188,7 @@ def __call__(self, message: dict[str, Any]) -> None: call_count = self.tracker.call_counts.get(message_id, 0) self.tracker.call_counts[message_id] = call_count + 1 if self.fail_consistently or ( - call_count == 0 and random.random() < self.fail_chance_on_first + call_count == 0 and (random.random() < self.fail_chance_on_first) ): raise ValueError("Simulated error") with self.tracker.lock: @@ -237,6 +238,7 @@ class ScaffoldConfig: group_id: str timeout: float = 15.0 split_messages: bool = False + max_chunk_reassembly_wait_time: timedelta = field(default=timedelta(seconds=10)) additional_settings: dict[str, Any] = field(default_factory=dict) @@ -365,6 +367,7 @@ def start_consumer( group_id=self.config.group_id, target=target, filter_function=filter_function, + max_chunk_reassembly_wait_time=self.config.max_chunk_reassembly_wait_time, additional_settings={ KafkaOptions.SECURITY_PROTO: "SASL_PLAINTEXT", **self.config.additional_settings, diff --git a/tests/integration/test_lost_chunk_recover.py b/tests/integration/test_lost_chunk_recover.py new file mode 100644 index 0000000..3680099 --- /dev/null +++ b/tests/integration/test_lost_chunk_recover.py @@ -0,0 +1,93 @@ +""" +Integration tests for Kafka producer and consumer +using larger amount of messages than other tests. +""" + +import asyncio +from datetime import timedelta +from typing import Any + +import pytest +from confluent_kafka.admin import AdminClient + +from retriable_kafka_client.chunking import generate_group_id +from retriable_kafka_client.headers import ( + CHUNK_GROUP_HEADER, + CHUNK_ID_HEADER, + serialize_number_to_bytes, + NUMBER_OF_CHUNKS_HEADER, +) +from retriable_kafka_client import ConsumeTopicConfig + +from .integration_utils import ( + IntegrationTestScaffold, + ScaffoldConfig, + RandomDelay, +) + + +@pytest.mark.asyncio +async def test_lost_chunk_recover( + kafka_config: dict[str, Any], + admin_client: AdminClient, + caplog: pytest.LogCaptureFixture, +) -> None: + """ + Test that a large number of messages (30+) can be processed without deadlocks + when using: + - A small number of workers (2) + - Random delays in processing + - Exceptions on first attempt that trigger retries + + This test ensures the consumer can handle concurrent processing with retries + without getting into a deadlock state. + """ + + config = ScaffoldConfig( + topics=[ + ConsumeTopicConfig( + base_topic="test-chunks-recover", + ), + ], + group_id="test-chunks", + timeout=30.0, + split_messages=False, + max_chunk_reassembly_wait_time=timedelta(seconds=0.5), + ) + + async with IntegrationTestScaffold(kafka_config, admin_client, config) as scaffold: + group_id = generate_group_id() + consumer = scaffold.start_consumer( + delay=RandomDelay(min_delay=0, max_delay=0.05), + max_workers=2, + max_concurrency=4, + ) + await asyncio.sleep(2) + + # This is missing the first message from the group + await scaffold.send_messages( + 1, + headers={ + CHUNK_GROUP_HEADER: group_id, + CHUNK_ID_HEADER: serialize_number_to_bytes(1), + NUMBER_OF_CHUNKS_HEADER: serialize_number_to_bytes(2), + }, + ) + # Send the next message only after the first message builder expired + # and this expiration can be handled during the next message processing + await asyncio.sleep(5) + await scaffold.send_messages(1) + + assert await scaffold.wait_for( + lambda _, success_counts: ( + any( + "Removing stale message builder that failed to assemble message in time. " + "Lost message topic: test-chunks-recover" in message + for message in caplog.messages + ) + and success_counts.get(1) == 1 + ), + config.timeout, + ) + # Check that the cache has been cleared + assert not consumer._consumer_thread.consumer._BaseConsumer__tracking_manager._TrackingManager__message_builders diff --git a/tests/unit/test_chunking.py b/tests/unit/test_chunking.py index 5aec0bc..fab6537 100644 --- a/tests/unit/test_chunking.py +++ b/tests/unit/test_chunking.py @@ -1,9 +1,11 @@ +from datetime import timedelta from unittest.mock import MagicMock import pytest from confluent_kafka import Message -from retriable_kafka_client.chunking import ChunkingCache, calculate_header_size +from retriable_kafka_client.chunking import MessageGroupBuilder, calculate_header_size +from retriable_kafka_client.consumer_tracking import TrackingManager from retriable_kafka_client.headers import ( CHUNK_GROUP_HEADER, NUMBER_OF_CHUNKS_HEADER, @@ -12,25 +14,72 @@ ) -def test_chunking_receive(): - message1 = MagicMock(spec=Message) - message2 = MagicMock(spec=Message) - message3 = MagicMock(spec=Message) - for message_id, message in enumerate([message1, message2, message3]): - message.headers.return_value = [ - (CHUNK_GROUP_HEADER, b"foo"), - (NUMBER_OF_CHUNKS_HEADER, serialize_number_to_bytes(3)), - (CHUNK_ID_HEADER, serialize_number_to_bytes(message_id)), - ] - message.topic.return_value = "foo_topic" - message.partition.return_value = 0 - message1.value.return_value = b'{"hello' - message2.value.return_value = b'": "wor' - message3.value.return_value = b'ld"}' - chunking_cache = ChunkingCache() - assert chunking_cache.receive(message1) is None - assert chunking_cache.receive(message2) is None - assert chunking_cache.receive(message3).deserialize() == {"hello": "world"} +def _make_chunk_message( + group_id: bytes, + chunk_id: int, + total_chunks: int, + value: bytes, + topic: str = "foo_topic", + partition: int = 0, + offset: int = 0, +) -> MagicMock: + message = MagicMock(spec=Message) + message.headers.return_value = [ + (CHUNK_GROUP_HEADER, group_id), + (NUMBER_OF_CHUNKS_HEADER, serialize_number_to_bytes(total_chunks)), + (CHUNK_ID_HEADER, serialize_number_to_bytes(chunk_id)), + ] + message.topic.return_value = topic + message.partition.return_value = partition + message.offset.return_value = offset + message.value.return_value = value + return message + + +def test_message_group_builder_complete(): + builder = MessageGroupBuilder(max_wait_time=timedelta(minutes=15)) + messages = [ + _make_chunk_message(b"foo", i, 3, v, offset=i) + for i, v in enumerate([b'{"hello', b'": "wor', b'ld"}']) + ] + for msg in messages[:2]: + builder.add(msg) + assert not builder.is_complete + builder.add(messages[2]) + assert builder.is_complete + group = builder.get_message_group() + assert group is not None + assert group.deserialize() == {"hello": "world"} + + +def test_message_group_builder_complete_empty(): + builder = MessageGroupBuilder(max_wait_time=timedelta(minutes=15)) + assert builder.is_complete is False + + +def test_message_group_builder_incomplete_returns_none(): + builder = MessageGroupBuilder(max_wait_time=timedelta(minutes=15)) + msg = _make_chunk_message(b"foo", 0, 3, b'{"hello', offset=0) + builder.add(msg) + assert not builder.is_complete + assert builder.get_message_group() is None + + +def test_tracking_manager_receive_chunked(): + tracker = TrackingManager( + concurrency=16, + cancel_wait_time=30.0, + max_chunk_wait_time=timedelta(minutes=15), + ) + messages = [ + _make_chunk_message(b"foo", i, 3, v, offset=i) + for i, v in enumerate([b'{"hello', b'": "wor', b'ld"}']) + ] + assert tracker.receive(messages[0]) is None + assert tracker.receive(messages[1]) is None + result = tracker.receive(messages[2]) + assert result is not None + assert result.deserialize() == {"hello": "world"} @pytest.mark.parametrize( diff --git a/tests/unit/test_consumer.py b/tests/unit/test_consumer.py index 16c8d23..06525bc 100644 --- a/tests/unit/test_consumer.py +++ b/tests/unit/test_consumer.py @@ -470,7 +470,9 @@ def test_ack_message_with_exception( base_consumer._BaseConsumer__ack_message(message_group, mock_future) # Verify tracking manager schedule_commit was called - mock_tracking_manager.schedule_commit.assert_called_once_with(message_group) + mock_tracking_manager.schedule_commit.assert_called_once_with( + message_group, release_semaphore=True + ) # Verify retry manager was called to resend the message mock_retry_manager.resend_message.assert_called_once_with(message_group) diff --git a/tests/unit/test_consumer_tracking.py b/tests/unit/test_consumer_tracking.py index 4e20699..cd50b59 100644 --- a/tests/unit/test_consumer_tracking.py +++ b/tests/unit/test_consumer_tracking.py @@ -1,16 +1,26 @@ """Tests for TrackingManager module""" import logging +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import MagicMock from concurrent.futures import Future from confluent_kafka import Message, TopicPartition from retriable_kafka_client.consumer_tracking import ( TrackingManager, ) +from retriable_kafka_client.headers import ( + CHUNK_GROUP_HEADER, + NUMBER_OF_CHUNKS_HEADER, + CHUNK_ID_HEADER, + serialize_number_to_bytes, +) from retriable_kafka_client.kafka_utils import TrackingInfo, MessageGroup +DEFAULT_CHUNK_WAIT = timedelta(minutes=15) + @pytest.mark.parametrize( "partition_states,expected_commits,expected_remaining_to_commit", @@ -70,7 +80,9 @@ def test_offset_cache_pop_committable( expected_remaining_to_commit: dict[int, set[int]], ) -> None: """Test pop_committable covers all branches with various partition states.""" - cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) + cache = TrackingManager( + concurrency=16, cancel_wait_time=30.0, max_chunk_wait_time=DEFAULT_CHUNK_WAIT + ) # Setup the cache state for state in partition_states: @@ -113,7 +125,9 @@ def test_offset_cache_schedule_commit_success( caplog: pytest.LogCaptureFixture, ) -> None: """Test schedule_commit successfully moves offset from to_process to to_commit.""" - cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) + cache = TrackingManager( + concurrency=16, cancel_wait_time=30.0, max_chunk_wait_time=DEFAULT_CHUNK_WAIT + ) # Create a MessageGroup mock_inner_message = MagicMock( @@ -139,7 +153,7 @@ def test_offset_cache_schedule_commit_success( assert (43,) not in cache._TrackingManager__to_commit.get(partition_info, set()) # Now schedule it for commit - result = cache.schedule_commit(message_group) + result = cache.schedule_commit(message_group, release_semaphore=True) # Verify success assert result is True @@ -157,7 +171,9 @@ def test_offset_cache_schedule_commit_without_prior_processing( ) -> None: """Test schedule_commit behavior when message wasn't marked for processing first.""" caplog.set_level(logging.WARNING) - cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) + cache = TrackingManager( + concurrency=16, cancel_wait_time=30.0, max_chunk_wait_time=DEFAULT_CHUNK_WAIT + ) # Create a MessageGroup mock_inner_message = MagicMock( @@ -172,8 +188,8 @@ def test_offset_cache_schedule_commit_without_prior_processing( ) # Don't add it to to_process (simulating it was never marked for processing) - # Try to schedule commit - result = cache.schedule_commit(message_group) + # Try to schedule commit without releasing semaphore (wasn't acquired) + result = cache.schedule_commit(message_group, release_semaphore=False) # schedule_commit always returns True in the new implementation assert result is True @@ -229,7 +245,9 @@ def offset_cache_revoke_processing( expected_cancelled: list[str], expected_not_cancelled: list[str], ) -> None: - tracking_manager = TrackingManager(concurrency=16, cancel_wait_time=30.0) + tracking_manager = TrackingManager( + concurrency=16, cancel_wait_time=30.0, max_chunk_wait_time=DEFAULT_CHUNK_WAIT + ) # Track created mock futures by name future_registry = {} @@ -334,7 +352,9 @@ def test_offset_cache_register_revoke( partitions_to_revoke: list[TopicPartition], expected_remaining_process: dict[TrackingInfo, set[int]], ) -> None: - cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) + cache = TrackingManager( + concurrency=16, cancel_wait_time=30.0, max_chunk_wait_time=DEFAULT_CHUNK_WAIT + ) cache._TrackingManager__to_process = to_process_data # Call register_revoke with partitions @@ -354,7 +374,9 @@ def test_offset_cache_register_revoke_err( error: type[Exception], ) -> None: stub_partition = TrackingInfo("topic-a", 0) - cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) + cache = TrackingManager( + concurrency=16, cancel_wait_time=30.0, max_chunk_wait_time=DEFAULT_CHUNK_WAIT + ) mock_future = MagicMock(spec=Future) mock_future.cancel.return_value = False mock_future.result.side_effect = error("Whoops") @@ -372,7 +394,9 @@ def test_offset_cache_schedule_commit_offset_not_in_partition( In the new implementation, schedule_commit always succeeds and adds to to_commit. """ caplog.set_level(logging.WARNING) - cache = TrackingManager(concurrency=16, cancel_wait_time=30.0) + cache = TrackingManager( + concurrency=16, cancel_wait_time=30.0, max_chunk_wait_time=DEFAULT_CHUNK_WAIT + ) # Create a MessageGroup mock_inner_message = MagicMock( @@ -394,7 +418,7 @@ def test_offset_cache_schedule_commit_offset_not_in_partition( cache._TrackingManager__to_process[partition_info][(offset,)] = mock_future # Try to schedule commit for offset 42 which doesn't exist in the partition - result = cache.schedule_commit(message_group) + result = cache.schedule_commit(message_group, release_semaphore=False) # In new implementation, schedule_commit always succeeds assert result is True @@ -411,3 +435,77 @@ def test_offset_cache_schedule_commit_offset_not_in_partition( (101,), (102,), } + + +def _make_chunk_message( + group_id: bytes, + chunk_id: int, + total_chunks: int, + topic: str = "test-topic", + partition: int = 0, + offset: int = 0, +) -> MagicMock: + msg = MagicMock(spec=Message) + msg.headers.return_value = [ + (CHUNK_GROUP_HEADER, group_id), + (NUMBER_OF_CHUNKS_HEADER, serialize_number_to_bytes(total_chunks)), + (CHUNK_ID_HEADER, serialize_number_to_bytes(chunk_id)), + ] + msg.topic.return_value = topic + msg.partition.return_value = partition + msg.offset.return_value = offset + msg.value.return_value = b'{"partial"' + return msg + + +@pytest.mark.parametrize( + "fresh_group_ids,stale_group_ids", + [ + pytest.param( + [], + [b"group_a", b"group_b"], + id="all_stale_discard_everything", + ), + pytest.param( + [b"group_b"], + [b"group_a"], + id="one_stale_one_fresh_discard_one", + ), + ], +) +def test_tracking_manager__cleanup_stale_builders( + stale_group_ids: list[bytes], + fresh_group_ids: list[bytes], +) -> None: + max_wait = timedelta(minutes=10) + base_time = datetime(2025, 1, 1, tzinfo=timezone.utc) + + with patch("retriable_kafka_client.chunking.datetime") as mock_dt: + mock_dt.now.return_value = base_time + + tracker = TrackingManager( + concurrency=16, + cancel_wait_time=30.0, + max_chunk_wait_time=max_wait, + ) + + for i, group_id in enumerate(stale_group_ids): + tracker.receive(_make_chunk_message(group_id, 0, 3, offset=i)) + + mock_dt.now.return_value = base_time + timedelta(minutes=8) + for i, group_id in enumerate(fresh_group_ids): + tracker.receive(_make_chunk_message(group_id, 0, 3, offset=100 + i)) + + assert len(tracker._TrackingManager__message_builders) == ( + len(stale_group_ids) + len(fresh_group_ids) + ) + + mock_dt.now.return_value = base_time + timedelta(minutes=11) + tracker._cleanup_stale_builders() + + assert len(tracker._TrackingManager__message_builders) == len(fresh_group_ids) + remaining_group_ids = { + builder.group_id + for builder in tracker._TrackingManager__message_builders.values() + } + assert remaining_group_ids == set(fresh_group_ids) From 9437a14ab3d056481e65648504a22a7444f63087 Mon Sep 17 00:00:00 2001 From: Brian Lindner Date: Thu, 2 Apr 2026 11:08:13 -0400 Subject: [PATCH 3/5] fix: prepare chunks with separate function to avoid code duplication --- src/retriable_kafka_client/producer.py | 46 +++++++++++++------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/retriable_kafka_client/producer.py b/src/retriable_kafka_client/producer.py index 0355c2f..183a6a9 100644 --- a/src/retriable_kafka_client/producer.py +++ b/src/retriable_kafka_client/producer.py @@ -5,7 +5,7 @@ import logging import time from copy import copy -from typing import Any +from typing import Any, Generator from confluent_kafka import Producer, KafkaException @@ -122,6 +122,24 @@ def __handle_problems(problems: dict[str, Exception]) -> None: LOGGER.error("Cannot produce to topic %s: %s", problem_topic, problem) raise next(iter(problems.values())) + def _prepare_chunks( + self, + group_id: bytes, + message: dict[str, Any] | bytes | list[bytes], + headers: dict[str, str | bytes] | None = None, + ) -> Generator[tuple[bytes, dict[str, str | bytes] | None]]: + chunks = self.__serialize_message(message, headers, self._config.split_messages) + number_of_chunks = len(chunks) + for chunk_id, chunk in enumerate(chunks): + chunk_headers = copy(headers) if headers else {} + if self._config.split_messages: + chunk_headers[CHUNK_GROUP_HEADER] = group_id + chunk_headers[NUMBER_OF_CHUNKS_HEADER] = serialize_number_to_bytes( + number_of_chunks + ) + chunk_headers[CHUNK_ID_HEADER] = serialize_number_to_bytes(chunk_id) + yield (chunk, chunk_headers) + def send_sync( self, message: dict[str, Any] | bytes | list[bytes], @@ -139,27 +157,18 @@ def send_sync( BufferError: if Kafka queue is full even after all attempts KafkaException: if some Kafka error occurs even after all attempts """ - chunks = self.__serialize_message(message, headers, self._config.split_messages) - number_of_chunks = len(chunks) problems: dict[str, Exception] = {} timestamp = int(time.time() * 1000) # Kafka expects milliseconds for topic in self._config.topics: group_id = generate_group_id() - for chunk_id, chunk in enumerate(chunks): - headers = copy(headers) if headers else {} - if self._config.split_messages: - headers[CHUNK_GROUP_HEADER] = group_id - headers[NUMBER_OF_CHUNKS_HEADER] = serialize_number_to_bytes( - number_of_chunks - ) - headers[CHUNK_ID_HEADER] = serialize_number_to_bytes(chunk_id) + for chunk, chunk_headers in self._prepare_chunks(group_id, message, headers): for attempt_idx in range(self._config.retries + 1): try: self._producer.produce( topic=topic, value=chunk, timestamp=timestamp, - headers=headers, + headers=chunk_headers, key=group_id, ) break @@ -190,27 +199,18 @@ async def send( BufferError: if Kafka queue is full even after all attempts KafkaException: if some Kafka error occurs even after all attempts """ - chunks = self.__serialize_message(message, headers, self._config.split_messages) - number_of_chunks = len(chunks) problems: dict[str, Exception] = {} timestamp = int(time.time() * 1000) # Kafka expects milliseconds for topic in self._config.topics: group_id = generate_group_id() - for chunk_id, chunk in enumerate(chunks): - headers = copy(headers) if headers else {} - if self._config.split_messages: - headers[CHUNK_GROUP_HEADER] = group_id - headers[NUMBER_OF_CHUNKS_HEADER] = serialize_number_to_bytes( - number_of_chunks - ) - headers[CHUNK_ID_HEADER] = serialize_number_to_bytes(chunk_id) + for chunk, chunk_headers in self._prepare_chunks(group_id, message, headers): for attempt_idx in range(self._config.retries + 1): try: self._producer.produce( topic=topic, value=chunk, timestamp=timestamp, - headers=headers, + headers=chunk_headers, key=group_id, ) break From 3015405de101b157b40b79405e5f591b84c00586 Mon Sep 17 00:00:00 2001 From: Marek Szymutko Date: Tue, 7 Apr 2026 14:48:56 +0200 Subject: [PATCH 4/5] fix(ISV-7020): fix formatting and address a review comment --- src/retriable_kafka_client/producer.py | 8 ++++++-- tests/unit/test_consumer.py | 5 +++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/retriable_kafka_client/producer.py b/src/retriable_kafka_client/producer.py index 183a6a9..e5481f2 100644 --- a/src/retriable_kafka_client/producer.py +++ b/src/retriable_kafka_client/producer.py @@ -161,7 +161,9 @@ def send_sync( timestamp = int(time.time() * 1000) # Kafka expects milliseconds for topic in self._config.topics: group_id = generate_group_id() - for chunk, chunk_headers in self._prepare_chunks(group_id, message, headers): + for chunk, chunk_headers in self._prepare_chunks( + group_id, message, headers + ): for attempt_idx in range(self._config.retries + 1): try: self._producer.produce( @@ -203,7 +205,9 @@ async def send( timestamp = int(time.time() * 1000) # Kafka expects milliseconds for topic in self._config.topics: group_id = generate_group_id() - for chunk, chunk_headers in self._prepare_chunks(group_id, message, headers): + for chunk, chunk_headers in self._prepare_chunks( + group_id, message, headers + ): for attempt_idx in range(self._config.retries + 1): try: self._producer.produce( diff --git a/tests/unit/test_consumer.py b/tests/unit/test_consumer.py index 06525bc..b09be7f 100644 --- a/tests/unit/test_consumer.py +++ b/tests/unit/test_consumer.py @@ -77,6 +77,11 @@ def test_consumer__process_message_decode_fail( def test_process_message_skip_chunk( base_consumer: BaseConsumer, ): + """ + Check that the message chunk is not sent to processing + by checking that its _process_message call doesn't + return a future. + """ mock_message = MagicMock( spec=Message, ) From e52da651e537549fd38df3c70727772c03efd299 Mon Sep 17 00:00:00 2001 From: Marek Szymutko Date: Wed, 8 Apr 2026 09:39:16 +0200 Subject: [PATCH 5/5] feat(ISV-7020): add cluster-error handling --- src/retriable_kafka_client/__init__.py | 2 + src/retriable_kafka_client/error.py | 23 ++++++++ src/retriable_kafka_client/producer.py | 76 +++++++++++++++++++++++--- tests/integration/integration_utils.py | 8 ++- tests/integration/test_producer.py | 40 ++++++++++++++ tests/unit/test_producer.py | 37 +++++++++++++ 6 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 src/retriable_kafka_client/error.py create mode 100644 tests/integration/test_producer.py diff --git a/src/retriable_kafka_client/__init__.py b/src/retriable_kafka_client/__init__.py index a1caf7b..dd78cb4 100644 --- a/src/retriable_kafka_client/__init__.py +++ b/src/retriable_kafka_client/__init__.py @@ -5,6 +5,7 @@ from .orchestrate import consume_topics, ConsumerThread from .producer import BaseProducer from .health import HealthCheckClient +from .error import SendError __all__ = ( "BaseConsumer", @@ -16,4 +17,5 @@ "ProducerConfig", "ConsumeTopicConfig", "HealthCheckClient", + "SendError", ) diff --git a/src/retriable_kafka_client/error.py b/src/retriable_kafka_client/error.py new file mode 100644 index 0000000..0c644c2 --- /dev/null +++ b/src/retriable_kafka_client/error.py @@ -0,0 +1,23 @@ +"""Module for error definitions""" + +from confluent_kafka import KafkaError + + +class SendError(RuntimeError): + """Class for raising problems with message producing.""" + + def __init__( + self, *args, retriable: bool, fatal: bool, kafka_error: KafkaError | None + ) -> None: + self.retriable = retriable + self.fatal = fatal + self.kafka_error = kafka_error + super().__init__(*args) + + @staticmethod + def format_err(err: KafkaError) -> str: + """Propose human-readable error message constructed from Kafka error.""" + return ( + f"Underlying error details: {err.__class__.__name__}(name={err.name()}, " + f"retriable={err.retriable()}, fatal={err.fatal()})" + ) diff --git a/src/retriable_kafka_client/producer.py b/src/retriable_kafka_client/producer.py index e5481f2..a71ebd5 100644 --- a/src/retriable_kafka_client/producer.py +++ b/src/retriable_kafka_client/producer.py @@ -5,11 +5,12 @@ import logging import time from copy import copy -from typing import Any, Generator +from typing import Any, Generator, Callable -from confluent_kafka import Producer, KafkaException +from confluent_kafka import Producer, KafkaException, Message, KafkaError from .chunking import generate_group_id, calculate_header_size +from .error import SendError from .headers import ( CHUNK_GROUP_HEADER, NUMBER_OF_CHUNKS_HEADER, @@ -49,6 +50,46 @@ def __init__(self, config: ProducerConfig): } self._config_dict.update(**self._config.additional_settings) + @staticmethod + def _get_delivery_callback( + topic: str, + ) -> Callable[[KafkaError | None, Message | None], None]: + """ + Gets the callback which should be called upon delivery of messages. + Handles errors or logs information about messages. + Args: + topic: The topic to which this callback should be called. + Used for logging information. + + Returns: The callable which should be used as the actual callback. + """ + + def callback(err: KafkaError | None, msg: Message | None) -> None: + if err is not None: + raise SendError( + f"Failed to flush produced message to topic " + f"{topic}\n{SendError.format_err(err)}", + retriable=err.retriable(), + fatal=err.fatal(), + kafka_error=err, + ) + if msg is not None: + LOGGER.info( + "Message delivered to topic: %s, partition: %s, offset %s", + msg.topic(), + msg.partition(), + msg.offset(), + ) + return None + raise SendError( + "Failed to flush produced message to topic, no details from the server.", + retriable=True, + fatal=False, + kafka_error=None, + ) + + return callback + @property def topics(self) -> list[str]: """Return topics this producer produces to.""" @@ -122,6 +163,13 @@ def __handle_problems(problems: dict[str, Exception]) -> None: LOGGER.error("Cannot produce to topic %s: %s", problem_topic, problem) raise next(iter(problems.values())) + @staticmethod + def _is_problem_retriable(problem: Exception) -> bool: + is_retriable = True + if isinstance(problem, SendError): + is_retriable = problem.retriable + return is_retriable + def _prepare_chunks( self, group_id: bytes, @@ -156,6 +204,8 @@ def send_sync( TypeError: if message is not a JSON-serializable object nor bytes BufferError: if Kafka queue is full even after all attempts KafkaException: if some Kafka error occurs even after all attempts + SendError: if any problems appear in Kafka cluster after sending + a message """ problems: dict[str, Exception] = {} timestamp = int(time.time() * 1000) # Kafka expects milliseconds @@ -172,16 +222,20 @@ def send_sync( timestamp=timestamp, headers=chunk_headers, key=group_id, + on_delivery=self._get_delivery_callback(topic), ) + self._producer.flush() break - except (BufferError, KafkaException) as err: - if attempt_idx < self._config.retries: + except (BufferError, KafkaException, SendError) as err: + if ( + self._is_problem_retriable(err) + and attempt_idx < self._config.retries + ): backoff_time = self.__calculate_backoff(attempt_idx) self.__log_retry(attempt_idx, backoff_time) time.sleep(backoff_time) continue problems[topic] = err - self._producer.flush() self.__handle_problems(problems) async def send( @@ -200,6 +254,8 @@ async def send( TypeError: if message is not a JSON-serializable object nor bytes BufferError: if Kafka queue is full even after all attempts KafkaException: if some Kafka error occurs even after all attempts + SendError: if any problems appear in Kafka cluster after sending + a message """ problems: dict[str, Exception] = {} timestamp = int(time.time() * 1000) # Kafka expects milliseconds @@ -216,16 +272,20 @@ async def send( timestamp=timestamp, headers=chunk_headers, key=group_id, + on_delivery=self._get_delivery_callback(topic), ) + self._producer.flush() break - except (BufferError, KafkaException) as err: - if attempt_idx < self._config.retries: + except (BufferError, KafkaException, SendError) as err: + if ( + self._is_problem_retriable(err) + and attempt_idx < self._config.retries + ): backoff_time = self.__calculate_backoff(attempt_idx) self.__log_retry(attempt_idx, backoff_time) await asyncio.sleep(backoff_time) continue problems[topic] = err - self._producer.flush() self.__handle_problems(problems) def connection_healthcheck(self) -> bool: diff --git a/tests/integration/integration_utils.py b/tests/integration/integration_utils.py index 19c1808..48eb233 100644 --- a/tests/integration/integration_utils.py +++ b/tests/integration/integration_utils.py @@ -240,6 +240,7 @@ class ScaffoldConfig: split_messages: bool = False max_chunk_reassembly_wait_time: timedelta = field(default=timedelta(seconds=10)) additional_settings: dict[str, Any] = field(default_factory=dict) + topic_config: dict[str, str] = field(default_factory=dict) class IntegrationTestScaffold: @@ -292,7 +293,12 @@ def _create_topics(self) -> None: all_topic_names.append(tc.retry_topic) new_topics = [ - NewTopic(topic, num_partitions=1, replication_factor=1) + NewTopic( + topic, + num_partitions=1, + replication_factor=1, + config=self.config.topic_config, + ) for topic in all_topic_names ] diff --git a/tests/integration/test_producer.py b/tests/integration/test_producer.py new file mode 100644 index 0000000..d6893b3 --- /dev/null +++ b/tests/integration/test_producer.py @@ -0,0 +1,40 @@ +""" +Integration tests to check sender + +""" + +from typing import Any + +import pytest +from confluent_kafka.admin import AdminClient + +from retriable_kafka_client import ConsumeTopicConfig, SendError + +from .integration_utils import ( + IntegrationTestScaffold, + ScaffoldConfig, +) + + +@pytest.mark.asyncio +async def test_send_error( + kafka_config: dict[str, Any], admin_client: AdminClient +) -> None: + """ + Test that send_error is raised if the topic cannot handle + messages + """ + config = ScaffoldConfig( + topics=[ + ConsumeTopicConfig(base_topic="test-send-error-topic"), + ], + group_id="test-send-error-group", + topic_config={"max.message.bytes": "1000"}, + ) + # We set topic-level constraint, producer is not aware of it. Without callbacks, + # this wouldn't raise any exceptions, it would just silently ignore it + async with IntegrationTestScaffold(kafka_config, admin_client, config) as scaffold: + with pytest.raises(SendError) as err: + await scaffold.send_messages(1, extra_fields={"large": 10000 * "a"}) + assert err.value.retriable is False # This problem cannot be easily retried + assert err.value.kafka_error.name() == "MSG_SIZE_TOO_LARGE" diff --git a/tests/unit/test_producer.py b/tests/unit/test_producer.py index aea0695..0b88f91 100644 --- a/tests/unit/test_producer.py +++ b/tests/unit/test_producer.py @@ -2,12 +2,14 @@ import asyncio import json +import logging from typing import Generator, Any from unittest.mock import patch, MagicMock import pytest from confluent_kafka import KafkaException +from retriable_kafka_client import SendError from retriable_kafka_client.producer import BaseProducer from retriable_kafka_client.config import ProducerConfig from retriable_kafka_client.headers import ( @@ -265,3 +267,38 @@ def test_serialize_message_resize_chunks( chunking_size != 0, ) assert result == expected_output + + +def test__get_delivery_callback(caplog: pytest.LogCaptureFixture) -> None: + """Test that _get_delivery_callback works as expected.""" + caplog.set_level(logging.INFO) + topic_name = "foo" + callback = BaseProducer._get_delivery_callback(topic_name) + mock_err = MagicMock() + with pytest.raises(SendError): + callback(mock_err, None) + mock_msg = MagicMock( + topic=lambda: topic_name, partition=lambda: 0, offset=lambda: 67 + ) + callback(None, mock_msg) + assert ( + f"Message delivered to topic: {topic_name}, partition: 0, offset 67" + in caplog.messages + ) + with pytest.raises(SendError): + callback(None, None) + + +@pytest.mark.parametrize( + ["problem", "is_retriable"], + [ + (SendError("big bad", retriable=False, fatal=True, kafka_error=None), False), + (BufferError("buffer is no more more"), True), + ( + SendError("big bad", retriable=True, fatal=False, kafka_error=MagicMock()), + True, + ), + ], +) +def test__is_problem_retriable(problem: Exception, is_retriable: bool) -> None: + assert BaseProducer._is_problem_retriable(problem) == is_retriable