Skip to content

Commit 9cf6085

Browse files
committed
Refacors state handling to use that-depends context over env vars
1 parent b5e86a3 commit 9cf6085

4 files changed

Lines changed: 41 additions & 55 deletions

File tree

plugboard/cli/process/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Plugboard Process CLI."""
22

33
import asyncio
4-
import os
54
from pathlib import Path
65

76
import msgspec
@@ -71,8 +70,7 @@ def run(
7170
config_spec = _read_yaml(config)
7271

7372
if job_id:
74-
# Override job ID in env and config file if set
75-
os.environ["PLUGBOARD_JOB_ID"] = job_id
73+
# Override job ID in config file if set
7674
config_spec.plugboard.process.args.state.args.job_id = job_id
7775

7876
with Progress(

plugboard/process/process_builder.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Provides `ProcessBuilder` to build `Process` objects."""
22

3-
import os
43
from pydoc import locate
54
import typing as _t
65

@@ -9,7 +8,7 @@
98
from plugboard.connector.connector_builder import ConnectorBuilder
109
from plugboard.events.event_connector_builder import EventConnectorBuilder
1110
from plugboard.process.process import Process
12-
from plugboard.schemas import ProcessSpec, StateBackendSpec
11+
from plugboard.schemas import ProcessSpec
1312
from plugboard.state import StateBackend
1413
from plugboard.utils import DI
1514

@@ -48,28 +47,8 @@ def _build_statebackend(cls, spec: ProcessSpec) -> StateBackend:
4847
statebackend_class: _t.Optional[_t.Any] = locate(state_spec.type)
4948
if not statebackend_class or not issubclass(statebackend_class, StateBackend):
5049
raise ValueError(f"StateBackend class {spec.args.state.type} not found.")
51-
cls._handle_job_id(state_spec)
5250
return statebackend_class(**dict(spec.args.state.args))
5351

54-
@staticmethod
55-
def _handle_job_id(state_spec: StateBackendSpec) -> None:
56-
"""Handle job ID for the state backend.
57-
58-
If a job ID is provided in the state spec, it will be set as an environment variable.
59-
If the job ID is already set in the environment, it will be checked against the one in the
60-
state spec. If they do not match, a RuntimeError will be raised.
61-
"""
62-
if state_spec.args.job_id is None:
63-
return
64-
if (
65-
env_job_id := os.environ.get("PLUGBOARD_JOB_ID")
66-
) is not None and env_job_id != state_spec.args.job_id:
67-
raise RuntimeError(
68-
f"Job ID {state_spec.args.job_id} does not match environment variable "
69-
f"PLUGBOARD_JOB_ID={env_job_id}"
70-
)
71-
os.environ["PLUGBOARD_JOB_ID"] = state_spec.args.job_id
72-
7352
@classmethod
7453
def _build_components(cls, spec: ProcessSpec) -> list[Component]:
7554
for c in spec.args.components:

plugboard/state/state_backend.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6+
from contextlib import AsyncExitStack
67
from datetime import datetime, timezone
7-
import os
88
from types import TracebackType
99
import typing as _t
1010

11+
from that_depends import Provide, container_context, inject
12+
13+
from plugboard.exceptions import NotFoundError
1114
from plugboard.utils import DI, ExportMixin
1215

1316

