diff --git a/api/pyproject.toml b/api/pyproject.toml index 734d2036..44f14f2c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "damnit~=0.2.1", "aiohttp~=3.11", "aiofiles>=24.1.0", + "aiokafka>=0.13", "anyio>=4.11.0", "python-ulid[pydantic]>=3.1.0", "sqlmodel>=0.0.31", diff --git a/api/src/damnit_api/db.py b/api/src/damnit_api/db.py index d3d06d1c..d2301743 100644 --- a/api/src/damnit_api/db.py +++ b/api/src/damnit_api/db.py @@ -199,6 +199,31 @@ async def async_variable_tags(proposal): return variable_tags +async def async_config_value(proposal, key): + metameta = await async_table(proposal, name="metameta") + selection = select(metameta.c.value).where(metameta.c.key == key) + async with get_session(proposal) as session: + result = await session.execute(selection) + return result.scalar() + + +async def async_changed_values(proposal, data_proposal, run, values): + run_vars = await async_table(proposal, name="run_variables") + + selection = ( + select(run_vars) + .where( + run_vars.c.proposal == data_proposal, + run_vars.c.run == run, + run_vars.c.run.in_(values) + ) + ) + + async with get_session(proposal) as session: + result = await session.execute(selection) + return result.mappings().all() # FIX: # pyright: ignore[reportReturnType] + + # ----------------------------------------------------------------------------- # Etc. diff --git a/api/src/damnit_api/graphql/subscriptions.py b/api/src/damnit_api/graphql/subscriptions.py index 4212fb22..d59237d2 100644 --- a/api/src/damnit_api/graphql/subscriptions.py +++ b/api/src/damnit_api/graphql/subscriptions.py @@ -1,18 +1,87 @@ import asyncio +import json from collections.abc import AsyncGenerator import strawberry +from aiokafka import AIOKafkaConsumer from async_lru import alru_cache from strawberry.scalars import JSON from strawberry.types import Info -from ..db import async_latest_rows, async_table, async_variables +from ..db import ( + async_latest_rows, + async_table, + async_variables, + async_config_value, + async_changed_values, +) from ..utils import create_map, wrap_values from .models import Timestamp, get_model from .utils import DatabaseInput, LatestData, fetch_info POLLING_INTERVAL = 1 # seconds +KAFKA_BROKERS = ['exflwgs06.desy.de:9091'] +KAFKA_UPDATE_TOPIC = "test.damnit.db-{}" + + +class DBWatcher: + # The registry holds 1 DBWatcher per proposal, so we retrieve new values + # from the database once, even if several clients are subscribed. + registry: dict[str, 'DBWatcher'] = {} + + @classmethod + async def queue_for_proposal(cls, proposal, schema): + q = asyncio.Queue() + if (watcher := cls.registry.get(proposal)) is None: + watcher = cls(proposal, schema) + watcher.task = asyncio.create_task(watcher.run()) + watcher.subscriptions.add(q) + return q + + @classmethod + def drop_queue(cls, proposal, queue): + if (watcher := cls.registry.get(proposal)) is not None: + watcher.subscriptions.discard(queue) + if not watcher.subscriptions: + del cls.registry[proposal] + if watcher.task is not None: + watcher.task.cancel() + + def __init__(self, proposal, schema): + self.proposal = proposal + self.schema = schema + self.subscriptions = set() + self.task: asyncio.Task | None = None + + def notify(self, d: dict): + for q in self.subscriptions: + q.put_nowait(d) + + async def run(self): + db_id = await async_config_value(self.proposal, "db_id") + update_topic = KAFKA_UPDATE_TOPIC.format(db_id) + + consumer = AIOKafkaConsumer(update_topic, bootstrap_servers=KAFKA_BROKERS) + await consumer.start() + try: + async for msg in consumer: + d = json.loads(msg.value) + if d.get('msg_kind') != 'run_values_updated': + continue + + d2 = d['data'] + run_var_rows = await async_changed_values( + self.proposal, d2['proposal'], d2['run'], d2['values'] + ) + res = await prepare_latest_data(run_var_rows, self.proposal, self.schema) + if res is not None: + self.notify(res) + finally: + await consumer.stop() + + + # Per-client cursor is deliberately omitted from the cache key so that # concurrent subscribers coalesce into a single DB read per tick. @@ -30,14 +99,18 @@ async def poll_proposal(proposal, schema): ) if not rows: return None + return prepare_latest_data(rows, proposal, schema) - latest_data = LatestData.from_list(rows) + +async def prepare_latest_data(latest_rows, proposal, schema): + latest_data = LatestData.from_list(latest_rows) latest_runs = create_map( await fetch_info(proposal, runs=list(latest_data.runs.keys())), key="run", ) latest_variables = await async_variables(proposal) + model = get_model(proposal) if model.update(latest_variables, timestamp=latest_data.timestamp): schema.update(model.stype) @@ -99,13 +172,11 @@ async def latest_data( database: DatabaseInput, timestamp: Timestamp, # FIX: # pyright: ignore[reportInvalidTypeForm] ) -> AsyncGenerator[JSON]: # FIX: # pyright: ignore[reportInvalidTypeForm] - while True: - await asyncio.sleep(POLLING_INTERVAL) - - snapshot = await poll_proposal( - proposal=database.proposal, - schema=info.schema, - ) - result = filter_for_client(snapshot, timestamp) - if result is not None: - yield result # FIX: # pyright: ignore[reportReturnType] + q = await DBWatcher.queue_for_proposal(database.proposal, info.schema) + + try: + while True: + update = await q.get() + yield update + finally: + DBWatcher.drop_queue(database.proposal, q)