Skip to content

fix: replace keys() with search indexes to prevent Redis CrossSlot errors (#25) #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,42 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
CHECKPOINT_WRITE_PREFIX,
)

# Get the blob keys
blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*"
blob_keys = [key.decode() for key in self._redis.keys(blob_key_pattern)]

# Also get checkpoint write keys that should have the same TTL
write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*"
write_keys = [key.decode() for key in self._redis.keys(write_key_pattern)]
# Get the blob keys using search index instead of keys()
blob_query = FilterQuery(
filter_expression=(
Tag("thread_id") == to_storage_safe_id(doc_thread_id)
)
& (Tag("checkpoint_ns") == to_storage_safe_str(doc_checkpoint_ns)),
return_fields=["key"], # Assuming the key field exists in the index
num_results=1000,
)
blob_results = self.checkpoint_blobs_index.search(blob_query)
blob_keys = [
f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{getattr(doc, 'channel', '')}:{getattr(doc, 'version', '')}"
for doc in blob_results.docs
]

# Get checkpoint write keys using search index
write_query = FilterQuery(
filter_expression=(
Tag("thread_id") == to_storage_safe_id(doc_thread_id)
)
& (Tag("checkpoint_ns") == to_storage_safe_str(doc_checkpoint_ns))
& (Tag("checkpoint_id") == to_storage_safe_id(doc_checkpoint_id)),
return_fields=["task_id", "idx"],
num_results=1000,
)
write_results = self.checkpoint_writes_index.search(write_query)
write_keys = [
BaseRedisSaver._make_redis_checkpoint_writes_key(
to_storage_safe_id(doc_thread_id),
to_storage_safe_str(doc_checkpoint_ns),
to_storage_safe_id(doc_checkpoint_id),
getattr(doc, "task_id", ""),
getattr(doc, "idx", 0),
)
for doc in write_results.docs
]

# Apply TTL to checkpoint, blob keys, and write keys
all_related_keys = blob_keys + write_keys
Expand Down Expand Up @@ -489,12 +518,15 @@ def get_channel_values(
blob_results = self.checkpoint_blobs_index.search(blob_query)
if blob_results.docs:
blob_doc = blob_results.docs[0]
blob_type = blob_doc.type
blob_type = getattr(blob_doc, "type", None)
blob_data = getattr(blob_doc, "$.blob", None)

if blob_data and blob_type != "empty":
if blob_data and blob_type and blob_type != "empty":
# Ensure blob_data is bytes for deserialization
if isinstance(blob_data, str):
blob_data = blob_data.encode("utf-8")
channel_values[channel] = self.serde.loads_typed(
(blob_type, blob_data)
(str(blob_type), blob_data)
)

return channel_values
Expand Down
58 changes: 33 additions & 25 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,30 +890,38 @@ async def _aload_pending_writes(
if checkpoint_id is None:
return [] # Early return if no checkpoint_id

writes_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
to_storage_safe_id(thread_id),
to_storage_safe_str(checkpoint_ns),
to_storage_safe_id(checkpoint_id),
"*",
None,
)
matching_keys = await self._redis.keys(pattern=writes_key)
# Use safely_decode to handle both string and bytes responses
decoded_keys = [safely_decode(key) for key in matching_keys]
parsed_keys = [
BaseRedisSaver._parse_redis_checkpoint_writes_key(key)
for key in decoded_keys
]
pending_writes = BaseRedisSaver._load_writes(
self.serde,
{
(
parsed_key["task_id"],
parsed_key["idx"],
): await self._redis.json().get(key)
for key, parsed_key in sorted(
zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"]
)
},
# Use search index instead of keys() to avoid CrossSlot errors
# Note: For checkpoint_ns, we use the raw value for tag searches
# because RediSearch may not handle sentinel values correctly in tag fields
writes_query = FilterQuery(
filter_expression=(Tag("thread_id") == to_storage_safe_id(thread_id))
& (Tag("checkpoint_ns") == checkpoint_ns)
& (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)),
return_fields=["task_id", "idx", "channel", "type", "$.blob"],
num_results=1000, # Adjust as needed
)

writes_results = await self.checkpoint_writes_index.search(writes_query)

# Sort results by idx to maintain order
sorted_writes = sorted(writes_results.docs, key=lambda x: getattr(x, "idx", 0))

# Build the writes dictionary
writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {}
for doc in sorted_writes:
task_id = str(getattr(doc, "task_id", ""))
idx = str(getattr(doc, "idx", 0))
blob_data = getattr(doc, "$.blob", "")
# Ensure blob is bytes for deserialization
if isinstance(blob_data, str):
blob_data = blob_data.encode("utf-8")
writes_dict[(task_id, idx)] = {
"task_id": task_id,
"idx": idx,
"channel": str(getattr(doc, "channel", "")),
"type": str(getattr(doc, "type", "")),
"blob": blob_data,
}

pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict)
return pending_writes
61 changes: 35 additions & 26 deletions langgraph/checkpoint/redis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.types import ChannelProtocol
from redisvl.query import FilterQuery
from redisvl.query.filter import Tag

