Skip to content

Commit ccd28ef

Browse files
committed
fix: replace keys() with search indexes to prevent Redis CrossSlot errors (#25)
Fixes #25 Replace all redis.keys() pattern matching calls with queries to prevent CROSSSLOT errors with Redis Cluster. Changes: - Replace keys() usage in base.py _load_pending_writes method - Replace keys() usage in aio.py _aload_pending_writes method - Replace keys() usage in __init__.py get_tuple method - Fix checkpoint_ns sentinel value handling for RediSearch tag fields - Fix type attribute access in get_channel_values method - Add proper type annotations and handle blob data encoding
1 parent a21e107 commit ccd28ef

File tree

4 files changed

+280
-61
lines changed

4 files changed

+280
-61
lines changed

langgraph/checkpoint/redis/__init__.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,42 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
354354
CHECKPOINT_WRITE_PREFIX,
355355
)
356356

357-
# Get the blob keys
358-
blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*"
359-
blob_keys = [key.decode() for key in self._redis.keys(blob_key_pattern)]
360-
361-
# Also get checkpoint write keys that should have the same TTL
362-
write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*"
363-
write_keys = [key.decode() for key in self._redis.keys(write_key_pattern)]
357+
# Get the blob keys using search index instead of keys()
358+
blob_query = FilterQuery(
359+
filter_expression=(
360+
Tag("thread_id") == to_storage_safe_id(doc_thread_id)
361+
)
362+
& (Tag("checkpoint_ns") == to_storage_safe_str(doc_checkpoint_ns)),
363+
return_fields=["key"], # Assuming the key field exists in the index
364+
num_results=1000,
365+
)
366+
blob_results = self.checkpoint_blobs_index.search(blob_query)
367+
blob_keys = [
368+
f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{getattr(doc, 'channel', '')}:{getattr(doc, 'version', '')}"
369+
for doc in blob_results.docs
370+
]
371+
372+
# Get checkpoint write keys using search index
373+
write_query = FilterQuery(
374+
filter_expression=(
375+
Tag("thread_id") == to_storage_safe_id(doc_thread_id)
376+
)
377+
& (Tag("checkpoint_ns") == to_storage_safe_str(doc_checkpoint_ns))
378+
& (Tag("checkpoint_id") == to_storage_safe_id(doc_checkpoint_id)),
379+
return_fields=["task_id", "idx"],
380+
num_results=1000,
381+
)
382+
write_results = self.checkpoint_writes_index.search(write_query)
383+
write_keys = [
384+
BaseRedisSaver._make_redis_checkpoint_writes_key(
385+
to_storage_safe_id(doc_thread_id),
386+
to_storage_safe_str(doc_checkpoint_ns),
387+
to_storage_safe_id(doc_checkpoint_id),
388+
getattr(doc, "task_id", ""),
389+
getattr(doc, "idx", 0),
390+
)
391+
for doc in write_results.docs
392+
]
364393

