Skip to content

Commit d1ce281

Browse files
authored
Merge pull request #13 from dfeen87/copilot/implement-helper-functions
Add segmentation utilities for Stage 5 regime-transition detection
2 parents c355216 + e0814fe commit d1ce281

1 file changed

Lines changed: 351 additions & 0 deletions

File tree

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,352 @@
1+
"""
2+
Segmentation utilities – Stage 5 Regime-Transition Detection
3+
=============================================================
14
5+
Provides three helper utilities that support the manuscript's validation
6+
pipeline:
7+
8+
1. :func:`compute_threshold` – empirically derive the critical threshold Φ_c
9+
from an instability-magnitude series ΔΦ(t) (manuscript Section 5).
10+
11+
2. :func:`regime_shading_intervals` – convert per-time-point regime labels into
12+
contiguous (start, end, regime) intervals suitable for axis shading
13+
(e.g. ``matplotlib.axes.Axes.axvspan``).
14+
15+
3. :func:`extract_transition_events` – extract the ordered sequence of regime
16+
transitions from the label array, annotated with the time index at which
17+
each transition occurs.
18+
19+
None of the functions assign biological meaning to the detected segments; they
20+
operate purely on the numerical ΔΦ(t) series and the discrete label array
21+
produced by :func:`~src.stage5_regime_transitions.detect_regimes.detect_regimes`.
22+
23+
References
24+
----------
25+
Manuscript Section 5 (Regime Transition Structure):
26+
27+
"The threshold Φ_c is not assumed to be universal and must be determined
28+
empirically for each system under study, for example through baseline
29+
characterisation, surrogate analysis, or controlled perturbation
30+
experiments."
31+
32+
Manuscript Section 8.3 (Synthetic Validation Protocol):
33+
34+
"Baseline statistics (μ_X, σ_X) are computed over the stable regime
35+
(t < 60 s)."
36+
"""
37+
38+
from __future__ import annotations
39+
40+
from typing import List, NamedTuple
41+
42+
import numpy as np
43+
44+
45+
# ---------------------------------------------------------------------------
46+
# 1. Threshold computation
47+
# ---------------------------------------------------------------------------
48+
49+
def compute_threshold(
50+
delta_phi: np.ndarray,
51+
method: str = "quantile",
52+
*,
53+
q: float = 0.90,
54+
baseline_mask: np.ndarray | None = None,
55+
n_sigma: float = 2.0,
56+
) -> float:
57+
"""Compute an empirical threshold Φ_c from the instability magnitude ΔΦ(t).
58+
59+
The threshold is *not* assumed to be universal (manuscript Section 5).
60+
Two estimation strategies are supported:
61+
62+
``"quantile"``
63+
Φ_c = quantile(ΔΦ, q). The quantile is computed over the full
64+
supplied series unless a *baseline_mask* is provided, in which case
65+
only the masked samples are used. This matches the quantile-based
66+
empirical characterisation described in Sections 5 and 6 of the
67+
manuscript.
68+
69+
``"baseline"``
70+
Φ_c = μ_baseline + n_sigma · σ_baseline, where μ and σ are the mean
71+
and standard deviation of ΔΦ(t) restricted to the stable baseline
72+
window identified by *baseline_mask*. This mirrors the baseline
73+
normalisation defined in manuscript Eq. (27):
74+
75+
X_n(t) = ( X(t) − μ_X ) / σ_X, X ∈ {E, I, C}
76+
77+
and the protocol in Section 8.3 which computes baseline statistics
78+
"over the stable regime (t < 60 s)".
79+
80+
Parameters
81+
----------
82+
delta_phi : array_like, shape (T,)
83+
Instability magnitude ΔΦ(t) ≥ 0 as produced by
84+
:func:`~src.stage5_regime_transitions.compute_delta_phi.compute_delta_phi`.
85+
method : {"quantile", "baseline"}, default "quantile"
86+
Threshold estimation strategy (see above).
87+
q : float, default 0.90
88+
Quantile level for the ``"quantile"`` method. Must be in (0, 1].
89+
baseline_mask : array_like of bool, shape (T,), optional
90+
Boolean mask identifying the stable baseline samples.
91+
- For ``"quantile"``: when provided, the quantile is computed only
92+
over the masked (``True``) samples.
93+
- For ``"baseline"``: required; raises ``ValueError`` if omitted.
94+
n_sigma : float, default 2.0
95+
Number of standard deviations above the baseline mean for the
96+
``"baseline"`` method. Must be non-negative.
97+
98+
Returns
99+
-------
100+
threshold : float
101+
Empirical threshold Φ_c ≥ 0.
102+
103+
Raises
104+
------
105+
ValueError
106+
If *delta_phi* is not 1-D, *method* is unrecognised, quantile
107+
parameters are out of range, or the ``"baseline"`` method is called
108+
without a valid *baseline_mask*.
109+
110+
Examples
111+
--------
112+
>>> import numpy as np
113+
>>> from src.stage5_regime_transitions.segmentation_utils import compute_threshold
114+
>>> rng = np.random.default_rng(0)
115+
>>> delta_phi = np.abs(rng.normal(1.0, 0.5, 300))
116+
>>> thr = compute_threshold(delta_phi, method="quantile", q=0.90)
117+
>>> float(thr) >= 0
118+
True
119+
>>> mask = np.zeros(300, dtype=bool)
120+
>>> mask[:60] = True # first 60 samples are the stable baseline
121+
>>> thr_b = compute_threshold(delta_phi, method="baseline", baseline_mask=mask, n_sigma=2.0)
122+
>>> float(thr_b) >= 0
123+
True
124+
"""
125+
delta_phi = np.asarray(delta_phi, dtype=float)
126+
if delta_phi.ndim != 1:
127+
raise ValueError(
128+
f"delta_phi must be a 1-D array; got shape {delta_phi.shape}."
129+
)
130+
131+
if method == "quantile":
132+
if not (0.0 < q <= 1.0):
133+
raise ValueError(
134+
f"q must be in (0, 1]; got q={q}."
135+
)
136+
samples = delta_phi
137+
if baseline_mask is not None:
138+
mask = np.asarray(baseline_mask, dtype=bool)
139+
if mask.shape != delta_phi.shape:
140+
raise ValueError(
141+
"baseline_mask must have the same shape as delta_phi; "
142+
f"got {mask.shape} vs {delta_phi.shape}."
143+
)
144+
samples = delta_phi[mask]
145+
if samples.size == 0:
146+
raise ValueError(
147+
"baseline_mask selects no samples; cannot compute quantile."
148+
)
149+
return float(np.quantile(samples, q))
150+
151+
elif method == "baseline":
152+
if baseline_mask is None:
153+
raise ValueError(
154+
"baseline_mask is required for method='baseline'."
155+
)
156+
if n_sigma < 0.0:
157+
raise ValueError(
158+
f"n_sigma must be non-negative; got n_sigma={n_sigma}."
159+
)
160+
mask = np.asarray(baseline_mask, dtype=bool)
161+
if mask.shape != delta_phi.shape:
162+
raise ValueError(
163+
"baseline_mask must have the same shape as delta_phi; "
164+
f"got {mask.shape} vs {delta_phi.shape}."
165+
)
166+
baseline_samples = delta_phi[mask]
167+
if baseline_samples.size == 0:
168+
raise ValueError(
169+
"baseline_mask selects no samples; cannot compute baseline statistics."
170+
)
171+
mu = float(np.mean(baseline_samples))
172+
sigma = float(np.std(baseline_samples, ddof=0))
173+
return mu + n_sigma * sigma
174+
175+
else:
176+
raise ValueError(
177+
f"Unknown method '{method}'. Supported methods: 'quantile', 'baseline'."
178+
)
179+
180+
181+
# ---------------------------------------------------------------------------
182+
# 2. Regime shading intervals
183+
# ---------------------------------------------------------------------------
184+
185+
class ShadingInterval(NamedTuple):
186+
"""A contiguous block of time steps assigned to the same regime.
187+
188+
Attributes
189+
----------
190+
regime : str
191+
Regime label (one of ``"stable"``, ``"pre-instability"``,
192+
``"instability"``, ``"recovery"``).
193+
start : int
194+
Index of the first time step in the interval (inclusive).
195+
stop : int
196+
Index one past the last time step in the interval (exclusive),
197+
matching Python slice semantics. The interval covers indices
198+
``[start, stop)``.
199+
"""
200+
201+
regime: str
202+
start: int
203+
stop: int
204+
205+
206+
def regime_shading_intervals(labels: np.ndarray) -> List[ShadingInterval]:
207+
"""Convert a per-time-point regime label array into contiguous intervals.
208+
209+
Groups consecutive identical labels into :class:`ShadingInterval` records
210+
that can be passed directly to ``matplotlib.axes.Axes.axvspan`` (or any
211+
interval-based shading routine) for visualization of regime boundaries in
212+
the validation pipeline.
213+
214+
Parameters
215+
----------
216+
labels : array_like, shape (T,), dtype str or object
217+
Per-time-point regime strings as produced by
218+
:func:`~src.stage5_regime_transitions.detect_regimes.detect_regimes`.
219+
Each element must be one of ``"stable"``, ``"pre-instability"``,
220+
``"instability"``, or ``"recovery"``.
221+
222+
Returns
223+
-------
224+
intervals : list of :class:`ShadingInterval`
225+
Ordered list of non-overlapping contiguous regime intervals that
226+
together cover the full index range ``[0, T)``. The list is ordered
227+
by increasing *start* index.
228+
229+
Raises
230+
------
231+
ValueError
232+
If *labels* is not a 1-D array or is empty.
233+
234+
Examples
235+
--------
236+
>>> import numpy as np
237+
>>> from src.stage5_regime_transitions.segmentation_utils import regime_shading_intervals
238+
>>> labels = np.array(
239+
... ["stable"] * 4 + ["pre-instability"] * 3 + ["instability"] * 2
240+
... )
241+
>>> intervals = regime_shading_intervals(labels)
242+
>>> [(iv.regime, iv.start, iv.stop) for iv in intervals]
243+
[('stable', 0, 4), ('pre-instability', 4, 7), ('instability', 7, 9)]
244+
"""
245+
labels = np.asarray(labels, dtype=object)
246+
if labels.ndim != 1:
247+
raise ValueError(
248+
f"labels must be a 1-D array; got shape {labels.shape}."
249+
)
250+
if labels.size == 0:
251+
raise ValueError("labels array must not be empty.")
252+
253+
intervals: List[ShadingInterval] = []
254+
current_regime = labels[0]
255+
block_start = 0
256+
257+
for i in range(1, len(labels)):
258+
if labels[i] != current_regime:
259+
intervals.append(ShadingInterval(str(current_regime), block_start, i))
260+
current_regime = labels[i]
261+
block_start = i
262+
263+
# Close the final block.
264+
intervals.append(ShadingInterval(str(current_regime), block_start, len(labels)))
265+
266+
return intervals
267+
268+
269+
# ---------------------------------------------------------------------------
270+
# 3. Transition event extraction
271+
# ---------------------------------------------------------------------------
272+
273+
class TransitionEvent(NamedTuple):
274+
"""A single regime-transition event.
275+
276+
Attributes
277+
----------
278+
from_regime : str
279+
Regime label at time step ``index - 1``.
280+
to_regime : str
281+
Regime label at time step ``index`` (the first step of the new regime).
282+
index : int
283+
Time-step index at which the transition occurs, i.e. the first index
284+
that belongs to the new regime.
285+
"""
286+
287+
from_regime: str
288+
to_regime: str
289+
index: int
290+
291+
292+
def extract_transition_events(labels: np.ndarray) -> List[TransitionEvent]:
293+
"""Extract regime-transition events from a per-time-point label array.
294+
295+
A transition event is recorded whenever adjacent time steps carry
296+
different regime labels. The returned list contains one
297+
:class:`TransitionEvent` per detected boundary, ordered by increasing
298+
*index*.
299+
300+
This function supports the manuscript's validation pipeline (Section 5
301+
and 8.3) by providing a discrete, ordered representation of when the
302+
instability magnitude ΔΦ(t) crosses regime boundaries—enabling
303+
quantitative evaluation of detection lead times, false-positive rates,
304+
and regime-ordering statistics (manuscript Table 3).
305+
306+
Parameters
307+
----------
308+
labels : array_like, shape (T,), dtype str or object
309+
Per-time-point regime strings as produced by
310+
:func:`~src.stage5_regime_transitions.detect_regimes.detect_regimes`.
311+
312+
Returns
313+
-------
314+
events : list of :class:`TransitionEvent`
315+
Ordered list of transition events. Empty if *labels* contains only
316+
a single regime throughout.
317+
318+
Raises
319+
------
320+
ValueError
321+
If *labels* is not a 1-D array or is empty.
322+
323+
Examples
324+
--------
325+
>>> import numpy as np
326+
>>> from src.stage5_regime_transitions.segmentation_utils import extract_transition_events
327+
>>> labels = np.array(
328+
... ["stable"] * 4 + ["pre-instability"] * 3 + ["instability"] * 2
329+
... )
330+
>>> events = extract_transition_events(labels)
331+
>>> [(e.from_regime, e.to_regime, e.index) for e in events]
332+
[('stable', 'pre-instability', 4), ('pre-instability', 'instability', 7)]
333+
"""
334+
labels = np.asarray(labels, dtype=object)
335+
if labels.ndim != 1:
336+
raise ValueError(
337+
f"labels must be a 1-D array; got shape {labels.shape}."
338+
)
339+
if labels.size == 0:
340+
raise ValueError("labels array must not be empty.")
341+
342+
events: List[TransitionEvent] = []
343+
for i in range(1, len(labels)):
344+
if labels[i] != labels[i - 1]:
345+
events.append(
346+
TransitionEvent(
347+
from_regime=str(labels[i - 1]),
348+
to_regime=str(labels[i]),
349+
index=i,
350+
)
351+
)
352+
return events

0 commit comments

Comments
 (0)