@@ -33,13 +36,18 @@ def __init__(
3336
self._local_state = {"job_id": job_id, "metadata": metadata, **kwargs}
3437
self._logger = DI.logger.sync_resolve().bind(cls=self.__class__.__name__, job_id=job_id)
3538
self._logger.info("StateBackend created")
39+
self._ctx = AsyncExitStack()
3640

3741
async def init(self) -> None:
3842
"""Initialises the `StateBackend`."""
43+
job_id = self._local_state.pop("job_id", None)
44+
container_cm = container_context(global_context={"job_id": job_id})
45+
await self._ctx.enter_async_context(container_cm)
3946
await self._initialise_data(**self._local_state)
4047

4148
async def destroy(self) -> None:
4249
"""Destroys the `StateBackend`."""
50+
await self._ctx.aclose()
4351
pass
4452

4553
async def __aenter__(self) -> StateBackend:
@@ -56,34 +64,23 @@ async def __aexit__(
5664
"""Exits the context manager."""
5765
await self.destroy()
5866

67+
@inject
5968
async def _initialise_data(
60-
self, job_id: _t.Optional[str] = None, metadata: _t.Optional[dict] = None, **kwargs: _t.Any
69+
self, job_id: str = Provide[DI.job_id], metadata: _t.Optional[dict] = None, **kwargs: _t.Any
6170
) -> None:
6271
"""Initialises the state data."""
63-
if (_job_id := self._resolve_job_id(job_id)) is not None:
64-
job_data = await self._get_job(_job_id)
65-
else:
72+
try:
73+
# TODO : Requires state for if this is a new job to conditionally raise exception?
74+
job_data = await self._get_job(job_id)
75+
except NotFoundError:
6676
job_data = {
67-
"job_id": DI.job_id.sync_resolve(),
77+
"job_id": job_id,
6878
"created_at": datetime.now(timezone.utc).isoformat(),
6979
"metadata": metadata or dict(),
7080
}
7181
await self._upsert_job(job_data)
7282
self._local_state.update(job_data)
7383

74-
@staticmethod
75-
def _resolve_job_id(job_id: _t.Optional[str] = None) -> _t.Optional[str]:
76-
"""Resolves the job id from the environment or argument if present."""
77-
env_job_id = os.environ.get("PLUGBOARD_JOB_ID")
78-
if job_id is None:
79-
return env_job_id
80-
if env_job_id is not None and job_id != env_job_id:
81-
raise RuntimeError(
82-
f"Job ID {job_id} does not match environment variable PLUGBOARD_JOB_ID={env_job_id}"
83-
)
84-
os.environ["PLUGBOARD_JOB_ID"] = job_id
85-
return job_id
86-
8784
@abstractmethod
8885
async def _get(self, key: str | tuple[str, ...], value: _t.Optional[_t.Any] = None) -> _t.Any:
8986
"""Returns a value from the state."""

plugboard/utils/di.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import aio_pika
88
import structlog
9-
from that_depends import BaseContainer
10-
from that_depends.providers import Resource, Singleton
9+
from that_depends import BaseContainer, fetch_context_item
10+
from that_depends.providers import ContextResource, Resource, Singleton
1111
from yarl import URL
1212

1313
from plugboard._zmq.zmq_proxy import ZMQProxy
@@ -59,17 +59,29 @@ async def _rabbitmq_conn(
5959
await conn.close() # pragma: no cover
6060

6161

62-
def _job_id() -> str:
62+
def _job_id() -> _t.Iterator[str]:
6363
"""Returns a job ID which uniquely identifies the current plugboard run.
6464
65-
If a job ID is set in the environment variable `PLUGBOARD_JOB_ID`, it will be used.
66-
Otherwise, a new unique job ID will be generated and set in the environment.
65+
If a job ID is available in the context (from the cli, the state spec, or an argument to the
66+
StateBackend), it will take precedence. If the job ID is set in the env var `PLUGBOARD_JOB_ID`,
67+
it will be checked against the one in the context, if present. If they do not match, a
68+
RuntimeError will be raised. If the job ID is not set in the context or the env var, a new
69+
unique job ID will be generated.
6770
"""
68-
# TODO : Should the env var be unset on DI teardown? Consider notebook execution.
69-
# : Where multiple process runs may be executed from the same os process.
70-
if (job_id := os.environ.get("PLUGBOARD_JOB_ID")) is None:
71-
os.environ["PLUGBOARD_JOB_ID"] = job_id = EntityIdGen.job_id()
72-
return job_id
71+
arg_job_id = fetch_context_item("job_id")
72+
env_job_id = os.environ.get("PLUGBOARD_JOB_ID")
73+
if arg_job_id is not None:
74+
if env_job_id is not None and arg_job_id != env_job_id:
75+
raise RuntimeError(
76+
f"Job ID {arg_job_id} does not match environment variable "
77+
f"PLUGBOARD_JOB_ID={env_job_id}"
78+
)
79+
job_id = arg_job_id
80+
elif env_job_id is not None:
81+
job_id = env_job_id
82+
else:
83+
job_id = EntityIdGen.job_id()
84+
yield job_id
7385

7486

7587
class DI(BaseContainer):
@@ -84,4 +96,4 @@ class DI(BaseContainer):
8496
rabbitmq_conn: Resource[aio_pika.abc.AbstractRobustConnection] = Resource(
8597
_rabbitmq_conn, logger, url=settings.rabbitmq.url
8698
)
87-
job_id: Singleton[str] = Singleton(_job_id)
99+
job_id: ContextResource[str] = ContextResource(_job_id)

0 commit comments

Comments
 (0)