365394
# Apply TTL to checkpoint, blob keys, and write keys
366395
all_related_keys = blob_keys + write_keys
@@ -489,12 +518,15 @@ def get_channel_values(
489518
blob_results = self.checkpoint_blobs_index.search(blob_query)
490519
if blob_results.docs:
491520
blob_doc = blob_results.docs[0]
492-
blob_type = blob_doc.type
521+
blob_type = getattr(blob_doc, "type", None)
493522
blob_data = getattr(blob_doc, "$.blob", None)
494523

495-
if blob_data and blob_type != "empty":
524+
if blob_data and blob_type and blob_type != "empty":
525+
# Ensure blob_data is bytes for deserialization
526+
if isinstance(blob_data, str):
527+
blob_data = blob_data.encode("utf-8")
496528
channel_values[channel] = self.serde.loads_typed(
497-
(blob_type, blob_data)
529+
(str(blob_type), blob_data)
498530
)
499531

500532
return channel_values

langgraph/checkpoint/redis/aio.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -890,30 +890,38 @@ async def _aload_pending_writes(
890890
if checkpoint_id is None:
891891
return [] # Early return if no checkpoint_id
892892

893-
writes_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
894-
to_storage_safe_id(thread_id),
895-
to_storage_safe_str(checkpoint_ns),
896-
to_storage_safe_id(checkpoint_id),
897-
"*",
898-
None,
899-
)
900-
matching_keys = await self._redis.keys(pattern=writes_key)
901-
# Use safely_decode to handle both string and bytes responses
902-
decoded_keys = [safely_decode(key) for key in matching_keys]
903-
parsed_keys = [
904-
BaseRedisSaver._parse_redis_checkpoint_writes_key(key)
905-
for key in decoded_keys
906-
]
907-
pending_writes = BaseRedisSaver._load_writes(
908-
self.serde,
909-
{
910-
(
911-
parsed_key["task_id"],
912-
parsed_key["idx"],
913-
): await self._redis.json().get(key)
914-
for key, parsed_key in sorted(
915-
zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"]
916-
)
917-
},
893+
# Use search index instead of keys() to avoid CrossSlot errors
894+
# Note: For checkpoint_ns, we use the raw value for tag searches
895+
# because RediSearch may not handle sentinel values correctly in tag fields
896+
writes_query = FilterQuery(
897+
filter_expression=(Tag("thread_id") == to_storage_safe_id(thread_id))
898+
& (Tag("checkpoint_ns") == checkpoint_ns)
899+
& (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)),
900+
return_fields=["task_id", "idx", "channel", "type", "$.blob"],
901+
num_results=1000, # Adjust as needed
918902
)
903+
904+
writes_results = await self.checkpoint_writes_index.search(writes_query)
905+
906+
# Sort results by idx to maintain order
907+
sorted_writes = sorted(writes_results.docs, key=lambda x: getattr(x, "idx", 0))
908+
909+
# Build the writes dictionary
910+
writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {}
911+
for doc in sorted_writes:
912+
task_id = str(getattr(doc, "task_id", ""))
913+
idx = str(getattr(doc, "idx", 0))
914+
blob_data = getattr(doc, "$.blob", "")
915+
# Ensure blob is bytes for deserialization
916+
if isinstance(blob_data, str):
917+
blob_data = blob_data.encode("utf-8")
918+
writes_dict[(task_id, idx)] = {
919+
"task_id": task_id,
920+
"idx": idx,
921+
"channel": str(getattr(doc, "channel", "")),
922+
"type": str(getattr(doc, "type", "")),
923+
"blob": blob_data,
924+
}
925+
926+
pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict)
919927
return pending_writes

langgraph/checkpoint/redis/base.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
)
1717
from langgraph.checkpoint.serde.base import SerializerProtocol
1818
from langgraph.checkpoint.serde.types import ChannelProtocol
19+
from redisvl.query import FilterQuery
20+
from redisvl.query.filter import Tag
1921

