Skip to content

Commit 662a485

Browse files
fix: Better handling for optional arguments (#1124)
Adjusts handling of optional arguments in plan parameters to use the `required` field from the JsonSchema only instead of additionally explicitly allowing `null` as an argument. This makes the schema easier to read and automated tools should work better with it.
1 parent 22b287d commit 662a485

File tree

3 files changed

+101
-85
lines changed

3 files changed

+101
-85
lines changed

src/blueapi/core/context.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass, field
44
from importlib import import_module
55
from inspect import Parameter, signature
6-
from types import ModuleType, UnionType
6+
from types import ModuleType, NoneType, UnionType
77
from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints
88

99
from bluesky.protocols import HasName
@@ -12,7 +12,7 @@
1212
from ophyd_async.core import NotConnected
1313
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, create_model
1414
from pydantic.fields import FieldInfo
15-
from pydantic.json_schema import JsonSchemaValue
15+
from pydantic.json_schema import JsonSchemaValue, SkipJsonSchema
1616
from pydantic_core import CoreSchema, core_schema
1717

1818
from blueapi import utils
@@ -323,12 +323,12 @@ def _type_spec_for_function(
323323
no_default = para.default is Parameter.empty
324324
factory = None if no_default else DefaultFactory(para.default)
325325
new_args[name] = (
326-
self._convert_type(arg_type),
326+
self._convert_type(arg_type, no_default),
327327
FieldInfo(default_factory=factory),
328328
)
329329
return new_args
330330

331-
def _convert_type(self, typ: type | Any) -> type:
331+
def _convert_type(self, typ: type | Any, no_default: bool = True) -> type:
332332
"""
333333
Recursively convert a type to something that can be deserialised by
334334
pydantic. Bluesky protocols (and types that extend them) are replaced
@@ -344,12 +344,14 @@ def _convert_type(self, typ: type | Any) -> type:
344344
Returns:
345345
A Type that can be deserialised by Pydantic
346346
"""
347+
if typ is NoneType and not no_default:
348+
return SkipJsonSchema[NoneType]
347349
root = get_origin(typ)
348350
if is_bluesky_type(typ) or (root is not None and is_bluesky_type(root)):
349351
return self._reference(typ)
350352
args = get_args(typ)
351353
if args:
352-
new_types = tuple(self._convert_type(i) for i in args)
354+
new_types = tuple(self._convert_type(i, no_default) for i in args)
353355
if root == UnionType:
354356
root = Union
355357
return root[new_types] if root else typ # type: ignore

tests/system_tests/plans.json

Lines changed: 24 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,9 @@
3737
"title": "Delay"
3838
},
3939
"metadata": {
40-
"anyOf": [
41-
{
42-
"additionalProperties": true,
43-
"type": "object"
44-
},
45-
{
46-
"type": "null"
47-
}
48-
],
49-
"title": "Metadata"
40+
"additionalProperties": true,
41+
"title": "Metadata",
42+
"type": "object"
5043
}
5144
},
5245
"required": [
@@ -89,7 +82,7 @@
8982
},
9083
"radius": {
9184
"description": "Radius of the circle",
92-
"exclusiveMinimum": 0,
85+
"exclusiveMinimum": 0.0,
9386
"title": "Radius",
9487
"type": "number"
9588
},
@@ -228,18 +221,18 @@
228221
},
229222
"x_radius": {
230223
"description": "The radius along the x axis of the ellipse",
231-
"exclusiveMinimum": 0,
224+
"exclusiveMinimum": 0.0,
232225
"title": "X Radius",
233226
"type": "number"
234227
},
235228
"y_radius": {
236229
"description": "The radius along the y axis of the ellipse",
237-
"exclusiveMinimum": 0,
230+
"exclusiveMinimum": 0.0,
238231
"title": "Y Radius",
239232
"type": "number"
240233
},
241234
"angle": {
242-
"default": 0,
235+
"default": 0.0,
243236
"description": "The angle of the ellipse (degrees)",
244237
"title": "Angle",
245238
"type": "number"
@@ -510,7 +503,7 @@
510503
"type": "number"
511504
},
512505
"angle": {
513-
"default": 0,
506+
"default": 0.0,
514507
"description": "Clockwise rotation angle of the rectangle",
515508
"title": "Angle",
516509
"type": "number"
@@ -724,7 +717,7 @@
724717
"type": "integer"
725718
},
726719
"rotate": {
727-
"default": 0,
720+
"default": 0.0,
728721
"description": "How much to rotate the angle of the spiral",
729722
"title": "Rotate",
730723
"type": "number"
@@ -908,16 +901,9 @@
908901
"$ref": "#/$defs/Spec"
909902
},
910903
"metadata": {
911-
"anyOf": [
912-
{
913-
"additionalProperties": true,
914-
"type": "object"
915-
},
916-
{
917-
"type": "null"
918-
}
919-
],
920-
"title": "Metadata"
904+
"additionalProperties": true,
905+
"title": "Metadata",
906+
"type": "object"
921907
}
922908
},
923909
"required": [
@@ -946,15 +932,8 @@
946932
"title": "Value"
947933
},
948934
"group": {
949-
"anyOf": [
950-
{
951-
"type": "string"
952-
},
953-
{
954-
"type": "null"
955-
}
956-
],
957-
"title": "Group"
935+
"title": "Group",
936+
"type": "string"
958937
},
959938
"wait": {
960939
"title": "Wait",
@@ -987,15 +966,8 @@
987966
"title": "Value"
988967
},
989968
"group": {
990-
"anyOf": [
991-
{
992-
"type": "string"
993-
},
994-
{
995-
"type": "null"
996-
}
997-
],
998-
"title": "Group"
969+
"title": "Group",
970+
"type": "string"
999971
},
1000972
"wait": {
1001973
"title": "Wait",
@@ -1022,15 +994,8 @@
1022994
"type": "object"
1023995
},
1024996
"group": {
1025-
"anyOf": [
1026-
{
1027-
"type": "string"
1028-
},
1029-
{
1030-
"type": "null"
1031-
}
1032-
],
1033-
"title": "Group"
997+
"title": "Group",
998+
"type": "string"
1034999
}
10351000
},
10361001
"required": [
@@ -1052,15 +1017,8 @@
10521017
"type": "object"
10531018
},
10541019
"group": {
1055-
"anyOf": [
1056-
{
1057-
"type": "string"
1058-
},
1059-
{
1060-
"type": "null"
1061-
}
1062-
],
1063-
"title": "Group"
1020+
"title": "Group",
1021+
"type": "string"
10641022
}
10651023
},
10661024
"required": [
@@ -1095,26 +1053,12 @@
10951053
"additionalProperties": false,
10961054
"properties": {
10971055
"group": {
1098-
"anyOf": [
1099-
{
1100-
"type": "string"
1101-
},
1102-
{
1103-
"type": "null"
1104-
}
1105-
],
1106-
"title": "Group"
1056+
"title": "Group",
1057+
"type": "string"
11071058
},
11081059
"timeout": {
1109-
"anyOf": [
1110-
{
1111-
"type": "number"
1112-
},
1113-
{
1114-
"type": "null"
1115-
}
1116-
],
1117-
"title": "Timeout"
1060+
"title": "Timeout",
1061+
"type": "number"
11181062
}
11191063
},
11201064
"title": "wait",

tests/unit_tests/core/test_context.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from dataclasses import dataclass
44
from pathlib import Path
5+
from types import NoneType
56
from typing import Generic, TypeVar, Union
67
from unittest.mock import patch
78

@@ -26,6 +27,7 @@
2627
from ophyd_async.epics.adaravis import AravisDetector
2728
from ophyd_async.epics.motor import Motor
2829
from pydantic import TypeAdapter, ValidationError
30+
from pydantic.json_schema import SkipJsonSchema
2931
from pytest import LogCaptureFixture
3032

3133
from blueapi.config import EnvironmentConfig, MetadataConfig, Source, SourceKind
@@ -431,6 +433,18 @@ def test_reference_type_conversion_new_style_union(
431433
assert empty_context._convert_type(Movable | int) == movable_ref | int
432434

433435

436+
def test_reference_type_conversion_new_style_optional(
437+
empty_context: BlueskyContext,
438+
):
439+
movable_ref: type = empty_context._reference(Movable)
440+
assert empty_context._convert_type(Movable) == movable_ref
441+
assert empty_context._convert_type(Movable | None) == movable_ref | None
442+
assert (
443+
empty_context._convert_type(Movable | None, no_default=False)
444+
== movable_ref | SkipJsonSchema[NoneType]
445+
)
446+
447+
434448
def test_default_device_reference(empty_context: BlueskyContext):
435449
def default_movable(mov: Movable = inject("demo")) -> MsgGenerator:
436450
yield from ()
@@ -596,3 +610,59 @@ class InnerClass: ...
596610
@pytest.mark.parametrize("type,expected", qualified_name_test_data)
597611
def test_qualified_name_with_types(type: type, expected: str):
598612
assert qualified_name(type) == expected
613+
614+
615+
def test_optional_arg_generated_schema(
616+
empty_context: BlueskyContext,
617+
):
618+
def demo_plan(foo: int | None = None) -> MsgGenerator:
619+
yield from ()
620+
621+
empty_context.register_plan(demo_plan)
622+
schema = empty_context.plans["demo_plan"].model.model_json_schema()
623+
assert schema["properties"] == {
624+
"foo": {"title": "Foo", "type": "integer"},
625+
}
626+
assert "foo" not in schema.get("required", [])
627+
628+
629+
def test_overloaded_arg_generated_schema(
630+
empty_context: BlueskyContext,
631+
):
632+
def demo_plan(foo: int | str) -> MsgGenerator:
633+
yield from ()
634+
635+
empty_context.register_plan(demo_plan)
636+
schema = empty_context.plans["demo_plan"].model.model_json_schema()
637+
assert schema["properties"] == {
638+
"foo": {"title": "Foo", "anyOf": [{"type": "integer"}, {"type": "string"}]}
639+
}
640+
assert "foo" in schema.get("required", [])
641+
642+
643+
def test_optional_overloaded_arg_generated_schema(
644+
empty_context: BlueskyContext,
645+
):
646+
def demo_plan(foo: int | str | None = None) -> MsgGenerator:
647+
yield from ()
648+
649+
empty_context.register_plan(demo_plan)
650+
schema = empty_context.plans["demo_plan"].model.model_json_schema()
651+
assert schema["properties"] == {
652+
"foo": {"title": "Foo", "anyOf": [{"type": "integer"}, {"type": "string"}]}
653+
}
654+
assert "foo" not in schema.get("required", [])
655+
656+
657+
def test_explicit_none_arg_generated_schema(
658+
empty_context: BlueskyContext,
659+
):
660+
def demo_plan(foo: int | None) -> MsgGenerator:
661+
yield from ()
662+
663+
empty_context.register_plan(demo_plan)
664+
schema = empty_context.plans["demo_plan"].model.model_json_schema()
665+
assert schema["properties"] == {
666+
"foo": {"title": "Foo", "anyOf": [{"type": "integer"}, {"type": "null"}]}
667+
}
668+
assert "foo" in schema.get("required", [])

0 commit comments

Comments
 (0)