11from contextlib import asynccontextmanager
22from datetime import timedelta
33from typing import Any , AsyncIterator , Dict , Optional , Sequence , Tuple
4+ import logging
45
56from langchain_core .runnables import RunnableConfig
67from acouchbase .cluster import Cluster as ACluster
78from acouchbase .bucket import Bucket as ABucket
89from couchbase .auth import PasswordAuthenticator
910from couchbase .options import ClusterOptions , QueryOptions , UpsertOptions
11+ from couchbase .exceptions import CollectionAlreadyExistsException
1012
1113from langgraph .checkpoint .base import (
1214 BaseCheckpointSaver ,
1820)
1921from .utils import _encode_binary , _decode_binary
2022
23+ logger = logging .getLogger (__name__ )
24+
2125class AsyncCouchbaseSaver (BaseCheckpointSaver ):
2226 """A checkpoint saver that stores checkpoints in a Couchbase database."""
2327
@@ -35,9 +39,35 @@ def __init__(
3539 self .cluster = cluster
3640 self .bucket_name = bucket_name
3741 self .scope_name = scope_name
42+ self .bucket = self .cluster .bucket (bucket_name )
43+ self .scope = self .bucket .scope (scope_name )
3844 self .checkpoints_collection_name = checkpoints_collection_name
3945 self .checkpoint_writes_collection_name = checkpoint_writes_collection_name
4046
47+ async def create_collections (self ):
48+ """Create collections in the Couchbase bucket if they do not exist."""
49+
50+ collection_manager = self .bucket .collections ()
51+ try :
52+ await collection_manager .create_collection (self .scope_name , self .checkpoints_collection_name )
53+ except CollectionAlreadyExistsException as _ :
54+ pass
55+ except Exception as e :
56+ logger .exception ("Error creating collections" )
57+ raise e
58+ finally :
59+ self .checkpoints_collection = self .bucket .scope (self .scope_name ).collection (self .checkpoints_collection_name )
60+
61+ try :
62+ await collection_manager .create_collection (self .scope_name , self .checkpoint_writes_collection_name )
63+ except CollectionAlreadyExistsException as _ :
64+ pass
65+ except Exception as e :
66+ logger .exception ("Error creating collections" )
67+ raise e
68+ finally :
69+ self .checkpoint_writes_collection = self .bucket .scope (self .scope_name ).collection (self .checkpoint_writes_collection_name )
70+
4171 @classmethod
4272 @asynccontextmanager
4373 async def from_conn_info (
@@ -69,15 +99,25 @@ async def from_conn_info(
6999 cls .bucket_name = bucket_name
70100 cls .scope_name = scope_name
71101
72- saver = AsyncCouchbaseSaver (cluster , bucket_name , scope_name , checkpoints_collection_name , checkpoint_writes_collection_name )
73- cls .bucket = cluster .bucket (bucket_name )
74- await cls .bucket .on_connect ()
102+ bucket = cluster .bucket (bucket_name )
103+ await bucket .on_connect ()
104+
105+ saver = AsyncCouchbaseSaver (
106+ cluster ,
107+ bucket_name ,
108+ scope_name ,
109+ checkpoints_collection_name ,
110+ checkpoint_writes_collection_name ,
111+ )
112+
113+ await saver .create_collections ()
75114
76115 yield saver
77116 finally :
78117 if cluster :
79118 await cluster .close ()
80119
120+
81121 @classmethod
82122 @asynccontextmanager
83123 async def from_cluster (
@@ -98,9 +138,18 @@ async def from_cluster(
98138 AsyncCouchbaseSaver: An instance of the AsyncCouchbaseSaver
99139 """
100140
101- saver = AsyncCouchbaseSaver (cluster , bucket_name , scope_name , checkpoints_collection_name , checkpoint_writes_collection_name )
102- cls .bucket = cluster .bucket (bucket_name )
103- await cls .bucket .on_connect ()
141+ bucket = cluster .bucket (bucket_name )
142+ await bucket .on_connect ()
143+
144+ saver = AsyncCouchbaseSaver (
145+ cluster ,
146+ bucket_name ,
147+ scope_name ,
148+ checkpoints_collection_name ,
149+ checkpoint_writes_collection_name ,
150+ )
151+
152+ await saver .create_collections ()
104153
105154 yield saver
106155
@@ -149,7 +198,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
149198 async for write_doc in serialized_writes_result :
150199 checkpoint_writes = write_doc .get (self .checkpoint_writes_collection_name , {})
151200 if "task_id" not in checkpoint_writes :
152- print ( "Error: 'task_id' is not present in checkpoint_writes" )
201+ logger . warning ( " 'task_id' is not present in checkpoint_writes" )
153202 else :
154203 pending_writes .append (
155204 (
@@ -294,6 +343,8 @@ async def aput(
294343
295344 upsert_key = f"{ thread_id } ::{ checkpoint_ns } ::{ checkpoint_id } "
296345
346+ # ensure bucket connected (idempotent)
347+ await self .bucket .on_connect ()
297348 collection = self .bucket .scope (self .scope_name ).collection (self .checkpoints_collection_name )
298349 await collection .upsert (upsert_key , (doc ), UpsertOptions (timeout = timedelta (seconds = 5 )))
299350
@@ -324,6 +375,7 @@ async def aput_writes(
324375 checkpoint_ns = config ["configurable" ]["checkpoint_ns" ]
325376 checkpoint_id = config ["configurable" ]["checkpoint_id" ]
326377
378+ await self .bucket .on_connect ()
327379 collection = self .bucket .scope (self .scope_name ).collection (self .checkpoint_writes_collection_name )
328380
329381 for idx , (channel , value ) in enumerate (writes ):
0 commit comments