33from __future__ import annotations
44
55from abc import ABC , abstractmethod
6+ from contextlib import AsyncExitStack
67from datetime import datetime , timezone
7- import os
88from types import TracebackType
99import typing as _t
1010
11+ from that_depends import Provide , container_context , inject
12+
13+ from plugboard .exceptions import NotFoundError
1114from plugboard .utils import DI , ExportMixin
1215
1316
@@ -33,13 +36,18 @@ def __init__(
3336 self ._local_state = {"job_id" : job_id , "metadata" : metadata , ** kwargs }
3437 self ._logger = DI .logger .sync_resolve ().bind (cls = self .__class__ .__name__ , job_id = job_id )
3538 self ._logger .info ("StateBackend created" )
39+ self ._ctx = AsyncExitStack ()
3640
3741 async def init (self ) -> None :
3842 """Initialises the `StateBackend`."""
43+ job_id = self ._local_state .pop ("job_id" , None )
44+ container_cm = container_context (global_context = {"job_id" : job_id })
45+ await self ._ctx .enter_async_context (container_cm )
3946 await self ._initialise_data (** self ._local_state )
4047
4148 async def destroy (self ) -> None :
4249 """Destroys the `StateBackend`."""
50+ await self ._ctx .aclose ()
4351 pass
4452
4553 async def __aenter__ (self ) -> StateBackend :
@@ -56,34 +64,23 @@ async def __aexit__(
5664 """Exits the context manager."""
5765 await self .destroy ()
5866
67+ @inject
5968 async def _initialise_data (
60- self , job_id : _t . Optional [ str ] = None , metadata : _t .Optional [dict ] = None , ** kwargs : _t .Any
69+ self , job_id : str = Provide [ DI . job_id ] , metadata : _t .Optional [dict ] = None , ** kwargs : _t .Any
6170 ) -> None :
6271 """Initialises the state data."""
63- if (_job_id := self ._resolve_job_id (job_id )) is not None :
64- job_data = await self ._get_job (_job_id )
65- else :
72+ try :
73+ # TODO : Requires state for if this is a new job to conditionally raise exception?
74+ job_data = await self ._get_job (job_id )
75+ except NotFoundError :
6676 job_data = {
67- "job_id" : DI . job_id . sync_resolve () ,
77+ "job_id" : job_id ,
6878 "created_at" : datetime .now (timezone .utc ).isoformat (),
6979 "metadata" : metadata or dict (),
7080 }
7181 await self ._upsert_job (job_data )
7282 self ._local_state .update (job_data )
7383
74- @staticmethod
75- def _resolve_job_id (job_id : _t .Optional [str ] = None ) -> _t .Optional [str ]:
76- """Resolves the job id from the environment or argument if present."""
77- env_job_id = os .environ .get ("PLUGBOARD_JOB_ID" )
78- if job_id is None :
79- return env_job_id
80- if env_job_id is not None and job_id != env_job_id :
81- raise RuntimeError (
82- f"Job ID { job_id } does not match environment variable PLUGBOARD_JOB_ID={ env_job_id } "
83- )
84- os .environ ["PLUGBOARD_JOB_ID" ] = job_id
85- return job_id
86-
8784 @abstractmethod
8885 async def _get (self , key : str | tuple [str , ...], value : _t .Optional [_t .Any ] = None ) -> _t .Any :
8986 """Returns a value from the state."""
0 commit comments