Skip to content

Commit 45450fe

Browse files
Add Hypnogram.plot_hypnodensity method (#233)
1 parent 11e54d5 commit 45450fe

3 files changed

Lines changed: 272 additions & 7 deletions

File tree

HYPNOGRAM_ROADMAP.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Goal: make `yasa.Hypnogram` the industry-standard Python object for handling sle
2828

2929
### Visualization & export
3030
- `plot_hypnogram()` — standard hypnogram figure
31+
- `plot_hypnodensity()` — per-epoch stage probabilities as a stacked area chart (requires `proba`); supports 2/3/4/5-stage hypnograms and datetime x-axis when `start` is set
3132
- `as_int()` — integer-encoded `pandas.Series`
3233
- `as_events()` — BIDS-compatible events `DataFrame` (onset, duration, stage)
3334
- `upsample(new_freq)` — change epoch resolution
@@ -42,8 +43,5 @@ Goal: make `yasa.Hypnogram` the industry-standard Python object for handling sle
4243
### I/O
4344
- **`from_edf_annotations(raw)`** — load hypnogram from EDF+ annotations.
4445

45-
### Analysis
46-
- **`plot_hypnodensity()`** — when `proba` is available, plot the per-epoch stage probability as a color-map (signature visualization of modern auto-staging papers).
47-
4846
### Multi-scorer support
4947
- **`HypnogramSet`** — new container class for multiple scorers of the same night (alignment, pairwise agreement, consensus scoring). See [HYPNOGRAM_MULTIPLE_SCORERS.md](HYPNOGRAM_MULTIPLE_SCORERS.md) for the full design plan.

src/yasa/hypno.py

Lines changed: 163 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class Hypnogram:
9393
- Compare two hypnograms epoch-by-epoch (kappa, F1, MCC, ...).
9494
* - :py:meth:`plot_hypnogram`
9595
- Plot the hypnogram as a standard hypnogram figure.
96+
* - :py:meth:`plot_hypnodensity`
97+
- Plot per-epoch stage probabilities as a stacked area chart (requires ``proba``).
9698
* - :py:meth:`simulate_similar`
9799
- Simulate a new hypnogram with the same transition probabilities as this one.
98100
@@ -1666,7 +1668,7 @@ def sleep_statistics(self):
16661668
16671669
>>> from yasa import simulate_hypnogram
16681670
>>> # Generate a 8 hr (= 480 minutes) 5-stage hypnogram with a 30-seconds resolution
1669-
>>> hyp = simulate_hypnogram(tib=480, seed=42)
1671+
>>> hyp = simulate_hypnogram(tib=300, seed=42)
16701672
>>> pd.Series(hyp.sleep_statistics())
16711673
TIB 480.0000
16721674
SPT 477.5000
@@ -1791,7 +1793,7 @@ def transition_matrix(self):
17911793
--------
17921794
>>> from yasa import Hypnogram, simulate_hypnogram
17931795
>>> # Generate a 8 hr (= 480 minutes) 5-stage hypnogram with a 30-seconds resolution
1794-
>>> hyp = simulate_hypnogram(tib=480, seed=42)
1796+
>>> hyp = simulate_hypnogram(tib=300, seed=42)
17951797
>>> counts, probs = hyp.transition_matrix()
17961798
>>> counts
17971799
To Stage WAKE N1 N2 N3 REM
@@ -1963,10 +1965,167 @@ def plot_hypnogram(self, **kwargs):
19631965
.. plot::
19641966
19651967
>>> from yasa import simulate_hypnogram
1966-
>>> ax = simulate_hypnogram(tib=480, seed=88).plot_hypnogram(highlight="REM")
1968+
>>> ax = simulate_hypnogram(tib=300, seed=88).plot_hypnogram(highlight="REM")
19671969
"""
19681970
return plot_hypnogram(self, **kwargs)
19691971

1972+
def plot_hypnodensity(self, palette=None, ax=None):
1973+
"""Plot the hypnodensity: per-epoch stage probabilities as a stacked area chart.
1974+
1975+
Requires that the :py:attr:`proba` attribute is set (i.e. the hypnogram was created by
1976+
:py:meth:`yasa.SleepStaging.predict`).
1977+
1978+
Parameters
1979+
----------
1980+
palette : dict or None
1981+
A dictionary mapping stage names to matplotlib colors, e.g.
1982+
``{"WAKE": "#99d7f1", "REM": "xkcd:sunflower"}``. When ``None`` (default), a
1983+
built-in palette is used. Missing stage keys fall back to ``"gray"``.
1984+
ax : :py:class:`matplotlib.axes.Axes` or None
1985+
Axis on which to draw the plot. If ``None`` (default), the current axis is used.
1986+
1987+
Returns
1988+
-------
1989+
ax : :py:class:`matplotlib.axes.Axes`
1990+
Matplotlib Axes
1991+
1992+
Raises
1993+
------
1994+
ValueError
1995+
If :py:attr:`proba` is ``None``.
1996+
1997+
Examples
1998+
--------
1999+
5-stage hypnogram:
2000+
2001+
.. plot::
2002+
2003+
>>> import numpy as np
2004+
>>> import pandas as pd
2005+
>>> from yasa import Hypnogram, simulate_hypnogram
2006+
>>> import matplotlib.pyplot as plt
2007+
>>> hyp = simulate_hypnogram(tib=300, n_stages=5, seed=42)
2008+
>>> stages = ["WAKE", "N1", "N2", "N3", "REM"]
2009+
>>> rng = np.random.default_rng(42)
2010+
>>> one_hot = (
2011+
... pd.get_dummies(hyp.hypno)
2012+
... .reindex(columns=stages, fill_value=0)
2013+
... .to_numpy(dtype=float)
2014+
... )
2015+
>>> noise = rng.dirichlet(np.ones(5) * 0.5, size=hyp.n_epochs)
2016+
>>> raw = 0.75 * one_hot + 0.25 * noise
2017+
>>> proba = pd.DataFrame(raw / raw.sum(axis=1, keepdims=True), columns=stages)
2018+
>>> ax = Hypnogram(hyp.hypno, n_stages=5, proba=proba).plot_hypnodensity()
2019+
>>> plt.tight_layout()
2020+
2021+
4-stage hypnogram:
2022+
2023+
.. plot::
2024+
2025+
>>> import numpy as np
2026+
>>> import pandas as pd
2027+
>>> from yasa import Hypnogram, simulate_hypnogram
2028+
>>> import matplotlib.pyplot as plt
2029+
>>> hyp = simulate_hypnogram(tib=300, n_stages=4, seed=42)
2030+
>>> stages = ["WAKE", "LIGHT", "DEEP", "REM"]
2031+
>>> rng = np.random.default_rng(42)
2032+
>>> one_hot = (
2033+
... pd.get_dummies(hyp.hypno)
2034+
... .reindex(columns=stages, fill_value=0)
2035+
... .to_numpy(dtype=float)
2036+
... )
2037+
>>> noise = rng.dirichlet(np.ones(4) * 0.5, size=hyp.n_epochs)
2038+
>>> raw = 0.75 * one_hot + 0.25 * noise
2039+
>>> proba = pd.DataFrame(raw / raw.sum(axis=1, keepdims=True), columns=stages)
2040+
>>> ax = Hypnogram(hyp.hypno, n_stages=4, proba=proba).plot_hypnodensity()
2041+
>>> plt.tight_layout()
2042+
2043+
2-stage hypnogram:
2044+
2045+
.. plot::
2046+
2047+
>>> import numpy as np
2048+
>>> import pandas as pd
2049+
>>> from yasa import Hypnogram, simulate_hypnogram
2050+
>>> import matplotlib.pyplot as plt
2051+
>>> hyp = simulate_hypnogram(tib=300, n_stages=2, seed=42)
2052+
>>> stages = ["WAKE", "SLEEP"]
2053+
>>> rng = np.random.default_rng(42)
2054+
>>> one_hot = (
2055+
... pd.get_dummies(hyp.hypno)
2056+
... .reindex(columns=stages, fill_value=0)
2057+
... .to_numpy(dtype=float)
2058+
... )
2059+
>>> noise = rng.dirichlet(np.ones(2) * 0.5, size=hyp.n_epochs)
2060+
>>> raw = 0.75 * one_hot + 0.25 * noise
2061+
>>> proba = pd.DataFrame(raw / raw.sum(axis=1, keepdims=True), columns=stages)
2062+
>>> ax = Hypnogram(hyp.hypno, n_stages=2, proba=proba).plot_hypnodensity()
2063+
>>> plt.tight_layout()
2064+
"""
2065+
import matplotlib.dates as mdates
2066+
import matplotlib.pyplot as plt
2067+
2068+
if self._proba is None:
2069+
raise ValueError(
2070+
"No probability data found. `proba` is only available when the Hypnogram "
2071+
"was created by `yasa.SleepStaging.predict()`."
2072+
)
2073+
2074+
# Default color palette covering all possible stage names.
2075+
# Base 5-stage colors: WAKE=#99d7f1, N1=#009ddc, N2=#0a437a, N3=#720058, REM=#ffc512
2076+
# Derived colors: LIGHT=avg(N1,N2), NREM=avg(N1,N2,N3), DEEP=N3, SLEEP=dark navy
2077+
_default_palette = {
2078+
"WAKE": "#99d7f1",
2079+
"N1": "#009ddc",
2080+
"N2": "#0a437a",
2081+
"N3": "#720058",
2082+
"REM": "#ffc512",
2083+
"LIGHT": "#0570ab", # avg(N1, N2)
2084+
"DEEP": "#720058", # = N3
2085+
"NREM": "#294b8f", # avg(N1, N2, N3)
2086+
"SLEEP": "#003366", # dark navy, pairs with light-blue WAKE
2087+
"ART": "#999999",
2088+
"UNS": "#cccccc",
2089+
}
2090+
if palette is None:
2091+
palette = _default_palette
2092+
2093+
proba = self._proba.copy()
2094+
stages = proba.columns.tolist()
2095+
colors = [palette.get(s, "gray") for s in stages]
2096+
2097+
# Increase font size while preserving original
2098+
old_fontsize = plt.rcParams["font.size"]
2099+
plt.rcParams.update({"font.size": 18})
2100+
2101+
if ax is None:
2102+
_, ax = plt.subplots(figsize=(12, 4))
2103+
2104+
# Build x-axis values
2105+
if self._start is not None:
2106+
times = pd.date_range(start=self._start, freq=self._freq, periods=self._n_epochs)
2107+
x = mdates.date2num(times)
2108+
xlabel = "Time"
2109+
else:
2110+
x = self._timedelta.total_seconds() / 60 # minutes
2111+
xlabel = "Time [mins]" if self._duration <= 90 else "Time [hrs]"
2112+
if self._duration > 90:
2113+
x = x / 60 # convert to hours
2114+
2115+
ax.stackplot(x, proba.to_numpy().T, labels=stages, colors=colors, alpha=0.85)
2116+
ax.set_xlim(x[0], x[-1])
2117+
ax.set_ylim(0, 1)
2118+
ax.set_ylabel("Probability")
2119+
ax.set_xlabel(xlabel)
2120+
ax.legend(frameon=False, bbox_to_anchor=(1, 1), loc="upper left")
2121+
ax.spines[["right", "top"]].set_visible(False)
2122+
if self._start is not None:
2123+
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))
2124+
ax.xaxis.set_major_locator(mdates.AutoDateLocator())
2125+
2126+
plt.rcParams.update({"font.size": old_fontsize})
2127+
return ax
2128+
19702129
#######################################################################
19712130
# SIMULATION
19722131
#######################################################################
@@ -2566,7 +2725,7 @@ def hypno_find_periods(hypno, sf_hypno, threshold="5min", equal_length=False):
25662725

25672726

25682727
def simulate_hypnogram(
2569-
tib=480,
2728+
tib=300,
25702729
trans_probas=None,
25712730
init_probas=None,
25722731
seed=None,

tests/test_hypnoclass.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,3 +750,111 @@ def test_pad_negative_after_raises():
750750
def test_pad_bad_type_raises():
751751
with pytest.raises(TypeError):
752752
Hypnogram(_PAD_STAGES).pad(before=3.5)
753+
754+
755+
###############################################################################
756+
# plot_hypnodensity
757+
###############################################################################
758+
759+
# Shared fixtures for hypnodensity tests
760+
_N = 100
761+
_rng = np.random.default_rng(0)
762+
763+
764+
def _make_proba(stages):
765+
"""Return a valid probability DataFrame for the given stage list."""
766+
raw = _rng.dirichlet(np.ones(len(stages)), size=_N)
767+
return pd.DataFrame(raw, columns=stages)
768+
769+
770+
def test_plot_hypnodensity_5stage_returns_axes():
771+
proba = _make_proba(["WAKE", "N1", "N2", "N3", "REM"])
772+
hyp = Hypnogram(
773+
["WAKE"] * 20 + ["N1"] * 10 + ["N2"] * 40 + ["N3"] * 20 + ["REM"] * 10, proba=proba
774+
)
775+
ax = hyp.plot_hypnodensity()
776+
assert isinstance(ax, plt.Axes)
777+
plt.close("all")
778+
779+
780+
def test_plot_hypnodensity_4stage_returns_axes():
781+
stages_seq = ["WAKE"] * 25 + ["LIGHT"] * 25 + ["DEEP"] * 25 + ["REM"] * 25
782+
proba = _make_proba(["WAKE", "LIGHT", "DEEP", "REM"])
783+
hyp = Hypnogram(stages_seq, n_stages=4, proba=proba)
784+
ax = hyp.plot_hypnodensity()
785+
assert isinstance(ax, plt.Axes)
786+
plt.close("all")
787+
788+
789+
def test_plot_hypnodensity_3stage_returns_axes():
790+
stages_seq = ["WAKE"] * 34 + ["NREM"] * 33 + ["REM"] * 33
791+
proba = _make_proba(["WAKE", "NREM", "REM"])
792+
hyp = Hypnogram(stages_seq, n_stages=3, proba=proba)
793+
ax = hyp.plot_hypnodensity()
794+
assert isinstance(ax, plt.Axes)
795+
plt.close("all")
796+
797+
798+
def test_plot_hypnodensity_2stage_returns_axes():
799+
stages_seq = ["WAKE"] * 50 + ["SLEEP"] * 50
800+
proba = _make_proba(["WAKE", "SLEEP"])
801+
hyp = Hypnogram(stages_seq, n_stages=2, proba=proba)
802+
ax = hyp.plot_hypnodensity()
803+
assert isinstance(ax, plt.Axes)
804+
plt.close("all")
805+
806+
807+
def test_plot_hypnodensity_no_proba_raises():
808+
hyp = Hypnogram(["WAKE"] * 10 + ["N2"] * 90)
809+
with pytest.raises(ValueError, match="proba"):
810+
hyp.plot_hypnodensity()
811+
812+
813+
def test_plot_hypnodensity_with_start_uses_datetime_axis():
814+
proba = _make_proba(["WAKE", "N1", "N2", "N3", "REM"])
815+
hyp = Hypnogram(
816+
["WAKE"] * 20 + ["N1"] * 10 + ["N2"] * 40 + ["N3"] * 20 + ["REM"] * 10,
817+
proba=proba,
818+
start="2022-12-15 22:30:00",
819+
)
820+
ax = hyp.plot_hypnodensity()
821+
assert isinstance(ax, plt.Axes)
822+
# x-axis should use a DateFormatter when start is set
823+
import matplotlib.dates as mdates
824+
825+
assert isinstance(ax.xaxis.get_major_formatter(), mdates.DateFormatter)
826+
plt.close("all")
827+
828+
829+
def test_plot_hypnodensity_accepts_ax_argument():
830+
proba = _make_proba(["WAKE", "N1", "N2", "N3", "REM"])
831+
hyp = Hypnogram(
832+
["WAKE"] * 20 + ["N1"] * 10 + ["N2"] * 40 + ["N3"] * 20 + ["REM"] * 10, proba=proba
833+
)
834+
fig, ax = plt.subplots()
835+
returned_ax = hyp.plot_hypnodensity(ax=ax)
836+
assert returned_ax is ax
837+
plt.close("all")
838+
839+
840+
def test_plot_hypnodensity_custom_palette():
841+
proba = _make_proba(["WAKE", "N1", "N2", "N3", "REM"])
842+
hyp = Hypnogram(
843+
["WAKE"] * 20 + ["N1"] * 10 + ["N2"] * 40 + ["N3"] * 20 + ["REM"] * 10, proba=proba
844+
)
845+
custom = {"WAKE": "red", "N1": "green", "N2": "blue", "N3": "purple", "REM": "orange"}
846+
ax = hyp.plot_hypnodensity(palette=custom)
847+
assert isinstance(ax, plt.Axes)
848+
plt.close("all")
849+
850+
851+
def test_plot_hypnodensity_ylim_and_legend():
852+
proba = _make_proba(["WAKE", "N1", "N2", "N3", "REM"])
853+
hyp = Hypnogram(
854+
["WAKE"] * 20 + ["N1"] * 10 + ["N2"] * 40 + ["N3"] * 20 + ["REM"] * 10, proba=proba
855+
)
856+
ax = hyp.plot_hypnodensity()
857+
assert ax.get_ylim() == (0, 1)
858+
legend_labels = [t.get_text() for t in ax.get_legend().get_texts()]
859+
assert set(legend_labels) == {"WAKE", "N1", "N2", "N3", "REM"}
860+
plt.close("all")

0 commit comments

Comments
 (0)