2022
from langgraph.checkpoint.redis.util import (
2123
safely_decode,
@@ -440,7 +442,7 @@ def put_writes(
440442
type_, blob = self.serde.dumps_typed(value)
441443
write_obj = {
442444
"thread_id": to_storage_safe_id(thread_id),
443-
"checkpoint_ns": to_storage_safe_str(checkpoint_ns),
445+
"checkpoint_ns": checkpoint_ns, # Don't use sentinel for tag fields in RediSearch
444446
"checkpoint_id": to_storage_safe_id(checkpoint_id),
445447
"task_id": task_id,
446448
"task_path": task_path,
@@ -462,7 +464,7 @@ def put_writes(
462464
checkpoint_ns,
463465
checkpoint_id,
464466
task_id,
465-
write_obj["idx"], # type: ignore[arg-type]
467+
write_obj["idx"],
466468
)
467469

468470
# First check if key exists
@@ -499,33 +501,40 @@ def _load_pending_writes(
499501
if checkpoint_id is None:
500502
return [] # Early return if no checkpoint_id
501503

502-
writes_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
503-
to_storage_safe_id(thread_id),
504-
to_storage_safe_str(checkpoint_ns),
505-
to_storage_safe_id(checkpoint_id),
506-
"*",
507-
None,
504+
# Use search index instead of keys() to avoid CrossSlot errors
505+
# Note: For checkpoint_ns, we use the raw value for tag searches
506+
# because RediSearch may not handle sentinel values correctly in tag fields
507+
writes_query = FilterQuery(
508+
filter_expression=(Tag("thread_id") == to_storage_safe_id(thread_id))
509+
& (Tag("checkpoint_ns") == checkpoint_ns)
510+
& (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)),
511+
return_fields=["task_id", "idx", "channel", "type", "$.blob"],
512+
num_results=1000, # Adjust as needed
508513
)
509514

510-
# Cast the result to List[bytes] to help type checker
511-
matching_keys: List[bytes] = self._redis.keys(pattern=writes_key) # type: ignore[assignment]
512-
513-
# Use safely_decode to handle both string and bytes responses
514-
decoded_keys = [safely_decode(key) for key in matching_keys]
515+
writes_results = self.checkpoint_writes_index.search(writes_query)
516+
517+
# Sort results by idx to maintain order
518+
sorted_writes = sorted(writes_results.docs, key=lambda x: getattr(x, "idx", 0))
519+
520+
# Build the writes dictionary
521+
writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {}
522+
for doc in sorted_writes:
523+
task_id = str(getattr(doc, "task_id", ""))
524+
idx = str(getattr(doc, "idx", 0))
525+
blob_data = getattr(doc, "$.blob", "")
526+
# Ensure blob is bytes for deserialization
527+
if isinstance(blob_data, str):
528+
blob_data = blob_data.encode("utf-8")
529+
writes_dict[(task_id, idx)] = {
530+
"task_id": task_id,
531+
"idx": idx,
532+
"channel": str(getattr(doc, "channel", "")),
533+
"type": str(getattr(doc, "type", "")),
534+
"blob": blob_data,
535+
}
515536

516-
parsed_keys = [
517-
BaseRedisSaver._parse_redis_checkpoint_writes_key(key)
518-
for key in decoded_keys
519-
]
520-
pending_writes = BaseRedisSaver._load_writes(
521-
self.serde,
522-
{
523-
(parsed_key["task_id"], parsed_key["idx"]): self._redis.json().get(key)
524-
for key, parsed_key in sorted(
525-
zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"]
526-
)
527-
},
528-
)
537+
pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict)
529538
return pending_writes
530539

531540
@staticmethod

tests/test_crossslot_integration.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
"""Integration tests for CrossSlot error fix in checkpoint operations."""
2+
3+
import pytest
4+
from langchain_core.runnables import RunnableConfig
5+
from langgraph.checkpoint.base import (
6+
Checkpoint,
7+
CheckpointMetadata,
8+
create_checkpoint,
9+
empty_checkpoint,
10+
)
11+
12+
from langgraph.checkpoint.redis import RedisSaver
13+
14+
15+
def test_checkpoint_operations_no_crossslot_errors(redis_url: str) -> None:
16+
"""Test that checkpoint operations work without CrossSlot errors.
17+
18+
This test verifies that the fix for using search indexes instead of keys()
19+
works correctly in a real Redis environment.
20+
"""
21+
# Create a saver
22+
saver = RedisSaver(redis_url)
23+
saver.setup()
24+
25+
# Create test data
26+
thread_id = "test-thread-crossslot"
27+
checkpoint_ns = "test-ns"
28+
29+
# Create checkpoints with unique IDs
30+
checkpoint1 = create_checkpoint(empty_checkpoint(), {}, 1)
31+
checkpoint2 = create_checkpoint(checkpoint1, {"messages": ["hello"]}, 2)
32+
checkpoint3 = create_checkpoint(checkpoint2, {"messages": ["hello", "world"]}, 3)
33+
34+
# Create metadata
35+
metadata1 = {"source": "input", "step": 1, "writes": {"task1": "value1"}}
36+
metadata2 = {"source": "loop", "step": 2, "writes": {"task2": "value2"}}
37+
metadata3 = {"source": "loop", "step": 3, "writes": {"task3": "value3"}}
38+
39+
# Put checkpoints with writes
40+
config1 = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
41+
config2 = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
42+
config3 = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
43+
44+
# Put checkpoints first to get configs with checkpoint_ids
45+
saved_config1 = saver.put(config1, checkpoint1, metadata1, {})
46+
saved_config2 = saver.put(config2, checkpoint2, metadata2, {})
47+
saved_config3 = saver.put(config3, checkpoint3, metadata3, {})
48+
49+
# Add some pending writes using saved configs
50+
saver.put_writes(
51+
saved_config1,
52+
[
53+
("channel1", {"value": "data1"}),
54+
("channel2", {"value": "data2"}),
55+
],
56+
"task-1",
57+
)
58+
59+
# Now test operations that previously used keys() and would fail in cluster mode
60+
61+
# Test 1: Load pending writes (uses _load_pending_writes)
62+
# This should work without CrossSlot errors
63+
tuple1 = saver.get_tuple(saved_config1)
64+
assert tuple1 is not None
65+
# Verify pending writes were loaded
66+
assert len(tuple1.pending_writes) == 2
67+
pending_channels = [w[1] for w in tuple1.pending_writes]
68+
assert "channel1" in pending_channels
69+
assert "channel2" in pending_channels
70+
71+
# Test 2: Get tuple with TTL (uses get_tuple which searches for blob and write keys)
72+
saver_with_ttl = RedisSaver(redis_url, ttl={"checkpoint": 3600})
73+
saver_with_ttl.setup()
74+
75+
# Put a checkpoint with TTL
76+
config_ttl = {
77+
"configurable": {"thread_id": "ttl-thread", "checkpoint_ns": "ttl-ns"}
78+
}
79+
saver_with_ttl.put(config_ttl, checkpoint1, metadata1, {})
80+
81+
# Get the checkpoint - this triggers TTL application which uses key searches
82+
tuple_ttl = saver_with_ttl.get_tuple(config_ttl)
83+
assert tuple_ttl is not None
84+
85+
# Test 3: List checkpoints - this should work without CrossSlot errors
86+
# List returns only the latest checkpoint by default
87+
checkpoints = list(saver.list(config1))
88+
assert len(checkpoints) >= 1
89+
90+
# The latest checkpoint should have the pending writes from checkpoint1
91+
latest_checkpoint = checkpoints[0]
92+
assert len(latest_checkpoint.pending_writes) == 2
93+
94+
# The important part is that all these operations work without CrossSlot errors
95+
# In a Redis cluster, the old keys() based approach would have failed by now
96+
97+
98+
def test_subgraph_checkpoint_operations(redis_url: str) -> None:
99+
"""Test checkpoint operations with subgraphs work without CrossSlot errors."""
100+
saver = RedisSaver(redis_url)
101+
saver.setup()
102+
103+
# Create nested namespace checkpoints
104+
thread_id = "test-thread-subgraph"
105+
106+
# Parent checkpoint
107+
parent_config = {
108+
"configurable": {
109+
"thread_id": thread_id,
110+
"checkpoint_ns": "",
111+
}
112+
}
113+
parent_checkpoint = empty_checkpoint()
114+
parent_metadata = {"source": "input", "step": 1}
115+
116+
# Child checkpoint in subgraph
117+
child_config = {
118+
"configurable": {
119+
"thread_id": thread_id,
120+
"checkpoint_ns": "subgraph1",
121+
}
122+
}
123+
child_checkpoint = create_checkpoint(parent_checkpoint, {"subgraph": "data"}, 1)
124+
child_metadata = {"source": "loop", "step": 1}
125+
126+
# Grandchild checkpoint in nested subgraph
127+
grandchild_config = {
128+
"configurable": {
129+
"thread_id": thread_id,
130+
"checkpoint_ns": "subgraph1:subgraph2",
131+
}
132+
}
133+
grandchild_checkpoint = create_checkpoint(child_checkpoint, {"nested": "data"}, 2)
134+
grandchild_metadata = {"source": "loop", "step": 2}
135+
136+
# Put all checkpoints first to get saved configs
137+
saved_parent_config = saver.put(
138+
parent_config, parent_checkpoint, parent_metadata, {}
139+
)
140+
saved_child_config = saver.put(child_config, child_checkpoint, child_metadata, {})
141+
saved_grandchild_config = saver.put(
142+
grandchild_config, grandchild_checkpoint, grandchild_metadata, {}
143+
)
144+
145+
# Put checkpoints with writes using saved configs
146+
saver.put_writes(
147+
saved_parent_config, [("parent_channel", {"parent": "data"})], "parent-task"
148+
)
149+
saver.put_writes(
150+
saved_child_config, [("child_channel", {"child": "data"})], "child-task"
151+
)
152+
saver.put_writes(
153+
saved_grandchild_config,
154+
[("grandchild_channel", {"grandchild": "data"})],
155+
"grandchild-task",
156+
)
157+
158+
# Test loading checkpoints with pending writes from different namespaces
159+
parent_tuple = saver.get_tuple(parent_config)
160+
assert parent_tuple is not None
161+
162+
child_tuple = saver.get_tuple(child_config)
163+
assert child_tuple is not None
164+
165+
grandchild_tuple = saver.get_tuple(grandchild_config)
166+
assert grandchild_tuple is not None
167+
168+
# List all checkpoints - should work without CrossSlot errors
169+
all_checkpoints = list(saver.list({"configurable": {"thread_id": thread_id}}))
170+
assert len(all_checkpoints) >= 3

0 commit comments

Comments
 (0)