Skip to content

Commit f8dea08

Browse files
authored
Merge origin/main into copilot/refine-cli-agents-md
2 parents 2e27ca2 + 8eb9451 commit f8dea08

14 files changed

Lines changed: 507 additions & 52 deletions

File tree

.github/workflows/lint-test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
if: always()
5959

6060
- name: Code complexity
61-
run: uv run xenon --max-absolute B --max-modules A --max-average A plugboard/
61+
run: uv run xenon --max-absolute C --max-modules A --max-average A plugboard/
6262
if: always()
6363

6464
- name: Notebook output cleared

plugboard-schemas/plugboard_schemas/_validation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from ._validator_registry import validator
1919

2020

21+
_SYSTEM_STOP_EVENT = "system_stop"
22+
23+
2124
def _build_component_graph(
2225
connectors: dict[str, dict[str, _t.Any]],
2326
) -> dict[str, set[str]]:
@@ -100,6 +103,9 @@ def validate_all_inputs_connected(
100103
all_inputs = set(io.get("inputs", []))
101104
connected = connected_inputs.get(comp_name, set())
102105
unconnected = all_inputs - connected
106+
if unconnected:
107+
event_covered_fields = set().union(*io.get("event_field_coverage", {}).values())
108+
unconnected -= event_covered_fields
103109
if unconnected:
104110
errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}")
105111
return errors

plugboard/cli/server/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,14 @@ async def _post_to_api(url: str, data: dict) -> None:
4343
def _import_recursive(path: Path, base_package: _t.Optional[str] = None) -> None:
4444
"""Import all modules recursively from the given path."""
4545
logger = DI.logger.resolve_sync()
46-
for root, _dirs, files in os.walk(path):
46+
for root, dirs, files in os.walk(path):
47+
# Update dirs in place so os.walk skips hidden directories like .venv.
48+
dirs[:] = [directory for directory in dirs if not directory.startswith(".")]
4749
for file in files:
4850
if file.endswith(".py") and not file.startswith("__"):
4951
# Construct module name
50-
rel_path = os.path.relpath(os.path.join(root, file), path)
51-
module_name = rel_path.replace(os.sep, ".")[:-3]
52+
rel_path = Path(root, file).relative_to(path)
53+
module_name = ".".join(rel_path.with_suffix("").parts)
5254

5355
if base_package:
5456
module_name = f"{base_package}.{module_name}"

plugboard/component/component.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
initial_values=self._initial_values,
9696
input_events=self.__class__.io.input_events,
9797
output_events=self.__class__.io.output_events,
98+
event_field_coverage=self.__class__.io.event_field_coverage,
9899
namespace=self.name,
99100
component=self,
100101
)
@@ -143,10 +144,9 @@ def parameters(self) -> dict[str, _t.Any]:
143144
return self._parameters
144145

145146
@classmethod
146-
def _configure_io(cls) -> None:
147-
# Get all parent classes that are Component subclasses
147+
def _get_aggregated_io_args(cls) -> tuple[dict[str, set], list[str]]:
148+
"""Get combined set of all io arguments and exports from this class and all parents."""
148149
parent_comps = cls._get_component_bases()
149-
# Create combined set of all io arguments from this class and all parents
150150
io_args: dict[str, set] = defaultdict(set)
151151
exports: list[str] = []
152152
for c in parent_comps + [cls]:
@@ -157,12 +157,30 @@ def _configure_io(cls) -> None:
157157
io_args["output_events"].update(c_io.output_events)
158158
if c_exports := getattr(c, "exports"):
159159
exports.extend(c_exports)
160+
return io_args, exports
161+
162+
@classmethod
163+
def _get_event_field_coverage(cls) -> dict[str, list[str]]:
164+
"""Get event field coverage from all handlers in this class and all parents."""
165+
event_field_coverage = {}
166+
for attr_name in dir(cls):
167+
attr = getattr(cls, attr_name, None)
168+
if callable(attr) and hasattr(attr, "_event_field_coverage"):
169+
event_field_coverage.update(attr._event_field_coverage)
170+
return event_field_coverage
171+
172+
@classmethod
173+
def _configure_io(cls) -> None:
174+
# Get all parent classes that are Component subclasses
175+
io_args, exports = cls._get_aggregated_io_args()
176+
event_field_coverage = cls._get_event_field_coverage()
160177
# Set io arguments for subclass
161178
cls.io = IO(
162179
inputs=sorted(io_args["inputs"], key=str),
163180
outputs=sorted(io_args["outputs"], key=str),
164181
input_events=sorted(io_args["input_events"], key=str),
165182
output_events=sorted(io_args["output_events"], key=str),
183+
event_field_coverage=event_field_coverage,
166184
)
167185
# Set exports for subclass
168186
cls.exports = sorted(set(exports))
@@ -356,7 +374,7 @@ async def _wrapper() -> None:
356374
raise e
357375
self._bind_outputs()
358376
await self.io.write()
359-
self._field_inputs_ready = False
377+
self._reset_input_trackers()
360378
await self._set_status(Status.WAITING, publish=not self._is_running)
361379

