Skip to content

Commit 46cf41d

Browse files
committed
refactor(generator): remove backward compatibility for old style yaml/configs
1 parent 5979381 commit 46cf41d

6 files changed

Lines changed: 146 additions & 73 deletions

File tree

assets/userguide/examples/yaml_multi.yaml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ scenarios:
77
plans:
88
baseline:
99
distribution:
10-
kind: normal
11-
mean: 0.0
12-
std: 1.0
10+
kind: univariate
11+
family: Normal
12+
parametrization_name: meanStd
13+
mu: 0.0
14+
sigma: 1.0
1315
state:
1416
type: baseline
1517
second:
@@ -20,8 +22,10 @@ scenarios:
2022
plans:
2123
shifted:
2224
distribution:
23-
kind: normal
24-
mean: 4.0
25-
std: 1.0
25+
kind: univariate
26+
family: Normal
27+
parametrization_name: meanStd
28+
mu: 4.0
29+
sigma: 1.0
2630
state:
2731
type: shifted

assets/userguide/examples/yaml_single.yaml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@ segments:
77
plans:
88
baseline:
99
distribution:
10-
kind: normal
11-
mean: 0.0
12-
std: 1.0
10+
kind: univariate
11+
family: Normal
12+
parametrization_name: meanStd
13+
mu: 0.0
14+
sigma: 1.0
1315
state:
1416
type: baseline
1517
shifted:
1618
distribution:
17-
kind: normal
18-
mean: 2.0
19-
std: 1.0
19+
kind: univariate
20+
family: Normal
21+
parametrization_name: meanStd
22+
mu: 2.0
23+
sigma: 1.0
2024
state:
2125
type: shifted

notebooks/user_guide/02-generator-api.ipynb

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -624,9 +624,9 @@
624624
"metadata": {},
625625
"outputs": [],
626626
"source": [
627-
"repeat_a = GenericSeriesGenerator(seed=123).generate_from_scenario(mean_shift_scenario, name=\"repeat_a\")\n",
628-
"repeat_b = GenericSeriesGenerator(seed=123).generate_from_scenario(mean_shift_scenario, name=\"repeat_b\")\n",
629-
"repeat_c = GenericSeriesGenerator(seed=124).generate_from_scenario(mean_shift_scenario, name=\"repeat_c\")\n",
627+
"repeat_a = GenericSeriesGenerator(seed=42).generate_from_scenario(mean_shift_scenario, name=\"repeat_a\")\n",
628+
"repeat_b = GenericSeriesGenerator(seed=42).generate_from_scenario(mean_shift_scenario, name=\"repeat_b\")\n",
629+
"repeat_c = GenericSeriesGenerator(seed=42).generate_from_scenario(mean_shift_scenario, name=\"repeat_c\")\n",
630630
"\n",
631631
"print(\"Same-seed first five values match:\", repeat_a.data[:5, 0].tolist() == repeat_b.data[:5, 0].tolist())\n",
632632
"print(\"Different-seed first five values match:\", repeat_a.data[:5, 0].tolist() == repeat_c.data[:5, 0].tolist())"
@@ -800,12 +800,24 @@
800800
" ],\n",
801801
" \"plans\": {\n",
802802
" \"baseline\": {\n",
803-
" \"distribution\": {\"kind\": \"normal\", \"mean\": 0.0, \"std\": 1.0},\n",
803+
" \"distribution\": {\n",
804+
" \"kind\": \"univariate\",\n",
805+
" \"family\": \"Normal\",\n",
806+
" \"parametrization_name\": \"meanStd\",\n",
807+
" \"mu\": 0.0,\n",
808+
" \"sigma\": 1.0,\n",
809+
" },\n",
804810
" \"state\": {\"type\": \"baseline\"},\n",
805811
" \"name\": \"mapping baseline\",\n",
806812
" },\n",
807813
" \"shifted\": {\n",
808-
" \"distribution\": {\"kind\": \"normal\", \"mean\": 2.5, \"std\": 1.0},\n",
814+
" \"distribution\": {\n",
815+
" \"kind\": \"univariate\",\n",
816+
" \"family\": \"Normal\",\n",
817+
" \"parametrization_name\": \"meanStd\",\n",
818+
" \"mu\": 2.5,\n",
819+
" \"sigma\": 1.0,\n",
820+
" },\n",
809821
" \"state\": {\"type\": \"shifted\"},\n",
810822
" \"name\": \"mapping shifted\",\n",
811823
" },\n",

pysatl_cpd/data/generator/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,24 @@
257257
... {"plan_name": "b", "length": 20},
258258
... ],
259259
... "plans": {
260-
... "a": {"distribution": {"kind": "normal", "mean": 0.0, "std": 1.0}},
261-
... "b": {"distribution": {"kind": "normal", "mean": 2.0, "std": 1.0}},
260+
... "a": {
261+
... "distribution": {
262+
... "kind": "univariate",
263+
... "family": "Normal",
264+
... "parametrization_name": "meanStd",
265+
... "mu": 0.0,
266+
... "sigma": 1.0,
267+
... }
268+
... },
269+
... "b": {
270+
... "distribution": {
271+
... "kind": "univariate",
272+
... "family": "Normal",
273+
... "parametrization_name": "meanStd",
274+
... "mu": 2.0,
275+
... "sigma": 1.0,
276+
... }
277+
... },
262278
... },
263279
... }
264280
>>> spec = scenario_from_mapping(mapping)