from langgraph.checkpoint.redis.util import (
safely_decode,
Expand Down Expand Up @@ -440,7 +442,7 @@ def put_writes(
type_, blob = self.serde.dumps_typed(value)
write_obj = {
"thread_id": to_storage_safe_id(thread_id),
"checkpoint_ns": to_storage_safe_str(checkpoint_ns),
"checkpoint_ns": checkpoint_ns, # Don't use sentinel for tag fields in RediSearch
"checkpoint_id": to_storage_safe_id(checkpoint_id),
"task_id": task_id,
"task_path": task_path,
Expand All @@ -462,7 +464,7 @@ def put_writes(
checkpoint_ns,
checkpoint_id,
task_id,
write_obj["idx"], # type: ignore[arg-type]
write_obj["idx"],
)

# First check if key exists
Expand Down Expand Up @@ -499,33 +501,40 @@ def _load_pending_writes(
if checkpoint_id is None:
return [] # Early return if no checkpoint_id

writes_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
to_storage_safe_id(thread_id),
to_storage_safe_str(checkpoint_ns),
to_storage_safe_id(checkpoint_id),
"*",
None,
# Use search index instead of keys() to avoid CrossSlot errors
# Note: For checkpoint_ns, we use the raw value for tag searches
# because RediSearch may not handle sentinel values correctly in tag fields
writes_query = FilterQuery(
filter_expression=(Tag("thread_id") == to_storage_safe_id(thread_id))
& (Tag("checkpoint_ns") == checkpoint_ns)
& (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)),
return_fields=["task_id", "idx", "channel", "type", "$.blob"],
num_results=1000, # Adjust as needed
)

# Cast the result to List[bytes] to help type checker
matching_keys: List[bytes] = self._redis.keys(pattern=writes_key) # type: ignore[assignment]

# Use safely_decode to handle both string and bytes responses
decoded_keys = [safely_decode(key) for key in matching_keys]
writes_results = self.checkpoint_writes_index.search(writes_query)

# Sort results by idx to maintain order
sorted_writes = sorted(writes_results.docs, key=lambda x: getattr(x, "idx", 0))

# Build the writes dictionary
writes_dict: Dict[Tuple[str, str], Dict[str, Any]] = {}
for doc in sorted_writes:
task_id = str(getattr(doc, "task_id", ""))
idx = str(getattr(doc, "idx", 0))
blob_data = getattr(doc, "$.blob", "")
# Ensure blob is bytes for deserialization
if isinstance(blob_data, str):
blob_data = blob_data.encode("utf-8")
writes_dict[(task_id, idx)] = {
"task_id": task_id,
"idx": idx,
"channel": str(getattr(doc, "channel", "")),
"type": str(getattr(doc, "type", "")),
"blob": blob_data,
}

parsed_keys = [
BaseRedisSaver._parse_redis_checkpoint_writes_key(key)
for key in decoded_keys
]
pending_writes = BaseRedisSaver._load_writes(
self.serde,
{
(parsed_key["task_id"], parsed_key["idx"]): self._redis.json().get(key)
for key, parsed_key in sorted(
zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"]
)
},
)
pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict)
return pending_writes

@staticmethod
Expand Down
170 changes: 170 additions & 0 deletions tests/test_crossslot_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Integration tests for CrossSlot error fix in checkpoint operations."""

import pytest
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Checkpoint,
CheckpointMetadata,
create_checkpoint,
empty_checkpoint,
)

from langgraph.checkpoint.redis import RedisSaver


def test_checkpoint_operations_no_crossslot_errors(redis_url: str) -> None:
"""Test that checkpoint operations work without CrossSlot errors.
This test verifies that the fix for using search indexes instead of keys()
works correctly in a real Redis environment.
"""
# Create a saver
saver = RedisSaver(redis_url)
saver.setup()

# Create test data
thread_id = "test-thread-crossslot"
checkpoint_ns = "test-ns"

# Create checkpoints with unique IDs
checkpoint1 = create_checkpoint(empty_checkpoint(), {}, 1)
checkpoint2 = create_checkpoint(checkpoint1, {"messages": ["hello"]}, 2)
checkpoint3 = create_checkpoint(checkpoint2, {"messages": ["hello", "world"]}, 3)

# Create metadata
metadata1 = {"source": "input", "step": 1, "writes": {"task1": "value1"}}
metadata2 = {"source": "loop", "step": 2, "writes": {"task2": "value2"}}
metadata3 = {"source": "loop", "step": 3, "writes": {"task3": "value3"}}

# Put checkpoints with writes
config1 = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
config2 = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
config3 = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}

# Put checkpoints first to get configs with checkpoint_ids
saved_config1 = saver.put(config1, checkpoint1, metadata1, {})
saved_config2 = saver.put(config2, checkpoint2, metadata2, {})
saved_config3 = saver.put(config3, checkpoint3, metadata3, {})

# Add some pending writes using saved configs
saver.put_writes(
saved_config1,
[
("channel1", {"value": "data1"}),
("channel2", {"value": "data2"}),
],
"task-1",
)

# Now test operations that previously used keys() and would fail in cluster mode

# Test 1: Load pending writes (uses _load_pending_writes)
# This should work without CrossSlot errors
tuple1 = saver.get_tuple(saved_config1)
assert tuple1 is not None
# Verify pending writes were loaded
assert len(tuple1.pending_writes) == 2
pending_channels = [w[1] for w in tuple1.pending_writes]
assert "channel1" in pending_channels
assert "channel2" in pending_channels

# Test 2: Get tuple with TTL (uses get_tuple which searches for blob and write keys)
saver_with_ttl = RedisSaver(redis_url, ttl={"checkpoint": 3600})
saver_with_ttl.setup()

# Put a checkpoint with TTL
config_ttl = {
"configurable": {"thread_id": "ttl-thread", "checkpoint_ns": "ttl-ns"}
}
saver_with_ttl.put(config_ttl, checkpoint1, metadata1, {})

# Get the checkpoint - this triggers TTL application which uses key searches
tuple_ttl = saver_with_ttl.get_tuple(config_ttl)
assert tuple_ttl is not None

# Test 3: List checkpoints - this should work without CrossSlot errors
# List returns only the latest checkpoint by default
checkpoints = list(saver.list(config1))
assert len(checkpoints) >= 1

# The latest checkpoint should have the pending writes from checkpoint1
latest_checkpoint = checkpoints[0]
assert len(latest_checkpoint.pending_writes) == 2

# The important part is that all these operations work without CrossSlot errors
# In a Redis cluster, the old keys() based approach would have failed by now


def test_subgraph_checkpoint_operations(redis_url: str) -> None:
"""Test checkpoint operations with subgraphs work without CrossSlot errors."""
saver = RedisSaver(redis_url)
saver.setup()

# Create nested namespace checkpoints
thread_id = "test-thread-subgraph"

# Parent checkpoint
parent_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
}
}
parent_checkpoint = empty_checkpoint()
parent_metadata = {"source": "input", "step": 1}

# Child checkpoint in subgraph
child_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "subgraph1",
}
}
child_checkpoint = create_checkpoint(parent_checkpoint, {"subgraph": "data"}, 1)
child_metadata = {"source": "loop", "step": 1}

# Grandchild checkpoint in nested subgraph
grandchild_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "subgraph1:subgraph2",
}
}
grandchild_checkpoint = create_checkpoint(child_checkpoint, {"nested": "data"}, 2)
grandchild_metadata = {"source": "loop", "step": 2}

# Put all checkpoints first to get saved configs
saved_parent_config = saver.put(
parent_config, parent_checkpoint, parent_metadata, {}
)
saved_child_config = saver.put(child_config, child_checkpoint, child_metadata, {})
saved_grandchild_config = saver.put(
grandchild_config, grandchild_checkpoint, grandchild_metadata, {}
)

# Put checkpoints with writes using saved configs
saver.put_writes(
saved_parent_config, [("parent_channel", {"parent": "data"})], "parent-task"
)
saver.put_writes(
saved_child_config, [("child_channel", {"child": "data"})], "child-task"
)
saver.put_writes(
saved_grandchild_config,
[("grandchild_channel", {"grandchild": "data"})],
"grandchild-task",
)

# Test loading checkpoints with pending writes from different namespaces
parent_tuple = saver.get_tuple(parent_config)
assert parent_tuple is not None

child_tuple = saver.get_tuple(child_config)
assert child_tuple is not None

grandchild_tuple = saver.get_tuple(grandchild_config)
assert grandchild_tuple is not None

# List all checkpoints - should work without CrossSlot errors
all_checkpoints = list(saver.list({"configurable": {"thread_id": thread_id}}))
assert len(all_checkpoints) >= 3