362380
return _wrapper
@@ -365,6 +383,11 @@ async def _wrapper() -> None:
365383
def _has_field_inputs(self) -> bool:
366384
return len(self.io.inputs) > 0
367385

386+
@property
387+
def _has_connected_field_inputs(self) -> bool:
388+
"""Whether any declared field inputs are connected via input channels."""
389+
return self.io.has_connected_field_inputs
390+
368391
@cached_property
369392
def _has_event_inputs(self) -> bool:
370393
input_events = set([evt.safe_type() for evt in self.io.input_events])
@@ -409,7 +432,7 @@ async def _io_read_with_status_check(self) -> None:
409432
task.cancel()
410433
for task in done:
411434
exc = task.exception()
412-
if isinstance(exc, EventStreamClosedError) and len(self.io.inputs) == 0:
435+
if isinstance(exc, EventStreamClosedError) and not self._has_connected_field_inputs:
413436
await self.io.close() # Call close for final wait and flush event buffer
414437
elif exc is not None:
415438
raise exc
@@ -422,7 +445,7 @@ async def _periodic_status_check(self) -> None:
422445
# TODO : Eventually producer graph update will be event driven. For now,
423446
# : the update is performed periodically, so it's called here along
424447
# : with the status check.
425-
if len(self.io.inputs) == 0:
448+
if not self._has_connected_field_inputs:
426449
await self._update_producer_graph()
427450

428451
async def _status_check(self) -> None:
@@ -455,8 +478,11 @@ def _bind_inputs(self) -> None:
455478
for field in self.io.inputs:
456479
field_default = getattr(self, field, None)
457480
value = self._field_inputs.get(field, field_default)
458-
setattr(self, field, value)
481+
super().__setattr__(field, value)
482+
483+
def _reset_input_trackers(self) -> None:
459484
self._field_inputs = {}
485+
self._field_inputs_ready = False
460486

461487
def _bind_outputs(self) -> None:
462488
"""Binds component fields to output fields."""