pysatl_cpd/data/generator/config.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -190,26 +190,6 @@ def parse_distribution_spec(mapping: Mapping[str, Any]) -> DistributionSpec:
190190
"""
191191

192192
kind = _require_str(mapping.get("kind"), "distribution.kind")
193-
if kind == "normal":
194-
return UnivariateDistributionSpec(
195-
"Normal",
196-
"meanStd",
197-
mu=_optional_float(mapping.get("mean"), "distribution.mean", default=0.0),
198-
sigma=_optional_float(mapping.get("std"), "distribution.std", default=1.0),
199-
)
200-
if kind == "uniform":
201-
return UnivariateDistributionSpec(
202-
"ContinuousUniform",
203-
"standard",
204-
lower_bound=_optional_float(mapping.get("low"), "distribution.low", default=0.0),
205-
upper_bound=_optional_float(mapping.get("high"), "distribution.high", default=1.0),
206-
)
207-
if kind == "exponential":
208-
return UnivariateDistributionSpec(
209-
"Exponential",
210-
"scale",
211-
beta=_optional_float(mapping.get("scale"), "distribution.scale", default=1.0),
212-
)
213193
if kind == "univariate":
214194
return _parse_univariate_family_distribution(mapping)
215195
if kind == "multivariate_normal":
@@ -512,24 +492,3 @@ def _require_float(raw: object, path: str) -> float:
512492
if not isinstance(raw, int | float) or isinstance(raw, bool):
513493
raise ValueError(f"{path} must be a number")
514494
return float(raw)
515-
516-
517-
def _optional_float(raw: object, path: str, *, default: float) -> float:
518-
"""
519-
Get a float value or default if None.
520-
521-
Parameters
522-
----------
523-
raw
524-
Raw data to validate.
525-
path
526-
Path for error messages.
527-
default
528-
Default value if raw is None.
529-
530-
Returns
531-
-------
532-
value
533-
Validated float or default.
534-
"""
535-
return default if raw is None else _require_float(raw, path)

tests/unit/data/generator/test_config.py

Lines changed: 91 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,23 @@ def test_scenario_from_mapping_loads_direct_univariate_distribution() -> None:
4141
"plans": {
4242
"baseline": {
4343
"state": {"type": "baseline"},
44-
"distribution": {"kind": "normal", "mean": 0.0, "std": 1.0},
44+
"distribution": {
45+
"kind": "univariate",
46+
"family": "Normal",
47+
"parametrization_name": "meanStd",
48+
"mu": 0.0,
49+
"sigma": 1.0,
50+
},
4551
},
4652
"shifted": {
4753
"state": {"type": "shifted"},
48-
"distribution": {"kind": "normal", "mean": 2.0, "std": 1.0},
54+
"distribution": {
55+
"kind": "univariate",
56+
"family": "Normal",
57+
"parametrization_name": "meanStd",
58+
"mu": 2.0,
59+
"sigma": 1.0,
60+
},
4961
},
5062
},
5163
}
@@ -69,8 +81,20 @@ def test_scenario_from_mapping_loads_independent_columns_distribution() -> None:
6981
"distribution": {
7082
"kind": "independent_columns",
7183
"columns": {
72-
"x": {"kind": "normal", "mean": 0.0, "std": 1.0},
73-
"y": {"kind": "uniform", "low": -1.0, "high": 1.0},
84+
"x": {
85+
"kind": "univariate",
86+
"family": "Normal",
87+
"parametrization_name": "meanStd",
88+
"mu": 0.0,
89+
"sigma": 1.0,
90+
},
91+
"y": {
92+
"kind": "univariate",
93+
"family": "ContinuousUniform",
94+
"parametrization_name": "standard",
95+
"lower_bound": -1.0,
96+
"upper_bound": 1.0,
97+
},
7498
},
7599
}
76100
}
@@ -116,9 +140,11 @@ def test_scenario_from_yaml_loads_single_scenario(tmp_path) -> None: # type: ig
116140
plans:
117141
baseline:
118142
distribution:
119-
kind: normal
120-
mean: 0.0
121-
std: 1.0
143+
kind: univariate
144+
family: Normal
145+
parametrization_name: meanStd
146+
mu: 0.0
147+
sigma: 1.0
122148
""",
123149
encoding="utf-8",
124150
)
@@ -142,7 +168,11 @@ def test_scenarios_from_yaml_loads_mapping(tmp_path) -> None: # type: ignore[no
142168
plans:
143169
baseline:
144170
distribution:
145-
kind: normal
171+
kind: univariate
172+
family: Normal
173+
parametrization_name: meanStd
174+
mu: 0.0
175+
sigma: 1.0
146176
""",
147177
encoding="utf-8",
148178
)
@@ -164,7 +194,11 @@ def test_scenarios_from_yaml_loads_single_scenario_mapping(tmp_path) -> None: #
164194
plans:
165195
baseline:
166196
distribution:
167-
kind: normal
197+
kind: univariate
198+
family: Normal
199+
parametrization_name: meanStd
200+
mu: 0.0
201+
sigma: 1.0
168202
""",
169203
encoding="utf-8",
170204
)
@@ -175,6 +209,31 @@ def test_scenarios_from_yaml_loads_single_scenario_mapping(tmp_path) -> None: #
175209
assert scenarios["fallback"].name == "fallback"
176210

