3131from sqlalchemy .ext .asyncio import AsyncEngine
3232from sqlalchemy .ext .asyncio import AsyncSession as DatabaseSessionFactory
3333from sqlalchemy .ext .asyncio import create_async_engine
34- from sqlalchemy .inspection import inspect
3534from sqlalchemy .pool import StaticPool
3635from typing_extensions import override
3736from tzlocal import get_localzone
@@ -275,22 +274,29 @@ async def create_session(
275274 storage_user_state .state = storage_user_state .state | user_state_delta
276275
277276 # Store the session
277+ now = datetime .now (timezone .utc )
278+ is_sqlite = self .db_engine .dialect .name == "sqlite"
279+ if is_sqlite :
280+ now = now .replace (tzinfo = None )
281+
278282 storage_session = schema .StorageSession (
279283 app_name = app_name ,
280284 user_id = user_id ,
281285 id = session_id ,
282286 state = session_state ,
287+ create_time = now ,
288+ update_time = now ,
283289 )
284290 sql_session .add (storage_session )
285291 await sql_session .commit ()
286292
287- await sql_session .refresh (storage_session )
288-
289293 # Merge states for response
290294 merged_state = _merge_state (
291295 storage_app_state .state , storage_user_state .state , session_state
292296 )
293- session = storage_session .to_session (state = merged_state )
297+ session = storage_session .to_session (
298+ state = merged_state , is_sqlite = is_sqlite
299+ )
294300 return session
295301
296302 @override
@@ -350,7 +356,10 @@ async def get_session(
350356
351357 # Convert storage session to session
352358 events = [e .to_event () for e in reversed (storage_events )]
353- session = storage_session .to_session (state = merged_state , events = events )
359+ is_sqlite = self .db_engine .dialect .name == "sqlite"
360+ session = storage_session .to_session (
361+ state = merged_state , events = events , is_sqlite = is_sqlite
362+ )
354363 return session
355364
356365 @override
@@ -393,11 +402,14 @@ async def list_sessions(
393402 user_states_map [storage_user_state .user_id ] = storage_user_state .state
394403
395404 sessions = []
405+ is_sqlite = self .db_engine .dialect .name == "sqlite"
396406 for storage_session in results :
397407 session_state = storage_session .state
398408 user_state = user_states_map .get (storage_session .user_id , {})
399409 merged_state = _merge_state (app_state , user_state , session_state )
400- sessions .append (storage_session .to_session (state = merged_state ))
410+ sessions .append (
411+ storage_session .to_session (state = merged_state , is_sqlite = is_sqlite )
412+ )
401413 return ListSessionsResponse (sessions = sessions )
402414
403415 @override
@@ -433,15 +445,6 @@ async def append_event(self, session: Session, event: Event) -> Event:
433445 schema .StorageSession , (session .app_name , session .user_id , session .id )
434446 )
435447
436- if storage_session .update_timestamp_tz > session .last_update_time :
437- raise ValueError (
438- "The last_update_time provided in the session object"
439- f" { datetime .fromtimestamp (session .last_update_time ):'%Y-%m-%d %H:%M:%S'} is"
440- " earlier than the update_time in the storage_session"
441- f" { datetime .fromtimestamp (storage_session .update_timestamp_tz ):'%Y-%m-%d %H:%M:%S'} ."
442- " Please check if it is a stale session."
443- )
444-
445448 # Fetch states from storage
446449 storage_app_state = await sql_session .get (
447450 schema .StorageAppState , (session .app_name )
@@ -450,6 +453,29 @@ async def append_event(self, session: Session, event: Event) -> Event:
450453 schema .StorageUserState , (session .app_name , session .user_id )
451454 )
452455
456+ is_sqlite = self .db_engine .dialect .name == "sqlite"
457+ if (
458+ storage_session .get_update_timestamp (is_sqlite )
459+ > session .last_update_time
460+ ):
461+ # Reload the session from storage if it has been updated since it was
462+ # loaded.
463+ app_state = storage_app_state .state if storage_app_state else {}
464+ user_state = storage_user_state .state if storage_user_state else {}
465+ session_state = storage_session .state
466+ session .state = _merge_state (app_state , user_state , session_state )
467+
468+ stmt = (
469+ select (schema .StorageEvent )
470+ .filter (schema .StorageEvent .app_name == session .app_name )
471+ .filter (schema .StorageEvent .session_id == session .id )
472+ .filter (schema .StorageEvent .user_id == session .user_id )
473+ .order_by (schema .StorageEvent .timestamp .asc ())
474+ )
475+ result = await sql_session .stream_scalars (stmt )
476+ storage_events = [e async for e in result ]
477+ session .events = [e .to_event () for e in storage_events ]
478+
453479 # Extract state delta
454480 if event .actions and event .actions .state_delta :
455481 state_deltas = _session_util .extract_state_delta (
@@ -466,7 +492,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
466492 if session_state_delta :
467493 storage_session .state = storage_session .state | session_state_delta
468494
469- if storage_session . _dialect_name == "sqlite" :
495+ if is_sqlite :
470496 update_time = datetime .fromtimestamp (
471497 event .timestamp , timezone .utc
472498 ).replace (tzinfo = None )
@@ -476,10 +502,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
476502 sql_session .add (schema .StorageEvent .from_event (session , event ))
477503
478504 await sql_session .commit ()
479- await sql_session .refresh (storage_session )
480505
481506 # Update timestamp with commit time
482- session .last_update_time = storage_session .update_timestamp_tz
507+ session .last_update_time = storage_session .get_update_timestamp ( is_sqlite )
483508
484509 # Also update the in-memory session
485510 await super ().append_event (session = session , event = event )
0 commit comments