plugboard/component/io_controller.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
initial_values: _t.Optional[dict[str, _t.Iterable]] = None,
3939
input_events: _t.Optional[list[_t.Type[Event]]] = None,
4040
output_events: _t.Optional[list[_t.Type[Event]]] = None,
41+
event_field_coverage: _t.Optional[dict[str, list[str]]] = None,
4142
namespace: str = IO_NS_UNSET,
4243
component: _t.Optional[Component] = None,
4344
) -> None:
@@ -47,10 +48,27 @@ def __init__(
4748
self.initial_values = initial_values or {}
4849
self.input_events = input_events or []
4950
self.output_events = output_events or []
51+
self.event_field_coverage = event_field_coverage or {}
5052
if set(self.initial_values.keys()) - set(self.inputs):
5153
raise ValueError("Initial values must be for input fields only.")
54+
5255
self._component = component
56+
self._initial_values = {k: deque(v) for k, v in self.initial_values.items()}
57+
self._input_event_types = {Event.safe_type(evt.type) for evt in self.input_events}
58+
self._output_event_types = {Event.safe_type(evt.type) for evt in self.output_events}
59+
60+
self._logger = DI.logger.resolve_sync().bind(
61+
cls=self.__class__.__name__, namespace=self.namespace
62+
)
63+
self._logger.info("IOController created")
64+
65+
# Initialise channel stores
66+
self._input_channels: dict[tuple[str, str], Channel] = {}
67+
self._output_channels: dict[tuple[str, str], Channel] = {}
68+
self._input_event_channels: dict[str, Channel] = {}
69+
self._output_event_channels: dict[str, Channel] = {}
5370

71+
# Initialise buffers
5472
self.buf_fields: dict[str, IOBuffer] = {
5573
_io_key_in: IOFieldBuffer(),
5674
_io_key_out: IOFieldBuffer(),
@@ -60,21 +78,9 @@ def __init__(
6078
_io_key_out: IOEventBuffer(),
6179
}
6280

63-
self._input_channels: dict[tuple[str, str], Channel] = {}
64-
self._output_channels: dict[tuple[str, str], Channel] = {}
65-
self._input_event_channels: dict[str, Channel] = {}
66-
self._output_event_channels: dict[str, Channel] = {}
67-
self._input_event_types = {Event.safe_type(evt.type) for evt in self.input_events}
68-
self._output_event_types = {Event.safe_type(evt.type) for evt in self.output_events}
69-
self._initial_values = {k: deque(v) for k, v in self.initial_values.items()}
70-
self._read_tasks: dict[str | _t_field_key, asyncio.Task] = {}
81+
# Initialise orchestration state
7182
self._is_closed = False
72-
73-
self._logger = DI.logger.resolve_sync().bind(
74-
cls=self.__class__.__name__, namespace=self.namespace
75-
)
76-
self._logger.info("IOController created")
77-
83+
self._read_tasks: dict[str | _t_field_key, asyncio.Task] = {}
7884
self._received_fields: dict[str, _t.Any] = {}
7985
self._received_fields_lock = asyncio.Lock()
8086
self._received_events: deque[Event] = deque()
@@ -86,8 +92,9 @@ def is_closed(self) -> bool:
8692
"""Returns `True` if the `IOController` is closed, `False` otherwise."""
8793
return self._is_closed
8894

89-
@cached_property
90-
def _has_field_inputs(self) -> bool:
95+
@property
96+
def has_connected_field_inputs(self) -> bool:
97+
"""Returns whether any field inputs are connected via channels."""
9198
return len(self._input_channels) > 0
9299

93100
@cached_property
@@ -96,7 +103,7 @@ def _has_event_inputs(self) -> bool:
96103

97104
@cached_property
98105
def _has_inputs(self) -> bool:
99-
return self._has_field_inputs or self._has_event_inputs
106+
return self.has_connected_field_inputs or self._has_event_inputs
100107

101108
async def read(self, timeout: float | None = None) -> None:
102109
"""Reads data and/or events from input channels.
@@ -139,7 +146,7 @@ async def read(self, timeout: float | None = None) -> None:
139146

140147
def _set_read_tasks(self) -> list[asyncio.Task]:
141148
read_tasks: list[asyncio.Task] = []
142-
if self._has_field_inputs:
149+
if self.has_connected_field_inputs:
143150
if _fields_read_task not in self._read_tasks:
144151
read_fields_task = asyncio.create_task(self._read_fields(), name=_fields_read_task)
145152
self._read_tasks[_fields_read_task] = read_fields_task
@@ -374,7 +381,7 @@ def _add_channel_for_event(
374381

375382
def _create_input_field_group_tasks(self) -> None:
376383
"""Groups input field channels by field name and launches read tasks for group inputs."""
377-
if not self._has_field_inputs:
384+
if not self.has_connected_field_inputs:
378385
return
379386
field_channels: dict[str, list[tuple[_t_field_key, Channel]]] = defaultdict(list)
380387
for key, chan in self._input_channels.items():
@@ -410,6 +417,7 @@ def dict(self) -> dict[str, _t.Any]: # noqa: D102
410417
"input_events": [e.safe_type() for e in self.input_events],
411418
"output_events": [e.safe_type() for e in self.output_events],
412419
"initial_values": {k: list(v) for k, v in self._initial_values.items()},
420+
"event_field_coverage": {k: list(v) for k, v in self.event_field_coverage.items()},
413421
}
414422

415423

plugboard/events/event.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,28 @@ def safe_type(cls, event_type: _t.Optional[str] = None) -> str:
7575
"""Returns a safe event type string for use in broker topic strings."""
7676
return (event_type or cls.type).replace(".", "_").replace("-", "_")
7777

78+
@_t.overload
7879
@classmethod
79-
def handler(cls, method: AsyncCallable) -> AsyncCallable:
80+
def handler(cls, method: AsyncCallable) -> AsyncCallable: ...
81+
82+
@_t.overload
83+
@classmethod
84+
def handler(
85+
cls, *, populates_fields: _t.Optional[list[str]] = None
86+
) -> _t.Callable[[AsyncCallable], AsyncCallable]: ...
87+
88+
@classmethod
89+
def handler(
90+
cls,
91+
method: _t.Optional[AsyncCallable] = None,
92+
*,
93+
populates_fields: _t.Optional[list[str]] = None,
94+
) -> _t.Union[AsyncCallable, _t.Callable[[AsyncCallable], AsyncCallable]]:
8095
"""Registers a class method as an event handler."""
96+
if method is None:
97+
# Invoked as @Event.handler(populates_fields=[...])
98+
return EventHandlers.add(cls, populates_fields=populates_fields)
99+
# Invoked as @Event.handler
81100
return EventHandlers.add(cls)(method)
82101

83102

plugboard/events/event_handlers.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,16 @@ class EventHandlers: # pragma: no cover
1818
_handlers: _t.ClassVar[dict[str, dict[str, AsyncCallable]]] = defaultdict(dict)
1919

2020
@classmethod
21-
def add(cls, event: _t.Type[Event] | Event) -> _t.Callable[[AsyncCallable], AsyncCallable]:
21+
def add(
22+
cls,
23+
event: _t.Type[Event] | Event,
24+
populates_fields: _t.Optional[list[str]] = None,
25+
) -> _t.Callable[[AsyncCallable], AsyncCallable]:
2226
"""Decorator that registers class methods as handlers for specific event types.
2327
2428
Args:
2529
event: Event class this handler processes
30+
populates_fields: Optional list of fields that the handler populates
2631
2732
Returns:
2833
Callable: Decorated method
@@ -31,6 +36,12 @@ def add(cls, event: _t.Type[Event] | Event) -> _t.Callable[[AsyncCallable], Asyn
3136
def decorator(method: AsyncCallable) -> AsyncCallable:
3237
class_path = cls._get_class_path_for_method(method)
3338
cls._handlers[class_path][event.type] = method
39+
40+
if populates_fields is not None:
41+
if not hasattr(method, "_event_field_coverage"):
42+
setattr(method, "_event_field_coverage", {})
43+
getattr(method, "_event_field_coverage")[event.type] = populates_fields
44+
3445
return method
3546

3647
return decorator
@@ -57,10 +68,11 @@ def get(cls, _class: _t.Type, event: _t.Type[Event] | Event) -> AsyncCallable:
5768
Raises:
5869
KeyError: If no handler found for class or event type
5970
"""
71+
store = cls._handlers
6072
for base_class in _class.__mro__:
6173
base_path = f"{base_class.__module__}.{base_class.__name__}"
62-
if base_path in cls._handlers and event.type in cls._handlers[base_path]:
63-
return cls._handlers[base_path][event.type]
74+
if base_path in store and event.type in store[base_path]:
75+
return store[base_path][event.type]
6476
raise KeyError(
6577
f"No handler found for class '{_class.__name__}' and event type '{event.type}'"
6678
)

plugboard/library/data_writer.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
outputs=None,
5151
input_events=self.__class__.io.input_events,
5252
output_events=self.__class__.io.output_events,
53+
event_field_coverage=self.__class__.io.event_field_coverage,
5354
namespace=self.name,
5455
component=self,
5556
)
@@ -76,18 +77,39 @@ async def _convert(self, data: dict[str, deque]) -> _t.Any:
7677
def _bind_inputs(self) -> None:
7778
"""Binds input fields to component fields and append to internal buffer."""
7879
super()._bind_inputs()
79-
for field in self.io.inputs:
80+
for field in self._field_inputs:
8081
value = getattr(self, field, None)
8182
self._buffer[field].append(value)
8283

84+
@property
85+
def _completed_rows(self) -> int:
86+
"""Calculates how many fully formed rows exist in the buffer."""
87+
if not self.io.inputs:
88+
return 0
89+
return min((len(self._buffer[f]) for f in self.io.inputs), default=0)
90+
91+
@property
92+
def _can_step(self) -> bool:
93+
"""We can step if we have at least one fully formed row."""
94+
return self._completed_rows > 0
95+
8396
async def _save_chunk(self) -> None:
84-
"""Write data from the buffer."""
97+
"""Write completed data rows from the buffer."""
98+
completed_rows = self._completed_rows
99+
if completed_rows == 0:
100+
return
101+
85102
if self._task is not None:
86103
await self._task
87-
# Create task to save next chunk of data
88-
chunk = await self._convert(self._buffer)
104+
105+
# Extract only the completed rows into a new chunk
106+
chunk_data = {
107+
field: deque([self._buffer[field].popleft() for _ in range(completed_rows)])
108+
for field in self.io.inputs
109+
}
110+
111+
chunk = await self._convert(chunk_data)
89112
self._task = asyncio.create_task(self._save(chunk))
90-
self._buffer = defaultdict(deque)
91113

92114
async def step(self) -> None:
93115
"""Trigger save when buffer is at target size."""

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ dev = [
6868
"pre-commit>=3.8,<5",
6969
"radon>=6.0.1,<7",
7070
"ruff>=0.5,<1",
71-
"types-aiofiles>=24.1,<25",
71+
"types-aiofiles>=24.1,<26",
7272
"xenon>=0.9.3,<1",
7373
]
7474
test = [

0 commit comments

Comments
 (0)