Skip to content

Commit 29a8120

Browse files
committed
fix: Pass headers into TiledWriter
1 parent 1ab51e8 commit 29a8120

File tree

5 files changed

+173
-201
lines changed

5 files changed

+173
-201
lines changed

src/blueapi/core/context.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from types import ModuleType, NoneType, UnionType
88
from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints
99

10+
from bluesky.callbacks.tiled_writer import TiledWriter
1011
from bluesky.protocols import HasName
1112
from bluesky.run_engine import RunEngine
1213
from dodal.common.beamlines.beamline_utils import get_path_provider, set_path_provider
@@ -16,6 +17,8 @@
1617
from pydantic.fields import FieldInfo
1718
from pydantic.json_schema import JsonSchemaValue, SkipJsonSchema
1819
from pydantic_core import CoreSchema, core_schema
20+
from tiled.client import from_uri
21+
from tiled.client.base import BaseClient
1922

2023
from blueapi import utils
2124
from blueapi.client.numtracker import NumtrackerClient
@@ -108,6 +111,7 @@ class BlueskyContext:
108111
run_engine: RunEngine = field(
109112
default_factory=lambda: RunEngine(context_managers=[])
110113
)
114+
tiled_client: BaseClient | None = field(default=None, init=False, repr=False)
111115
numtracker: NumtrackerClient | None = field(default=None, init=False, repr=False)
112116
plans: dict[str, Plan] = field(default_factory=dict)
113117
devices: dict[str, Device] = field(default_factory=dict)
@@ -157,6 +161,12 @@ def _update_scan_num(md: dict[str, Any]) -> int:
157161
"the devices. Remove this path provider to use numtracker."
158162
)
159163