177211

212+
def test_scenario_from_yaml_loads_custom_parametrization_name(tmp_path) -> None: # type: ignore[no-untyped-def]
213+
path = tmp_path / "scenario.yaml"
214+
path.write_text(
215+
"""
216+
name: yaml_scenario
217+
segments:
218+
- plan_name: baseline
219+
length: 3
220+
plans:
221+
baseline:
222+
distribution:
223+
kind: univariate
224+
family: Normal
225+
parametrization_name: meanStd
226+
mu: 0.0
227+
sigma: 1.0
228+
""",
229+
encoding="utf-8",
230+
)
231+
232+
scenario = scenario_from_yaml(path)
233+
234+
assert scenario.plans["baseline"].distribution == UnivariateDistributionSpec("Normal", "meanStd", mu=0.0, sigma=1.0)
235+
236+
178237
def test_scenario_from_mapping_rejects_unknown_distribution_kind() -> None:
179238
with pytest.raises(ValueError, match="Unsupported distribution kind"):
180239
scenario_from_mapping(
@@ -222,7 +281,9 @@ def test_scenarios_from_yaml_rejects_non_mapping_single_scenario(tmp_path) -> No
222281

223282

224283
def test_parse_distribution_spec_loads_exponential_distribution() -> None:
225-
distribution = parse_distribution_spec({"kind": "exponential", "scale": 2.5})
284+
distribution = parse_distribution_spec(
285+
{"kind": "univariate", "family": "Exponential", "parametrization_name": "scale", "beta": 2.5}
286+
)
226287

227288
assert distribution == UnivariateDistributionSpec("Exponential", "scale", beta=2.5)
228289

@@ -246,8 +307,19 @@ def test_parse_distribution_spec_loads_independent_columns_univariate_variants()
246307
{
247308
"kind": "independent_columns",
248309
"columns": {
249-
"exp": {"kind": "exponential", "scale": 3.0},
250-
"uniform": {"kind": "uniform", "low": -2.0, "high": 2.0},
310+
"exp": {
311+
"kind": "univariate",
312+
"family": "Exponential",
313+
"parametrization_name": "scale",
314+
"beta": 3.0,
315+
},
316+
"uniform": {
317+
"kind": "univariate",
318+
"family": "ContinuousUniform",
319+
"parametrization_name": "standard",
320+
"lower_bound": -2.0,
321+
"upper_bound": 2.0,
322+
},
251323
},
252324
}
253325
)
@@ -264,6 +336,12 @@ def test_parse_distribution_spec_rejects_student_t_distribution() -> None:
264336
parse_distribution_spec({"kind": "student_t", "df": 7.0})
265337

266338

339+
@pytest.mark.parametrize("kind", ["normal", "uniform", "exponential"])
340+
def test_parse_distribution_spec_rejects_removed_legacy_univariate_kinds(kind: str) -> None:
341+
with pytest.raises(ValueError, match=rf"Unsupported distribution kind '{kind}'"):
342+
parse_distribution_spec({"kind": kind})
343+
344+
267345
def test_parse_hashable_mapping_rejects_non_hashable_value() -> None:
268346
with pytest.raises(ValueError, match="metadata.bad must be hashable"):
269347
_parse_hashable_mapping({"bad": []}, "metadata")
@@ -324,7 +402,7 @@ def test_scenario_from_mapping_rejects_non_scalar_state_value() -> None:
324402
"plans": {
325403
"baseline": {
326404
"state": {"bad": []},
327-
"distribution": {"kind": "normal"},
405+
"distribution": {"kind": "univariate", "family": "Normal", "mu": 0.0, "sigma": 1.0},
328406
}
329407
},
330408
}

0 commit comments

Comments
 (0)