Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"damnit~=0.2.1",
"aiohttp~=3.11",
"aiofiles>=24.1.0",
"aiokafka>=0.13",
"anyio>=4.11.0",
"python-ulid[pydantic]>=3.1.0",
"sqlmodel>=0.0.31",
Expand Down
25 changes: 25 additions & 0 deletions api/src/damnit_api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,31 @@ async def async_variable_tags(proposal):
return variable_tags


async def async_config_value(proposal, key):
metameta = await async_table(proposal, name="metameta")
selection = select(metameta.c.value).where(metameta.c.key == key)
async with get_session(proposal) as session:
result = await session.execute(selection)
return result.scalar()


async def async_changed_values(proposal, data_proposal, run, values):
run_vars = await async_table(proposal, name="run_variables")

selection = (
select(run_vars)
.where(
run_vars.c.proposal == data_proposal,
run_vars.c.run == run,
run_vars.c.run.in_(values)
)
)

async with get_session(proposal) as session:
result = await session.execute(selection)
return result.mappings().all() # FIX: # pyright: ignore[reportReturnType]


# -----------------------------------------------------------------------------
# Etc.

Expand Down
95 changes: 83 additions & 12 deletions api/src/damnit_api/graphql/subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,87 @@
import asyncio
import json
from collections.abc import AsyncGenerator

import strawberry
from aiokafka import AIOKafkaConsumer
from async_lru import alru_cache
from strawberry.scalars import JSON
from strawberry.types import Info

from ..db import async_latest_rows, async_table, async_variables
from ..db import (
async_latest_rows,
async_table,
async_variables,
async_config_value,
async_changed_values,
)
from ..utils import create_map, wrap_values
from .models import Timestamp, get_model
from .utils import DatabaseInput, LatestData, fetch_info

POLLING_INTERVAL = 1 # seconds

KAFKA_BROKERS = ['exflwgs06.desy.de:9091']
KAFKA_UPDATE_TOPIC = "test.damnit.db-{}"


class DBWatcher:
# The registry holds 1 DBWatcher per proposal, so we retrieve new values
# from the database once, even if several clients are subscribed.
registry: dict[str, 'DBWatcher'] = {}

@classmethod
async def queue_for_proposal(cls, proposal, schema):
q = asyncio.Queue()
if (watcher := cls.registry.get(proposal)) is None:
watcher = cls(proposal, schema)
watcher.task = asyncio.create_task(watcher.run())
watcher.subscriptions.add(q)
return q

@classmethod
def drop_queue(cls, proposal, queue):
if (watcher := cls.registry.get(proposal)) is not None:
watcher.subscriptions.discard(queue)
if not watcher.subscriptions:
del cls.registry[proposal]
if watcher.task is not None:
watcher.task.cancel()

def __init__(self, proposal, schema):
self.proposal = proposal
self.schema = schema
self.subscriptions = set()
self.task: asyncio.Task | None = None

def notify(self, d: dict):
for q in self.subscriptions:
q.put_nowait(d)

async def run(self):
db_id = await async_config_value(self.proposal, "db_id")
update_topic = KAFKA_UPDATE_TOPIC.format(db_id)

consumer = AIOKafkaConsumer(update_topic, bootstrap_servers=KAFKA_BROKERS)
await consumer.start()
try:
async for msg in consumer:
d = json.loads(msg.value)
if d.get('msg_kind') != 'run_values_updated':
continue

d2 = d['data']
run_var_rows = await async_changed_values(
self.proposal, d2['proposal'], d2['run'], d2['values']
)
res = await prepare_latest_data(run_var_rows, self.proposal, self.schema)
if res is not None:
self.notify(res)
finally:
await consumer.stop()




# Per-client cursor is deliberately omitted from the cache key so that
# concurrent subscribers coalesce into a single DB read per tick.
Expand All @@ -30,14 +99,18 @@ async def poll_proposal(proposal, schema):
)
if not rows:
return None
return prepare_latest_data(rows, proposal, schema)

latest_data = LatestData.from_list(rows)

async def prepare_latest_data(latest_rows, proposal, schema):
latest_data = LatestData.from_list(latest_rows)
latest_runs = create_map(
await fetch_info(proposal, runs=list(latest_data.runs.keys())),
key="run",
)

latest_variables = await async_variables(proposal)
model = get_model(proposal)
if model.update(latest_variables, timestamp=latest_data.timestamp):
schema.update(model.stype)

Expand Down Expand Up @@ -99,13 +172,11 @@ async def latest_data(
database: DatabaseInput,
timestamp: Timestamp, # FIX: # pyright: ignore[reportInvalidTypeForm]
) -> AsyncGenerator[JSON]: # FIX: # pyright: ignore[reportInvalidTypeForm]
while True:
await asyncio.sleep(POLLING_INTERVAL)

snapshot = await poll_proposal(
proposal=database.proposal,
schema=info.schema,
)
result = filter_for_client(snapshot, timestamp)
if result is not None:
yield result # FIX: # pyright: ignore[reportReturnType]
q = await DBWatcher.queue_for_proposal(database.proposal, info.schema)

try:
while True:
update = await q.get()
yield update
finally:
DBWatcher.drop_queue(database.proposal, q)