164+
if (tiled_conf := configuration.tiled) is not None and tiled_conf.enabled:
165+
self.tiled_client = client = from_uri(
166+
str(tiled_conf.url), api_key=tiled_conf.api_key
167+
)
168+
self.run_engine.subscribe(TiledWriter(client))
169+
160170
def find_device(self, addr: str | list[str]) -> Device | None:
161171
"""
162172
Find a device in this context, allows for recursive search.

src/blueapi/service/interface.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from collections.abc import Mapping
22
from functools import cache
3+
from queue import Full
34
from typing import Any
45

5-
from bluesky.callbacks.tiled_writer import TiledWriter
66
from bluesky_stomp.messaging import StompClient
77
from bluesky_stomp.models import Broker, DestinationBase, MessageTopic
8-
from tiled.client import from_uri
98

109
from blueapi.cli.scratch import get_python_environment
11-
from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig, TiledConfig
10+
from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig
1211
from blueapi.core.context import BlueskyContext
1312
from blueapi.core.event import EventStream
1413
from blueapi.log import set_up_logging
@@ -20,7 +19,7 @@
2019
TaskRequest,
2120
WorkerTask,
2221
)
23-
from blueapi.worker.event import TaskStatusEnum, WorkerState
22+
from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState
2423
from blueapi.worker.task import Task
2524
from blueapi.worker.task_worker import TaskWorker, TrackableTask
2625

@@ -87,16 +86,6 @@ def stomp_client() -> StompClient | None:
8786
return None
8887

8988

90-
@cache
91-
def tiled_writer() -> TiledWriter | None:
92-
tiled_config: TiledConfig = config().tiled
93-
if tiled_config.enabled:
94-
client = from_uri(str(tiled_config.url), api_key=tiled_config.api_key)
95-
return TiledWriter(client, batch_size=1)
96-
else:
97-
return None
98-
99-
10089
def setup(config: ApplicationConfig) -> None:
10190
"""Creates and starts a worker with supplied config"""
10291
set_config(config)
@@ -105,8 +94,6 @@ def setup(config: ApplicationConfig) -> None:
10594
# Eagerly initialize worker and messaging connection
10695
worker()
10796
stomp_client()
108-
if writer := tiled_writer():
109-
context().run_engine.subscribe(writer)
11097

11198

11299
def teardown() -> None:
@@ -116,7 +103,6 @@ def teardown() -> None:
116103
context.cache_clear()
117104
worker.cache_clear()
118105
stomp_client.cache_clear()
119-
tiled_writer.cache_clear()
120106

121107

122108
def _publish_event_streams(
@@ -161,7 +147,10 @@ def submit_task(task_request: TaskRequest) -> str:
161147
task = Task(
162148
name=task_request.name,
163149
params=task_request.params,
164-
metadata={"instrument_session": task_request.instrument_session},
150+
metadata={
151+
"instrument_session": task_request.instrument_session,
152+
"tiled_access_tags": [task_request.instrument_session],
153+
},
165154
)
166155
return worker().submit_task(task)
167156

@@ -175,8 +164,38 @@ def begin_task(
175164
task: WorkerTask, pass_through_headers: Mapping[str, str] | None = None
176165
) -> WorkerTask:
177166
"""Trigger a task. Will fail if the worker is busy"""
167+
if worker().get_active_task() is not None:
168+
raise Full()
178169
if nt := context().numtracker:
179170
nt.set_headers(pass_through_headers or {})
171+
172+
def unset_headers_when_task_finished(
173+
event: WorkerEvent, correlation_id: str | None
174+
) -> None:
175+
if (
176+
event.task_status
177+
and event.task_status.task_id == task.task_id
178+
and event.task_status.task_complete
179+
):
180+
nt.set_headers({})
181+
182+
worker().worker_events.subscribe(unset_headers_when_task_finished)
183+
if tiled_client := context().tiled_client:
184+
tiled_client.context.http_client.headers.update(pass_through_headers or {})
185+
186+
def unset_headers_when_task_finished(
187+
event: WorkerEvent, correlation_id: str | None
188+
) -> None:
189+
if (
190+
event.task_status
191+
and event.task_status.task_id == task.task_id
192+
and event.task_status.task_complete
193+
):
194+
for header in pass_through_headers or {}:
195+
del tiled_client.context.http_client.headers[header]
196+
197+
worker().worker_events.subscribe(unset_headers_when_task_finished)
198+
180199
if task.task_id is not None:
181200
worker().begin_task(task.task_id)
182201
return task

tests/unit_tests/core/test_context.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,14 @@
3030
from pydantic.json_schema import SkipJsonSchema
3131
from pytest import LogCaptureFixture
3232

33-
from blueapi.config import EnvironmentConfig, MetadataConfig, Source, SourceKind
33+
from blueapi.config import (
34+
ApplicationConfig,
35+
EnvironmentConfig,
36+
MetadataConfig,
37+
Source,
38+
SourceKind,
39+
TiledConfig,
40+
)
3441
from blueapi.core import BlueskyContext, is_bluesky_compatible_device
3542
from blueapi.core.context import DefaultFactory, generic_bounds, qualified_name
3643
from blueapi.utils.connect_devices import _establish_device_connections
@@ -693,3 +700,30 @@ def demo_plan(foo: int | None) -> MsgGenerator:
693700
"foo": {"title": "Foo", "anyOf": [{"type": "integer"}, {"type": "null"}]}
694701
}
695702
assert "foo" in schema.get("required", [])
703+
704+
705+
def test_setup_without_tiled_not_makes_tiled_inserter():
706+
with patch("blueapi.core.context.from_uri") as from_uri:
707+
BlueskyContext(ApplicationConfig())
708+
709+
assert from_uri.call_count == 0
710+
711+
712+
def test_setup_with_tiled_makes_tiled_inserter():
713+
with patch("blueapi.core.context.from_uri") as from_uri:
714+
BlueskyContext(ApplicationConfig(tiled=TiledConfig(enabled=True)))
715+
716+
assert from_uri.call_count == 1
717+
assert from_uri.call_args.args == ("http://localhost:8407/",)
718+
assert from_uri.call_args.kwargs == {"api_key": None}
719+
720+
721+
def test_setup_with_tiled_api_key_makes_tiled_inserter():
722+
with patch("blueapi.core.context.from_uri") as from_uri:
723+
BlueskyContext(
724+
ApplicationConfig(tiled=TiledConfig(enabled=True, api_key="foobarbaz"))
725+
)
726+
727+
assert from_uri.call_count == 1
728+
assert from_uri.call_args.args == ("http://localhost:8407/",)
729+
assert from_uri.call_args.kwargs == {"api_key": "foobarbaz"}

tests/unit_tests/service/test_interface.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
Source,
2727
SourceKind,
2828
StompConfig,
29-
TiledConfig,
3029
)
3130
from blueapi.core.context import BlueskyContext
3231
from blueapi.service import interface
@@ -510,34 +509,6 @@ def test_setup_with_numtracker_makes_start_document_provider():
510509
clear_path_provider()
511510

512511

513-
def test_setup_without_tiled_not_makes_tiled_inserter():
514-
with patch("blueapi.service.interface.from_uri") as from_uri:
515-
conf = ApplicationConfig()
516-
interface.setup(conf)
517-
518-
assert from_uri.call_count == 0
519-
520-
521-
def test_setup_with_tiled_makes_tiled_inserter():
522-
with patch("blueapi.service.interface.from_uri") as from_uri:
523-
conf = ApplicationConfig(tiled=TiledConfig(enabled=True))
524-
interface.setup(conf)
525-
526-
assert from_uri.call_count == 1
527-
assert from_uri.call_args.args == ("http://localhost:8407/",)
528-
assert from_uri.call_args.kwargs == {"api_key": None}
529-
530-
531-
def test_setup_with_tiled_api_key_makes_tiled_inserter():
532-
with patch("blueapi.service.interface.from_uri") as from_uri:
533-
conf = ApplicationConfig(tiled=TiledConfig(enabled=True, api_key="foobarbaz"))
534-
interface.setup(conf)
535-
536-
assert from_uri.call_count == 1
537-
assert from_uri.call_args.args == ("http://localhost:8407/",)
538-
assert from_uri.call_args.kwargs == {"api_key": "foobarbaz"}
539-
540-
541512
def test_setup_with_numtracker_raises_if_provider_is_defined_in_device_module():
542513
conf = ApplicationConfig(
543514
env=EnvironmentConfig(

0 commit comments

Comments
 (0)