Skip to content

Commit 710a98b

Browse files
committed
Additional tests for schemas
1 parent 3995607 commit 710a98b

2 files changed

Lines changed: 86 additions & 1 deletion

File tree

plugboard/schemas/tune.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,12 @@ class ObjectiveSpec(BaseFieldSpec):
5454

5555
@model_validator(mode="before")
5656
@classmethod
57-
def _fill_defaults(cls, data: dict[str, _t.Any]) -> dict[str, _t.Any]:
57+
def _fill_defaults(
58+
cls, data: dict[str, _t.Any] | list[dict[str, _t.Any]]
59+
) -> dict[str, _t.Any] | list[dict[str, _t.Any]]:
60+
if isinstance(data, list):
61+
# If the data is a list, skip because it is already a list of objectives
62+
return data
5863
if "field_type" not in data:
5964
data["field_type"] = "field"
6065
if data["field_type"] != "field":

tests/unit/test_schemas.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Provides unit tests for the schemas module."""
2+
3+
import pytest
4+
5+
from plugboard.schemas import TuneArgsSpec, TuneSpec
6+
7+
8+
def test_tune_spec() -> None:
9+
"""Test the TuneSpec class."""
10+
valid_spec = {
11+
"objective": {
12+
"object_type": "component",
13+
"object_name": "my_component",
14+
"field_name": "my_metric",
15+
},
16+
"parameters": [
17+
{
18+
"object_type": "component",
19+
"object_name": "my_component",
20+
"field_type": "arg",
21+
"field_name": "my_param",
22+
"type": "ray.tune.uniform",
23+
"lower": 0.0,
24+
"upper": 1.0,
25+
},
26+
{
27+
"object_type": "component",
28+
"object_name": "my_component",
29+
"field_type": "initial_value",
30+
"field_name": "x",
31+
"type": "ray.tune.randint",
32+
"lower": 1,
33+
"upper": 10,
34+
},
35+
{
36+
"object_type": "component",
37+
"object_name": "my_component",
38+
"field_type": "arg",
39+
"field_name": "my_choice",
40+
"categories": ["option1", "option2", "option3"],
41+
},
42+
],
43+
"num_samples": 100,
44+
"mode": "max",
45+
"max_concurrent": 5,
46+
"algorithm": {
47+
"type": "ray.tune.search.optuna.OptunaSearch",
48+
"study_name": "my_study",
49+
"storage": "sqlite:///my_study.db",
50+
},
51+
}
52+
# Validate the TuneSpec with the valid specification
53+
_ = TuneSpec(args=TuneArgsSpec.model_validate(valid_spec))
54+
55+
invalid_spec = valid_spec.copy()
56+
invalid_spec["mode"] = ["min", "max"]
57+
# Invalid mode should raise a validation error
58+
with pytest.raises(ValueError):
59+
_ = TuneSpec(args=TuneArgsSpec.model_validate(invalid_spec))
60+
61+
invalid_spec["objective"] = [
62+
{
63+
"object_type": "component",
64+
"object_name": "my_component",
65+
"field_name": "my_metric",
66+
},
67+
{
68+
"object_type": "component",
69+
"object_name": "another_component",
70+
"field_name": "another_metric",
71+
},
72+
{
73+
"object_type": "component",
74+
"object_name": "my_component",
75+
"field_name": "yet_another_metric",
76+
},
77+
]
78+
# Invalid objective length should raise a validation error
79+
with pytest.raises(ValueError):
80+
_ = TuneSpec(args=TuneArgsSpec.model_validate(invalid_spec))

0 commit comments

Comments
 (0)