Skip to content

Commit 7f3cd45

Browse files
committed
Fix RuntimeError in ClusterPubSub sharded message generator
1 parent 0894218 commit 7f3cd45

File tree

2 files changed

+133
-2
lines changed

2 files changed

+133
-2
lines changed

redis/cluster.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2199,16 +2199,31 @@ def _get_node_pubsub(self, node):
21992199
return pubsub
22002200

22012201
def _sharded_message_generator(self):
2202-
for _ in range(len(self.node_pubsub_mapping)):
2202+
"""
2203+
Iterate through pubsubs until a complete cycle is done.
2204+
"""
2205+
while True:
22032206
pubsub = next(self._pubsubs_generator)
2207+
2208+
# None marks end of cycle
2209+
if pubsub is None:
2210+
break
2211+
22042212
message = pubsub.get_message()
22052213
if message is not None:
22062214
return message
2215+
22072216
return None
22082217

22092218
def _pubsubs_generator(self):
2219+
"""
2220+
Generator that yields pubsubs in round-robin fashion.
2221+
Yields None to mark cycle boundaries.
2222+
"""
22102223
while True:
2211-
yield from self.node_pubsub_mapping.values()
2224+
current_nodes = list(self.node_pubsub_mapping.values())
2225+
yield from current_nodes
2226+
yield None # Cycle marker
22122227

22132228
def get_sharded_message(
22142229
self, ignore_subscribe_messages=False, timeout=0.0, target_node=None

tests/test_pubsub.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,122 @@ def test_pubsub_shardnumsub(self, r):
871871
channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)]
872872
assert r.pubsub_shardnumsub("foo", "bar", "baz", target_nodes="all") == channels
873873

874+
@pytest.mark.onlycluster
875+
@skip_if_server_version_lt("7.0.0")
876+
def test_ssubscribe_multiple_channels_different_nodes(self, r):
877+
"""
878+
Test subscribing to multiple sharded channels on different nodes.
879+
Validates that the generator properly handles multiple node_pubsub_mapping entries.
880+
"""
881+
pubsub = r.pubsub()
882+
channel1 = "test-channel:{0}"
883+
channel2 = "test-channel:{6}"
884+
885+
# Subscribe to first channel
886+
pubsub.ssubscribe(channel1)
887+
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
888+
assert msg is not None
889+
assert msg["type"] == "ssubscribe"
890+
891+
# Subscribe to second channel (likely different node)
892+
pubsub.ssubscribe(channel2)
893+
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
894+
assert msg is not None
895+
assert msg["type"] == "ssubscribe"
896+
897+
# Verify both channels are in shard_channels
898+
assert channel1.encode() in pubsub.shard_channels
899+
assert channel2.encode() in pubsub.shard_channels
900+
901+
pubsub.close()
902+
903+
@pytest.mark.onlycluster
904+
@skip_if_server_version_lt("7.0.0")
905+
def test_ssubscribe_multiple_channels_publish_and_read(self, r):
906+
"""
907+
Test publishing to multiple sharded channels and reading messages.
908+
Validates that _sharded_message_generator properly cycles through
909+
multiple node_pubsub_mapping entries.
910+
"""
911+
pubsub = r.pubsub()
912+
channel1 = "test-channel:{0}"
913+
channel2 = "test-channel:{6}"
914+
msg1_data = "message-1"
915+
msg2_data = "message-2"
916+
917+
# Subscribe to both channels
918+
pubsub.ssubscribe(channel1, channel2)
919+
920+
# Read subscription confirmations
921+
for _ in range(2):
922+
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
923+
assert msg is not None
924+
assert msg["type"] == "ssubscribe"
925+
926+
# Publish messages to both channels
927+
r.spublish(channel1, msg1_data)
928+
r.spublish(channel2, msg2_data)
929+
930+
# Read messages - should get both messages
931+
messages = []
932+
for _ in range(2):
933+
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
934+
assert msg is not None
935+
assert msg["type"] == "smessage"
936+
messages.append(msg)
937+
938+
# Verify we got messages from both channels
939+
channels_received = {msg["channel"] for msg in messages}
940+
assert channel1.encode() in channels_received
941+
assert channel2.encode() in channels_received
942+
943+
pubsub.close()
944+
945+
@pytest.mark.onlycluster
946+
@skip_if_server_version_lt("7.0.0")
947+
def test_generator_handles_concurrent_mapping_changes(self, r):
948+
"""
949+
Test that the generator properly handles mapping changes during iteration.
950+
This validates the fix for the RuntimeError: dictionary changed size during iteration.
951+
"""
952+
pubsub = r.pubsub()
953+
channel1 = "test-channel:{0}"
954+
channel2 = "test-channel:{6}"
955+
956+
# Subscribe to first channel
957+
pubsub.ssubscribe(channel1)
958+
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
959+
assert msg is not None
960+
assert msg["type"] == "ssubscribe"
961+
962+
# Get initial mapping size (if available)
963+
initial_size = 0
964+
if hasattr(pubsub, "node_pubsub_mapping"):
965+
initial_size = len(pubsub.node_pubsub_mapping)
966+
967+
# Subscribe to second channel (modifies mapping during potential iteration)
968+
pubsub.ssubscribe(channel2)
969+
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
970+
assert msg is not None
971+
assert msg["type"] == "ssubscribe"
972+
973+
# Verify mapping was updated (if available)
974+
if hasattr(pubsub, "node_pubsub_mapping"):
975+
assert len(pubsub.node_pubsub_mapping) >= initial_size
976+
977+
# Publish and read messages - should not raise RuntimeError
978+
r.spublish(channel1, "msg1")
979+
r.spublish(channel2, "msg2")
980+
981+
messages_received = 0
982+
for _ in range(2):
983+
msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message)
984+
if msg and msg["type"] == "smessage":
985+
messages_received += 1
986+
987+
assert messages_received == 2
988+
pubsub.close()
989+
874990

875991
class TestPubSubPings:
876992
@skip_if_server_version_lt("3.0.0")

0 commit comments

Comments
 (0)