Skip to content

Commit ba9a33a

Browse files
remramaCopilot
andauthored
Add Bland-Altman plots to evaluation module (#248)
* Draft with blandaltman from pingouin src * Scatter only (dropped pingouin src) * Add method-specific bias and LoA * Add passing tests * Add refline * Uniform parameters * Specify zorders * Swap scatter and facetgrid kwarg arguments * Resolve mutable default argument * Handle edge cases of LoA bands * Consistent assertion checks * ruff * ruff on tests * Skip title alignment in early versions of matplotlib * Fix kwargs in docstrings Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Deep assertion checks for `sleep_stats` Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Better check for CI bands in tests Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Clarify test comment * Add docstring example * Remove spread in negative edgecases * Added top-level docstring --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent bdb0628 commit ba9a33a

2 files changed

Lines changed: 375 additions & 0 deletions

File tree

src/yasa/evaluation.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,3 +1600,273 @@ def calibration_func(x, method="auto", adjust_all=False):
16001600
return (x - intercept) / (1 + slope)
16011601

16021602
return calibration_func
1603+
1604+
def plot_blandaltman(
1605+
self,
1606+
sleep_stats=None,
1607+
bias_method="auto",
1608+
loa_method="auto",
1609+
ci_method="auto",
1610+
flag_biased=False,
1611+
scatter_kwargs=None,
1612+
**kwargs,
1613+
):
1614+
"""Plot Bland-Altman agreement plots for one or more sleep statistics.
1615+
1616+
Each panel shows observed-minus-reference differences (y-axis) against reference values
1617+
(x-axis) for one sleep statistic. Bias and limits of agreement are drawn as lines, with
1618+
optional confidence-interval bands. Methods (parametric, regression, or bootstrap) are
1619+
chosen automatically per statistic based on the assumption tests stored in
1620+
:py:attr:`~yasa.SleepStatsAgreement.assumptions`, or can be set explicitly.
1621+
1622+
.. seealso:: :py:meth:`~yasa.SleepStatsAgreement.report`,
1623+
:py:meth:`~yasa.SleepStatsAgreement.summary`
1624+
1625+
Parameters
1626+
----------
1627+
sleep_stats : list or None
1628+
List of sleep statistics to plot. Default (None) is to plot all sleep statistics.
1629+
bias_method : str
1630+
If ``'param'``, bias is always the mean difference (horizontal line). If ``'regr'``,
1631+
bias is always a regression line. If ``'auto'`` (default), the method is chosen per
1632+
statistic based on the proportional-bias assumption test.
1633+
loa_method : str
1634+
If ``'param'``, limits of agreement are always horizontal lines (bias ± 1.96 SD). If
1635+
``'regr'``, they are always regression-modeled lines. If ``'auto'`` (default), the
1636+
method is chosen per statistic based on the homoscedasticity assumption test.
1637+
ci_method : str or None
1638+
If ``'param'``, parametric CIs are drawn. If ``'boot'``, bootstrap CIs are drawn. If
1639+
``'auto'`` (default), chosen per statistic based on the normality assumption test.
1640+
If ``None``, no confidence intervals are drawn.
1641+
flag_biased : bool
1642+
If True, sleep statistics with a statistically significant bias (i.e., the ``unbiased``
1643+
assumption is violated) are drawn with a red bias line instead of grey.
1644+
scatter_kwargs : dict
1645+
Other keyword arguments are passed through to :py:func:`matplotlib.pyplot.scatter`.
1646+
**kwargs : dict
1647+
Other keyword arguments are passed through to :py:class:`seaborn.FacetGrid`.
1648+
1649+
Returns
1650+
-------
1651+
g : :py:class:`seaborn.FacetGrid`
1652+
Seaborn FacetGrid
1653+
1654+
Examples
1655+
--------
1656+
.. plot::
1657+
1658+
>>> import yasa
1659+
>>> n = 20
1660+
>>> ref_hyps = [yasa.simulate_hypnogram(scorer="PSG", seed=i) for i in range(n)]
1661+
>>> obs_hyps = [ref_hyps[i].simulate_similar(scorer="Device", seed=i) for i in range(n)]
1662+
>>> eea = yasa.EpochByEpochAgreement(ref_hyps, obs_hyps)
1663+
>>> sstats = eea.get_sleep_stats()
1664+
>>> ssa = yasa.SleepStatsAgreement(sstats.loc["PSG"], sstats.loc["Device"])
1665+
>>> stats = ["TST", "WASO", "N1", "REM"]
1666+
>>> g = ssa.plot_blandaltman(sleep_stats=stats, ci_method="param")
1667+
"""
1668+
import seaborn as sns # noqa
1669+
import matplotlib.pyplot as plt
1670+
1671+
assert isinstance(sleep_stats, (list, type(None))), "`sleep_stats` must be a list or None"
1672+
assert isinstance(bias_method, str), "`bias_method` must be a string"
1673+
assert bias_method in self._bias_method_opts, (
1674+
f"`bias_method` must be one of {self._bias_method_opts}"
1675+
)
1676+
assert isinstance(loa_method, str), "`loa_method` must be a string"
1677+
assert loa_method in self._loa_method_opts, (
1678+
f"`loa_method` must be one of {self._loa_method_opts}"
1679+
)
1680+
assert ci_method is None or (
1681+
isinstance(ci_method, str) and ci_method in self._ci_method_opts
1682+
), f"`ci_method` must be one of {self._ci_method_opts} or None"
1683+
assert isinstance(flag_biased, bool), "`flag_biased` must be True or False"
1684+
assert isinstance(scatter_kwargs, (dict, type(None))), (
1685+
"`scatter_kwargs` must be a dict or None"
1686+
)
1687+
if scatter_kwargs is None:
1688+
scatter_kwargs = {}
1689+
if sleep_stats is None:
1690+
sleep_stats = self.sleep_statistics
1691+
1692+
# Validate sleep_stats content
1693+
assert isinstance(sleep_stats, list), "`sleep_stats` must be a list"
1694+
assert len(sleep_stats) > 0, "`sleep_stats` must be a non-empty list"
1695+
assert all(isinstance(stat, str) for stat in sleep_stats), (
1696+
"`sleep_stats` must be a list of strings"
1697+
)
1698+
assert len(sleep_stats) == len(set(sleep_stats)), (
1699+
"`sleep_stats` must not contain duplicate entries"
1700+
)
1701+
valid_stats = set(self.sleep_statistics)
1702+
invalid_stats = [stat for stat in sleep_stats if stat not in valid_stats]
1703+
assert not invalid_stats, (
1704+
"`sleep_stats` contains invalid statistics: "
1705+
f"{sorted(invalid_stats)}; valid options are {sorted(valid_stats)}"
1706+
)
1707+
# Resolve per-stat bias and loa methods
1708+
if bias_method == "auto":
1709+
bias_param_idx = self.auto_methods.query("bias == 'param'").index.tolist()
1710+
elif bias_method == "param":
1711+
bias_param_idx = sleep_stats
1712+
else:
1713+
bias_param_idx = []
1714+
1715+
if loa_method == "auto":
1716+
loa_param_idx = self.auto_methods.query("loa == 'param'").index.tolist()
1717+
elif loa_method == "param":
1718+
loa_param_idx = sleep_stats
1719+
else:
1720+
loa_param_idx = []
1721+
1722+
# Retrieve values and CIs
1723+
if ci_method is not None:
1724+
vals = self.summary(ci_method=ci_method)
1725+
else:
1726+
vals = pd.concat({"center": self._vals}, names=["interval"], axis=1).swaplevel(axis=1)
1727+
1728+
agreement_adj = self._agreement * np.sqrt(np.pi / 2)
1729+
1730+
# Identify stats with significant bias for optional flagging
1731+
biased_stats = (
1732+
self.assumptions.query("unbiased == False").index.tolist() if flag_biased else []
1733+
)
1734+
1735+
# Select scatterplot arguments and update with optional input
1736+
default_scatter_kwargs = dict(facecolor="none", edgecolor="black", alpha=0.8)
1737+
scatter_kwargs = default_scatter_kwargs | scatter_kwargs
1738+
# Select FacetGrid arguments and update with optional input
1739+
default_facetgrid_kwargs = dict(
1740+
data=self._data.reset_index("sleep_stat"),
1741+
col="sleep_stat",
1742+
col_order=sleep_stats,
1743+
col_wrap=5 if len(sleep_stats) > 5 else None,
1744+
height=2,
1745+
aspect=1,
1746+
sharex=False,
1747+
sharey=False,
1748+
)
1749+
facetgrid_kwargs = default_facetgrid_kwargs | kwargs
1750+
# Choose display levels with zorder
1751+
data_zorder = 30
1752+
bias_zorder = 20
1753+
loa_zorder = 10
1754+
refline_zorder = 0
1755+
# Initialize a grid of plots with an Axes for each sleep statistic
1756+
g = sns.FacetGrid(**facetgrid_kwargs)
1757+
# Draw scatterplot on each axis
1758+
g.map(plt.scatter, self.ref_scorer, "difference", zorder=data_zorder, **scatter_kwargs)
1759+
# Draw a horizontal line at y=0 on each axis
1760+
g.refline(y=0, color="black", linewidth=1, linestyle="solid", zorder=refline_zorder)
1761+
# Choose arguments for all calls to axhspan and fill_between for bias and LoA CI bands
1762+
band_kwargs = dict(edgecolor="none", alpha=0.15)
1763+
# Choose arguments for all calls to axhline and plot for bias and LoA lines
1764+
line_kwargs = dict(linewidth=1, linestyle="dashed", alpha=0.9)
1765+
loa_color = "tab:blue"
1766+
bias_default_color = "tab:gray" # when not flagged as biased
1767+
bias_flagged_color = "tab:red"
1768+
# Draw bias lines, LoA lines, and CI bands on each axis
1769+
for stat, ax in zip(sleep_stats, g.axes.flat, strict=True):
1770+
x_min, x_max = ax.get_xlim()
1771+
x_line = np.array([x_min, x_max])
1772+
v = vals.loc[stat]
1773+
has_ci = ci_method is not None
1774+
bias_color = bias_flagged_color if stat in biased_stats else bias_default_color
1775+
1776+
# --- Bias line ---
1777+
if stat in bias_param_idx:
1778+
y_bias = v[("bias_mean", "center")]
1779+
ax.axhline(y_bias, color=bias_color, zorder=bias_zorder, **line_kwargs)
1780+
if has_ci:
1781+
ax.axhspan(
1782+
v[("bias_mean", "lower")],
1783+
v[("bias_mean", "upper")],
1784+
facecolor=bias_color,
1785+
zorder=bias_zorder - 1,
1786+
**band_kwargs,
1787+
)
1788+
y_bias_arr = np.full_like(x_line, y_bias, dtype=float)
1789+
else:
1790+
intercept = v[("bias_intercept", "center")]
1791+
slope = v[("bias_slope", "center")]
1792+
y_bias_arr = intercept + slope * x_line
1793+
ax.plot(x_line, y_bias_arr, color=bias_color, zorder=bias_zorder, **line_kwargs)
1794+
if has_ci:
1795+
y_ci_a = v[("bias_intercept", "lower")] + v[("bias_slope", "lower")] * x_line
1796+
y_ci_b = v[("bias_intercept", "upper")] + v[("bias_slope", "upper")] * x_line
1797+
y_lo = np.minimum(y_ci_a, y_ci_b)
1798+
y_hi = np.maximum(y_ci_a, y_ci_b)
1799+
ax.fill_between(
1800+
x_line,
1801+
y_lo,
1802+
y_hi,
1803+
facecolor=bias_color,
1804+
zorder=bias_zorder - 1,
1805+
**band_kwargs,
1806+
)
1807+
1808+
# --- LoA lines ---
1809+
if stat in loa_param_idx:
1810+
for loa_var in ("loa_lower", "loa_upper"):
1811+
y_loa = v[(loa_var, "center")]
1812+
ax.axhline(y_loa, color=loa_color, zorder=loa_zorder, **line_kwargs)
1813+
if has_ci:
1814+
ax.axhspan(
1815+
v[(loa_var, "lower")],
1816+
v[(loa_var, "upper")],
1817+
facecolor=loa_color,
1818+
zorder=loa_zorder - 1,
1819+
**band_kwargs,
1820+
)
1821+
else:
1822+
loa_int = v[("loa_intercept", "center")]
1823+
loa_slp = v[("loa_slope", "center")]
1824+
y_spread = agreement_adj * np.maximum(0.0, loa_int + loa_slp * x_line)
1825+
ax.plot(
1826+
x_line, y_bias_arr + y_spread, color=loa_color, zorder=loa_zorder, **line_kwargs
1827+
)
1828+
ax.plot(
1829+
x_line, y_bias_arr - y_spread, color=loa_color, zorder=loa_zorder, **line_kwargs
1830+
)
1831+
if has_ci:
1832+
lint_lo = v[("loa_intercept", "lower")]
1833+
lint_hi = v[("loa_intercept", "upper")]
1834+
lslp_lo = v[("loa_slope", "lower")]
1835+
lslp_hi = v[("loa_slope", "upper")]
1836+
spread_a = agreement_adj * (lint_lo + lslp_lo * x_line)
1837+
spread_b = agreement_adj * (lint_hi + lslp_hi * x_line)
1838+
spread_lo = np.minimum(spread_a, spread_b)
1839+
spread_hi = np.maximum(spread_a, spread_b)
1840+
ax.fill_between(
1841+
x_line,
1842+
y_bias_arr + spread_lo,
1843+
y_bias_arr + spread_hi,
1844+
facecolor=loa_color,
1845+
zorder=loa_zorder - 1,
1846+
**band_kwargs,
1847+
)
1848+
ax.fill_between(
1849+
x_line,
1850+
y_bias_arr - spread_hi,
1851+
y_bias_arr - spread_lo,
1852+
facecolor=loa_color,
1853+
zorder=loa_zorder - 1,
1854+
**band_kwargs,
1855+
)
1856+
1857+
# Tidy-up axis limits with symmetric y-axis and minimal ticks
1858+
for ax in g.axes.flat:
1859+
bound = max(map(abs, ax.get_ylim()))
1860+
ax.set_ylim(-bound, bound)
1861+
ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=2, integer=True, symmetric=True))
1862+
ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=1, integer=True))
1863+
# More aesthetics
1864+
ylabel = " - ".join((self.obs_scorer, self.ref_scorer))
1865+
g.set_ylabels(ylabel)
1866+
g.set_xlabels(self.ref_scorer)
1867+
g.set_titles(col_template="{col_name}")
1868+
if hasattr(g.fig, "align_titles"): # introduced in matplotlib v3.9.0
1869+
g.fig.align_titles()
1870+
g.fig.align_labels()
1871+
g.tight_layout(w_pad=1, h_pad=2)
1872+
return g

