@@ -92,13 +92,13 @@ def test_kafka_target(kafka_topic_setup_teardown):
9292 assert record .value .decode ("UTF-8" ) == json .dumps (event .body , default = str )
9393
9494
95- async def async_test_write_to_kafka_full_event_readback (kafka_topic_setup_teardown ):
95+ async def async_test_write_to_kafka_full_event_readback (kafka_topic_setup_teardown , partition_key ):
9696 kafka_consumer = kafka_topic_setup_teardown
9797
9898 controller = build_flow (
9999 [
100100 AsyncEmitSource (),
101- KafkaTarget (kafka_brokers , topic , sharding_func = lambda _ : 0 , full_event = True ),
101+ KafkaTarget (kafka_brokers , topic , sharding_func = lambda _ : partition_key , full_event = True ),
102102 ]
103103 ).run ()
104104 events = []
@@ -115,7 +115,10 @@ async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardo
115115 record = next (kafka_consumer )
116116 if event .key is None :
117117 if event .key is None :
118- assert record .key is None
118+ if isinstance (partition_key , int ):
119+ assert record .key is None
120+ else :
121+ assert record .key .decode ("UTF-8" ) == partition_key
119122 else :
120123 assert record .key .decode ("UTF-8" ) == event .key
121124 readback_records .append (json .loads (record .value .decode ("UTF-8" )))
@@ -143,5 +146,6 @@ async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardo
143146 not kafka_brokers ,
144147 reason = "KAFKA_BROKERS must be defined to run kafka tests" ,
145148)
146- def test_async_test_write_to_kafka_full_event_readback (kafka_topic_setup_teardown ):
147- asyncio .run (async_test_write_to_kafka_full_event_readback (kafka_topic_setup_teardown ))
149+ @pytest .mark .parametrize ("partition_key" , [0 , "some_string" ])
150+ def test_async_test_write_to_kafka_full_event_readback (kafka_topic_setup_teardown , partition_key ):
151+ asyncio .run (async_test_write_to_kafka_full_event_readback (kafka_topic_setup_teardown , partition_key ))
0 commit comments