Skip to content

fix: cleanup blobs and writes for shallow classes #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions langgraph/checkpoint/redis/ashallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import json
import os
from contextlib import asynccontextmanager
from functools import partial
from types import TracebackType
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast

Expand All @@ -25,7 +24,6 @@
from redisvl.index import AsyncSearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Num, Tag
from redisvl.redis.connection import RedisConnectionFactory

from langgraph.checkpoint.redis.base import (
CHECKPOINT_BLOB_PREFIX,
Expand All @@ -34,6 +32,10 @@
REDIS_KEY_SEPARATOR,
BaseRedisSaver,
)
from langgraph.checkpoint.redis.util import (
to_storage_safe_id,
to_storage_safe_str,
)

SCHEMAS = [
{
Expand Down Expand Up @@ -794,26 +796,38 @@ def put_writes(
@staticmethod
def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str:
"""Create a key for shallow checkpoints using only thread_id and checkpoint_ns."""
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns])
return REDIS_KEY_SEPARATOR.join(
[
CHECKPOINT_PREFIX,
str(to_storage_safe_id(thread_id)),
to_storage_safe_str(checkpoint_ns),
]
)

@staticmethod
def _make_shallow_redis_checkpoint_blob_key_pattern(
thread_id: str, checkpoint_ns: str
) -> str:
"""Create a pattern to match all blob keys for a thread and namespace."""
return (
REDIS_KEY_SEPARATOR.join([CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns])
+ ":*"
return REDIS_KEY_SEPARATOR.join(
[
CHECKPOINT_BLOB_PREFIX,
str(to_storage_safe_id(thread_id)),
to_storage_safe_str(checkpoint_ns),
"*",
]
)

@staticmethod
def _make_shallow_redis_checkpoint_writes_key_pattern(
thread_id: str, checkpoint_ns: str
) -> str:
"""Create a pattern to match all writes keys for a thread and namespace."""
return (
REDIS_KEY_SEPARATOR.join(
[CHECKPOINT_WRITE_PREFIX, thread_id, checkpoint_ns]
)
+ ":*"
return REDIS_KEY_SEPARATOR.join(
[
CHECKPOINT_WRITE_PREFIX,
str(to_storage_safe_id(thread_id)),
to_storage_safe_str(checkpoint_ns),
"*",
]
)
39 changes: 23 additions & 16 deletions langgraph/checkpoint/redis/shallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
REDIS_KEY_SEPARATOR,
BaseRedisSaver,
)
from langgraph.checkpoint.redis.util import (
to_storage_safe_id,
to_storage_safe_str,
)

SCHEMAS = [
{
Expand Down Expand Up @@ -688,35 +692,38 @@ def _load_pending_sends(
@staticmethod
def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str:
"""Create a key for shallow checkpoints using only thread_id and checkpoint_ns."""
return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns])

@staticmethod
def _make_shallow_redis_checkpoint_blob_key(
thread_id: str, checkpoint_ns: str, channel: str
) -> str:
"""Create a key for a blob in a shallow checkpoint."""
return REDIS_KEY_SEPARATOR.join(
[CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns, channel]
[
CHECKPOINT_PREFIX,
str(to_storage_safe_id(thread_id)),
to_storage_safe_str(checkpoint_ns),
]
)

@staticmethod
def _make_shallow_redis_checkpoint_blob_key_pattern(
thread_id: str, checkpoint_ns: str
) -> str:
"""Create a pattern to match all blob keys for a thread and namespace."""
return (
REDIS_KEY_SEPARATOR.join([CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns])
+ ":*"
return REDIS_KEY_SEPARATOR.join(
[
CHECKPOINT_BLOB_PREFIX,
str(to_storage_safe_id(thread_id)),
to_storage_safe_str(checkpoint_ns),
"*",
]
)

@staticmethod
def _make_shallow_redis_checkpoint_writes_key_pattern(
thread_id: str, checkpoint_ns: str
) -> str:
"""Create a pattern to match all writes keys for a thread and namespace."""
return (
REDIS_KEY_SEPARATOR.join(
[CHECKPOINT_WRITE_PREFIX, thread_id, checkpoint_ns]
)
+ ":*"
return REDIS_KEY_SEPARATOR.join(
[
CHECKPOINT_WRITE_PREFIX,
str(to_storage_safe_id(thread_id)),
to_storage_safe_str(checkpoint_ns),
"*",
]
)
16 changes: 13 additions & 3 deletions tests/test_shallow_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from redis.exceptions import ConnectionError as RedisConnectionError

from langgraph.checkpoint.redis.ashallow import AsyncShallowRedisSaver
from langgraph.checkpoint.redis.base import CHECKPOINT_BLOB_PREFIX


@pytest.fixture
Expand Down Expand Up @@ -96,7 +97,10 @@ async def test_only_latest_checkpoint(
}
)
checkpoint_1 = test_data["checkpoints"][0]
await saver.aput(config_1, checkpoint_1, test_data["metadata"][0], {})
channel_versions_1 = {"test_channel": "1"}
await saver.aput(
config_1, checkpoint_1, test_data["metadata"][0], channel_versions_1
)

# Create second checkpoint
config_2 = RunnableConfig(
Expand All @@ -108,13 +112,19 @@ async def test_only_latest_checkpoint(
}
)
checkpoint_2 = test_data["checkpoints"][1]
await saver.aput(config_2, checkpoint_2, test_data["metadata"][1], {})
channel_versions_2 = {"test_channel": "2"}
await saver.aput(
config_2, checkpoint_2, test_data["metadata"][1], channel_versions_2
)

# Verify only latest checkpoint exists
# Verify only latest checkpoint and blobs exists
results = [c async for c in saver.alist(None)]
assert len(results) == 1
assert results[0].config["configurable"]["checkpoint_id"] == checkpoint_2["id"]

blobs = list(await saver._redis.keys(CHECKPOINT_BLOB_PREFIX + ":*"))
assert len(blobs) == 1


@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down
12 changes: 9 additions & 3 deletions tests/test_shallow_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from redis import Redis
from redis.exceptions import ConnectionError as RedisConnectionError

from langgraph.checkpoint.redis.base import CHECKPOINT_BLOB_PREFIX
from langgraph.checkpoint.redis.shallow import ShallowRedisSaver


Expand Down Expand Up @@ -102,7 +103,8 @@ def test_only_latest_checkpoint(
}
}
checkpoint_1 = test_data["checkpoints"][0]
saver.put(config_1, checkpoint_1, test_data["metadata"][0], {})
channel_versions_1 = {"test_channel": "1"}
saver.put(config_1, checkpoint_1, test_data["metadata"][0], channel_versions_1)

# Create second checkpoint
config_2 = {
Expand All @@ -112,13 +114,17 @@ def test_only_latest_checkpoint(
}
}
checkpoint_2 = test_data["checkpoints"][1]
saver.put(config_2, checkpoint_2, test_data["metadata"][1], {})
channel_versions_2 = {"test_channel": "2"}
saver.put(config_2, checkpoint_2, test_data["metadata"][1], channel_versions_2)

# Verify only latest checkpoint exists
# Verify only latest checkpoint and blobs exists
results = list(saver.list(None))
assert len(results) == 1
assert results[0].config["configurable"]["checkpoint_id"] == checkpoint_2["id"]

blobs = list(saver._redis.keys(CHECKPOINT_BLOB_PREFIX + ":*"))
assert len(blobs) == 1


@pytest.mark.parametrize(
"query, expected_count",
Expand Down