33import logging
44from collections import defaultdict
55from concurrent .futures import Future
6+ from itertools import chain
67from threading import Lock , Semaphore
7- from typing import NamedTuple , Any
8+ from typing import Any , Iterable
89
9- from confluent_kafka import Message , TopicPartition
10+ from confluent_kafka import TopicPartition
11+
12+ from retriable_kafka_client .kafka_utils import TrackingInfo , MessageGroup
1013
1114LOGGER = logging .getLogger (__name__ )
1215
1316
14- class _PartitionInfo (NamedTuple ):
15- """
16- Consistently hashable dataclass for storing information about a partition,
17- namely offset information. Can be used as keys in a dictionary.
18- """
19-
20- topic : str
21- partition : int
22-
23- @staticmethod
24- def from_message (message : Message ) -> "_PartitionInfo" :
25- """
26- Create a PartitionInfo from a Kafka message.
27- Args:
28- message: Kafka message object
29- Returns: hashable info about a partition
30- """
31- message_topic = message .topic ()
32- message_partition = message .partition ()
33- # This should never happen with polled messages. Polled messages need
34- # the information asserted to be valid Kafka messages. This can
35- # happen only for custom-created messages objects, which this
36- # method is not intended to be used for
37- assert message_topic is not None and message_partition is not None , (
38- "Invalid message cannot be converted to partition info"
39- )
40- return _PartitionInfo (message_topic , message_partition )
41-
42- def to_offset_info (self , offset : int ) -> TopicPartition :
43- """
44- Create a Kafka-committable object using the provided offset.
45- Args:
46- offset: The offset to be committed. Make sure to commit
47- offset one higher than the latest processed message.
48- Returns: The committable Kafka object
49- """
50- return TopicPartition (topic = self .topic , partition = self .partition , offset = offset )
17+ def _flatten_offsets (done_offsets : Iterable [tuple [int , ...]]) -> list [int ]:
18+ return list (chain (* done_offsets ))
5119
5220
5321class TrackingManager :
@@ -85,8 +53,10 @@ class TrackingManager:
8553 """
8654
8755 def __init__ (self , concurrency : int , cancel_wait_time : float ):
88- self .__to_process : dict [_PartitionInfo , dict [int , Future ]] = defaultdict (dict )
89- self .__to_commit : dict [_PartitionInfo , set [int ]] = defaultdict (set )
56+ self .__to_process : dict [TrackingInfo , dict [tuple [int , ...], Future ]] = (
57+ defaultdict (dict )
58+ )
59+ self .__to_commit : dict [TrackingInfo , set [tuple [int , ...]]] = defaultdict (set )
9060 self .__access_lock = Lock () # For handling multithreaded access to this object
9161 self .__semaphore = Semaphore (concurrency )
9262 self .__cancel_wait_time = cancel_wait_time
@@ -105,15 +75,15 @@ def pop_committable(self) -> list[TopicPartition]:
10575 """
10676 to_commit = []
10777 with self .__access_lock :
108- for partition_info , pending_to_commit in self .__to_commit .items ():
109- if not pending_to_commit :
78+ for partition_info , tuples_pending_to_commit in self .__to_commit .items ():
79+ if not tuples_pending_to_commit :
11080 # Nothing to commit
11181 continue
11282
113- pending_to_process = self .__to_process .get (partition_info , None )
114- if not pending_to_process :
83+ tuples_pending_to_process = self .__to_process .get (partition_info , None )
84+ if not tuples_pending_to_process :
11585 # Nothing is blocking the committing
116- max_to_commit = max (pending_to_commit )
86+ max_to_commit = max (_flatten_offsets ( tuples_pending_to_commit ) )
11787 to_commit .append (
11888 TopicPartition (
11989 topic = partition_info .topic ,
@@ -123,17 +93,18 @@ def pop_committable(self) -> list[TopicPartition]:
12393 )
12494 self .__to_commit [partition_info ] = set ()
12595 continue
126-
127- min_pending_to_process = min (pending_to_process )
96+ min_pending_to_process = min (
97+ _flatten_offsets (tuples_pending_to_process )
98+ )
12899 commit_candidates = {
129- offset
130- for offset in pending_to_commit
131- if offset < min_pending_to_process
100+ offset_tuple
101+ for offset_tuple in tuples_pending_to_commit
102+ if all ( offset < min_pending_to_process for offset in offset_tuple )
132103 }
133104 if not commit_candidates :
134105 # Nothing to commit
135106 continue
136- max_to_commit = max (commit_candidates )
107+ max_to_commit = max (_flatten_offsets ( commit_candidates ) )
137108 to_commit .append (
138109 TopicPartition (
139110 topic = partition_info .topic ,
@@ -157,11 +128,11 @@ def reschedule_uncommittable(
157128 failed_committable: list of data that failed to be committed
158129 """
159130 for failed in failed_committable :
160- self .__to_commit . setdefault (
161- _PartitionInfo (topic = failed .topic , partition = failed .partition ), set ( )
162- ) .add (failed .offset )
131+ self .__to_commit [
132+ TrackingInfo (topic = failed .topic , partition = failed .partition )
133+ ] .add (( failed .offset ,) )
163134
164- def process_message (self , message : Message , future : Future [Any ]) -> None :
135+ def process_message (self , message : MessageGroup , future : Future [Any ]) -> None :
165136 """
166137 Mark message as pending for processing.
167138 Args:
@@ -170,15 +141,16 @@ def process_message(self, message: Message, future: Future[Any]) -> None:
170141 """
171142 # We cannot really use context manager, the semaphore is released in
172143 # future's callback or when the future is cancelled
144+
173145 self .__semaphore .acquire () # pylint: disable=consider-using-with
174- message_offset : int = message .offset () # type: ignore[assignment]
146+ message_offsets = message .offsets
175147 with self .__access_lock :
176148 # Mark the message as being processed
177- self .__to_process [_PartitionInfo . from_message (message )][
178- message_offset + 1
149+ self .__to_process [TrackingInfo . from_message_group (message )][
150+ tuple ( message_offset + 1 for message_offset in message_offsets )
179151 ] = future
180152
181- def schedule_commit (self , message : Message ) -> bool :
153+ def schedule_commit (self , message : MessageGroup ) -> bool :
182154 """
183155 Mark message as pending for committing when its processing is fully done.
184156 Args:
@@ -188,12 +160,12 @@ def schedule_commit(self, message: Message) -> bool:
188160 as pending for processing), False otherwise
189161 """
190162 self .__semaphore .release ()
191- partition_info = _PartitionInfo . from_message (message )
192- message_offset : int = message .offset () # type: ignore[assignment]
193- stored_offset = message_offset + 1
163+ partition_info = TrackingInfo . from_message_group (message )
164+ message_offsets = message .offsets
165+ stored_offsets = tuple ( message_offset + 1 for message_offset in message_offsets )
194166 with self .__access_lock :
195- self .__to_process [partition_info ].pop (stored_offset , None )
196- self .__to_commit .setdefault (partition_info , set ()).add (stored_offset )
167+ self .__to_process [partition_info ].pop (stored_offsets , None )
168+ self .__to_commit .setdefault (partition_info , set ()).add (stored_offsets )
197169 self ._cleanup ()
198170 return True
199171
@@ -212,7 +184,7 @@ def _cleanup(self) -> None:
212184 cache_to_clean .pop (key , None )
213185
214186 def _revoke_processing (
215- self , revoked_partitions : set [_PartitionInfo ]
187+ self , revoked_partitions : set [TrackingInfo ]
216188 ) -> list [Future [Any ]]:
217189 """
218190 Cancel all pending tracked futures related to the given partitions.
@@ -253,7 +225,7 @@ def register_revoke(self, partitions: list[TopicPartition] | None = None) -> Non
253225 revoked_partition_keys = set (self .__to_process .keys ())
254226 else :
255227 revoked_partition_keys = {
256- _PartitionInfo (partition = partition .partition , topic = partition .topic )
228+ TrackingInfo (partition = partition .partition , topic = partition .topic )
257229 for partition in partitions
258230 }
259231 pending_futures = self ._revoke_processing (revoked_partition_keys )
0 commit comments