Skip to content

Commit 77713df

Browse files
authored
Encode Kafka partition key (#553)
1 parent 9b970f7 commit 77713df

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

integration/test_kafka_integration.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ def test_kafka_target(kafka_topic_setup_teardown):
9292
assert record.value.decode("UTF-8") == json.dumps(event.body, default=str)
9393

9494

95-
async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown):
95+
async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key):
9696
kafka_consumer = kafka_topic_setup_teardown
9797

9898
controller = build_flow(
9999
[
100100
AsyncEmitSource(),
101-
KafkaTarget(kafka_brokers, topic, sharding_func=lambda _: 0, full_event=True),
101+
KafkaTarget(kafka_brokers, topic, sharding_func=lambda _: partition_key, full_event=True),
102102
]
103103
).run()
104104
events = []
@@ -115,7 +115,10 @@ async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardo
115115
record = next(kafka_consumer)
116116
if event.key is None:
117117
if event.key is None:
118-
assert record.key is None
118+
if isinstance(partition_key, int):
119+
assert record.key is None
120+
else:
121+
assert record.key.decode("UTF-8") == partition_key
119122
else:
120123
assert record.key.decode("UTF-8") == event.key
121124
readback_records.append(json.loads(record.value.decode("UTF-8")))
@@ -143,5 +146,6 @@ async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardo
143146
not kafka_brokers,
144147
reason="KAFKA_BROKERS must be defined to run kafka tests",
145148
)
146-
def test_async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown):
147-
asyncio.run(async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown))
149+
@pytest.mark.parametrize("partition_key", [0, "some_string"])
150+
def test_async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key):
151+
asyncio.run(async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key))

storey/targets.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,9 +1345,7 @@ async def _do(self, event):
13451345
self._producer.close()
13461346
return await self._do_downstream(_termination_obj)
13471347
else:
1348-
key = None
1349-
if event.key is not None:
1350-
key = stringify_key(event.key).encode("UTF-8")
1348+
key = event.key
13511349
record = self._event_to_writer_entry(event)
13521350
if self._full_event:
13531351
record = wrap_event_for_serialization(event, record)
@@ -1359,6 +1357,10 @@ async def _do(self, event):
13591357
partition = sharding_func_result
13601358
else:
13611359
key = sharding_func_result
1360+
1361+
if key is not None:
1362+
key = stringify_key(key).encode("UTF-8")
1363+
13621364
future = self._producer.send(self._topic, record, key, partition=partition)
13631365
# Prevent garbage collection of event until persisted to kafka
13641366
future.add_callback(lambda x: event)

0 commit comments

Comments
 (0)