tests/test_evaluation.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,3 +405,108 @@ def test_invalid_decimals_raises(self):
405405
def test_invalid_bias_method_raises(self):
406406
with pytest.raises(AssertionError):
407407
ssa.report(bias_method="invalid")
408+
409+
410+
class TestSleepStatsAgreementPlotBlandAltman(unittest.TestCase):
411+
"""Test the plot_blandaltman method.
412+
413+
Use ci_method="param" to avoid the bootstrap path with small samples (N_SESSIONS=5).
414+
"""
415+
416+
@classmethod
417+
def setUpClass(cls):
418+
import matplotlib
419+
420+
matplotlib.use("Agg")
421+
422+
def test_returns_facetgrid(self):
423+
import seaborn as sns
424+
425+
g = ssa.plot_blandaltman(ci_method="param")
426+
assert isinstance(g, sns.FacetGrid)
427+
428+
def test_default_auto_methods(self):
429+
g = ssa.plot_blandaltman(ci_method="param")
430+
assert len(g.axes.flat) == len(ssa.sleep_statistics)
431+
432+
def test_param_bias_param_loa(self):
433+
g = ssa.plot_blandaltman(bias_method="param", loa_method="param", ci_method="param")
434+
# Each axis should have lines drawn (axhline creates Line2D objects)
435+
for ax in g.axes.flat:
436+
assert len(ax.lines) > 0
437+
438+
def test_regr_bias_regr_loa(self):
439+
g = ssa.plot_blandaltman(bias_method="regr", loa_method="regr", ci_method="param")
440+
for ax in g.axes.flat:
441+
assert len(ax.lines) > 0
442+
443+
def test_no_ci(self):
444+
g = ssa.plot_blandaltman(ci_method=None)
445+
# With no CI, axes should have no patches (no fill_between / axhspan)
446+
for ax in g.axes.flat:
447+
assert len(ax.patches) == 0
448+
449+
def test_ci_adds_patches(self):
450+
g = ssa.plot_blandaltman(ci_method="param")
451+
# Patches from parametric CI bands should be drawn with axhspan
452+
has_patches = any(len(ax.patches) > 0 for ax in g.axes.flat)
453+
assert has_patches
454+
455+
def test_sleep_stats_subset(self):
456+
subset = ssa.sleep_statistics[:3]
457+
g = ssa.plot_blandaltman(sleep_stats=subset, ci_method="param")
458+
assert len(g.axes.flat) == len(subset)
459+
460+
def test_flag_biased_false(self):
461+
# Should not raise
462+
g = ssa.plot_blandaltman(flag_biased=False, ci_method="param")
463+
assert g is not None
464+
465+
def test_flag_biased_true(self):
466+
# Should not raise
467+
g = ssa.plot_blandaltman(flag_biased=True, ci_method="param")
468+
assert g is not None
469+
470+
def test_xlabel_is_ref_scorer(self):
471+
g = ssa.plot_blandaltman(ci_method="param")
472+
# x-axis label should be the reference scorer name
473+
assert g.axes.flat[-1].get_xlabel() == REF_SCORER
474+
475+
def test_ylabel_format(self):
476+
g = ssa.plot_blandaltman(ci_method="param")
477+
expected = f"{OBS_SCORER} - {REF_SCORER}"
478+
assert g.axes.flat[0].get_ylabel() == expected
479+
480+
def test_invalid_bias_method_raises(self):
481+
with pytest.raises(AssertionError):
482+
ssa.plot_blandaltman(bias_method="invalid")
483+
484+
def test_invalid_loa_method_raises(self):
485+
with pytest.raises(AssertionError):
486+
ssa.plot_blandaltman(loa_method="invalid")
487+
488+
def test_invalid_ci_method_raises(self):
489+
with pytest.raises(AssertionError):
490+
ssa.plot_blandaltman(ci_method="invalid")
491+
492+
def test_invalid_flag_biased_raises(self):
493+
with pytest.raises(AssertionError):
494+
ssa.plot_blandaltman(flag_biased="yes")
495+
496+
def test_scatter_kwargs_passthrough(self):
497+
g = ssa.plot_blandaltman(ci_method="param", scatter_kwargs={"edgecolor": "red"})
498+
# Scatter points on first axis should have the custom color
499+
scatter = ax_collections(g.axes.flat[0])
500+
assert len(scatter) > 0
501+
502+
def test_facetgrid_kwargs_passthrough(self):
503+
g = ssa.plot_blandaltman(ci_method="param", col_wrap=1)
504+
# FacetGrid col_wrap should reflect the override
505+
assert g._col_wrap == 1
506+
507+
508+
def ax_collections(ax):
509+
"""Return PathCollections (scatter plots) from an Axes."""
510+
from matplotlib.collections import PathCollection
511+
512+
return [c for c in ax.collections if isinstance(c, PathCollection)]

0 commit comments

Comments
 (0)