Skip to content

Commit 273a26e

Browse files
committed
feat: add comprehensive tests and plotting functions for domain motion analysis
1 parent 763975e commit 273a26e

5 files changed

Lines changed: 675 additions & 10 deletions

File tree

docs/development.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,68 @@ pytest --ff # "failed first"
123123

124124
Test data should be placed in the `test_data/` directory. If test data is missing, tests will be automatically skipped.
125125

126+
### Plot Tests
127+
128+
The test suite includes comprehensive tests for all plotting functions in `sarcasm/plots.py`. These tests verify that plots are generated correctly without errors.
129+
130+
**Test Classes:**
131+
132+
| Test Class | Location | Marker | Description |
133+
|------------|----------|--------|-------------|
134+
| `TestStructureMetadata` | `test_structure.py` | - | Fast metadata tests |
135+
| `TestStructureTimelapseAnalysis` | `test_structure.py` | `slow` | Time-lapse analysis |
136+
| `TestStructureSingleImageAnalysis` | `test_structure.py` | `slow` | Single image analysis |
137+
| `TestStructureErrors` | `test_structure.py` | - | Fast error handling |
138+
| `TestStructureIntegration` | `test_structure.py` | `slow`, `integration` | Full workflow tests |
139+
| `TestStructurePlots` | `test_structure.py` | `slow` | Structure plotting (13 tests) |
140+
| `TestDomainMotionPlots` | `test_structure.py` | `slow` | Domain motion plots (3 tests) |
141+
| `TestMotion` | `test_motion.py` | `slow` | LOI detection and analysis |
142+
| `TestMotionIntegration` | `test_motion.py` | `slow`, `integration` | Full motion workflow |
143+
| `TestMotionPlots` | `test_motion.py` | `slow` | Motion plotting (10 tests) |
144+
145+
**Running Tests by Speed:**
146+
147+
```bash
148+
# Run only fast tests (unit tests, error handling, metadata)
149+
pytest -m "not slow" -v
150+
151+
# Run only slow tests (analysis, detection, plotting)
152+
pytest -m "slow" -v
153+
154+
# Run integration tests only
155+
pytest -m "integration" -v
156+
157+
# Skip both slow and integration tests
158+
pytest -m "not slow and not integration" -v
159+
```
160+
161+
**Running Plot Tests:**
162+
163+
```bash
164+
# Run all structure plot tests
165+
pytest tests/test_structure.py::TestStructurePlots -v
166+
167+
# Run all motion plot tests
168+
pytest tests/test_motion.py::TestMotionPlots -v
169+
170+
# Run domain motion plot tests (slow, uses 30kPa data)
171+
pytest tests/test_structure.py::TestDomainMotionPlots -v
172+
173+
# Run all plot tests together
174+
pytest tests/test_structure.py::TestStructurePlots tests/test_structure.py::TestDomainMotionPlots tests/test_motion.py::TestMotionPlots -v
175+
```
176+
177+
**Test Artifacts:**
178+
179+
Plot tests generate `*_sarcasm/` folders containing analysis results. By default, these are automatically cleaned up after the test session completes.
180+
181+
```bash
182+
# Keep generated folders for debugging
183+
pytest --keep-artifacts -v
184+
```
185+
186+
**Note:** The test suite uses `matplotlib.use('Agg')` to prevent popup windows during testing.
187+
126188
## Code Quality
127189

128190
The project uses several tools for code quality:

sarcasm/plots.py

Lines changed: 195 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sarcasm.motion import Motion
3030
from sarcasm.plot_utils import PlotUtils
3131
from sarcasm.structure import Structure
32+
from sarcasm.structure_modules import domain_clustering
3233
from sarcasm.utils import Utils
3334

3435

@@ -247,7 +248,8 @@ def plot_image(ax: Axes, sarc_obj: Union[Structure, Motion], frame: int = 0, cma
247248
font_properties={'size': PlotUtils.fontsize - 1}))
248249
ax.set_xticks([])
249250
ax.set_yticks([])
250-
ax.set_title(title, fontsize=PlotUtils.fontsize)
251+
if title is not None:
252+
ax.set_title(title, fontsize=PlotUtils.fontsize)
251253

252254
# Add inset axis if zoom_region is specified
253255
if zoom_region:
@@ -326,6 +328,7 @@ def plot_z_bands(ax: plt.Axes, sarc_obj: Union[Structure, Motion], frame=0, cmap
326328
# Mark the zoomed region on the main plot
327329
PlotUtils.plot_box(ax, xlim=(x1, x2), ylim=(y1, y2), c='w')
328330

331+
@staticmethod
329332
def plot_z_bands_midlines(ax: plt.Axes, sarc_obj: Union[Structure, Motion], frame=0, cmap='berlin',
330333
alpha=1, scalebar=True, title=None, color_scalebar='w',
331334
show_loi=True, zoom_region: Tuple[int, int, int, int] = None,
@@ -925,12 +928,10 @@ def plot_sarcomere_domains(ax: Axes, sarc_obj: Structure, frame=0, alpha=0.5, cm
925928
sarcomere_length_vectors = sarc_obj.data['sarcomere_length_vectors'][frame]
926929
area_min = sarc_obj.data['params.analyze_sarcomere_domains.area_min']
927930
dilation_radius = sarc_obj.data['params.analyze_sarcomere_domains.dilation_radius']
928-
domain_mask = sarc_obj._analyze_domains(domains, pos_vectors=pos_vectors,
929-
sarcomere_length_vectors=sarcomere_length_vectors,
930-
sarcomere_orientation_vectors=sarcomere_orientation_vectors,
931-
size=sarc_obj.metadata.size,
932-
pixelsize=sarc_obj.metadata.pixelsize,
933-
dilation_radius=dilation_radius, area_min=area_min)[0]
931+
domain_mask, *_ = domain_clustering.analyze_domains(
932+
domains, pos_vectors, sarcomere_orientation_vectors, sarcomere_length_vectors,
933+
size=sarc_obj.metadata.size, pixelsize=sarc_obj.metadata.pixelsize,
934+
dilation_radius=dilation_radius, area_min=area_min)
934935

935936
domain_mask_masked = np.ma.masked_where(domain_mask == 0, domain_mask)
936937
cmap = plt.get_cmap(cmap)
@@ -1517,6 +1518,193 @@ def plot_overlay_velocity(ax, motion_obj: Motion, number_contr=None, t_lim=(0, 0
15171518
ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
15181519
ax.xaxis.set_minor_locator(MultipleLocator(0.25))
15191520

1521+
@staticmethod
1522+
def plot_domain_timeseries(ax: Axes, sarc_obj: Structure, t_lim: Tuple[float, float] = (0, 12),
1523+
y_lim: Tuple[float, float] = (1.6, 2.2), n_rows: Optional[int] = None,
1524+
show_contr: bool = True, use_median: bool = False):
1525+
"""
1526+
Plots domain sarcomere length time-series in a stacked multi-subplot layout.
1527+
1528+
Each domain's sarcomere length time-series is shown in a separate row, with optional
1529+
contraction period shading. Similar layout to plot_delta_slen for Motion objects.
1530+
1531+
Parameters
1532+
----------
1533+
ax : matplotlib.axes.Axes
1534+
The axes to draw the plot on.
1535+
sarc_obj : Structure
1536+
The Structure object with domain motion analysis results.
1537+
t_lim : tuple of float, optional
1538+
The time limits for the plot in seconds. Defaults to (0, 12).
1539+
y_lim : tuple of float, optional
1540+
The y-axis limits for sarcomere length in µm. Defaults to (1.6, 2.2).
1541+
n_rows : int or None, optional
1542+
Number of domains to display. If None, shows all domains. Defaults to None.
1543+
show_contr : bool, optional
1544+
Whether to shade contraction periods. Defaults to True.
1545+
use_median : bool, optional
1546+
If True, use median sarcomere length instead of mean. Defaults to False.
1547+
1548+
Raises
1549+
------
1550+
ValueError
1551+
If domain motion analysis has not been run.
1552+
"""
1553+
# Validate prerequisites
1554+
if 'domain_slen_timeseries' not in sarc_obj.data:
1555+
raise ValueError("Domain motion analysis not run. Call analyze_domain_motion() first.")
1556+
1557+
# Get data
1558+
if use_median:
1559+
slen_timeseries = sarc_obj.data['domain_slen_median_timeseries']
1560+
else:
1561+
slen_timeseries = sarc_obj.data['domain_slen_timeseries']
1562+
n_domains, n_frames = slen_timeseries.shape
1563+
time = np.arange(n_frames) * sarc_obj.metadata.frametime
1564+
1565+
# Determine number of rows
1566+
if n_rows is None:
1567+
n_rows = n_domains
1568+
n_rows = min(n_rows, n_domains)
1569+
1570+
# Get contraction data if available
1571+
domain_contr = sarc_obj.data.get('domain_contr', None)
1572+
domain_labels_contr = sarc_obj.data.get('domain_labels_contr', None)
1573+
1574+
# Calculate y-ticks
1575+
y_range = y_lim[1] - y_lim[0]
1576+
y_step = y_range / 4
1577+
yticks = [y_lim[0] + y_step, y_lim[0] + 2 * y_step, y_lim[0] + 3 * y_step]
1578+
1579+
# Domain colormap
1580+
cm = plt.cm.gist_rainbow(np.linspace(0, 1, n_domains))
1581+
1582+
# Create inset axes for each domain
1583+
list_y = np.linspace(0, 1, num=n_rows, endpoint=False)
1584+
for i, y in enumerate(list_y):
1585+
domain_idx = n_rows - 1 - i # Reverse order so domain 1 is at bottom
1586+
if domain_idx >= n_domains:
1587+
continue
1588+
1589+
ax_i = ax.inset_axes((0., y, 1, 1 / n_rows - 0.02))
1590+
ax_i.plot(time, slen_timeseries[domain_idx], c=cm[domain_idx], lw=0.8)
1591+
ax_i.axhline(np.nanmean(slen_timeseries[domain_idx]), linewidth=0.5, linestyle=':', c='k')
1592+
1593+
# Shade contraction periods
1594+
if show_contr and domain_contr is not None:
1595+
contr = domain_contr[domain_idx]
1596+
ax_i.fill_between(time, y_lim[0], y_lim[1], where=contr, color='lavender', alpha=0.7)
1597+
1598+
# Configure axes
1599+
if i > 0:
1600+
ax_i.set_xticks([])
1601+
else:
1602+
PlotUtils.polish_xticks(ax_i, 2, 1)
1603+
1604+
ax_i.set_ylim(y_lim)
1605+
ax_i.set_xlim(t_lim)
1606+
ax_i.set_yticks(yticks)
1607+
ax_i.set_yticklabels([f'{yt:.2f}' for yt in yticks], fontsize='x-small')
1608+
1609+
# Add domain label
1610+
ax_i.text(0.02, 0.85, f'D{domain_idx + 1}', transform=ax_i.transAxes,
1611+
fontsize='x-small', fontweight='bold', color=cm[domain_idx])
1612+
1613+
# Configure main axes
1614+
ax.set_xlabel('Time [s]')
1615+
ax.set_ylabel('Sarcomere length [µm]')
1616+
ax.spines['bottom'].set_color('w')
1617+
ax.spines['top'].set_color('w')
1618+
ax.xaxis.label.set_color('k')
1619+
ax.tick_params(axis='x', colors='w')
1620+
ax.tick_params(axis='y', colors='w')
1621+
1622+
@staticmethod
1623+
def plot_overlay_domain_timeseries(ax: Axes, sarc_obj: Structure, t_lim: Tuple[float, float] = (0, 12),
1624+
y_lim: Tuple[float, float] = (1.4, 2.2), show_contr: bool = True,
1625+
show_average: bool = True, use_median: bool = False,
1626+
domain_indices: Optional[list] = None):
1627+
"""
1628+
Plots domain sarcomere length time-series as overlaid trajectories.
1629+
1630+
All domain time-series are plotted on the same axes with different colors,
1631+
optionally with an average line and contraction period shading.
1632+
1633+
Parameters
1634+
----------
1635+
ax : matplotlib.axes.Axes
1636+
The axes to draw the plot on.
1637+
sarc_obj : Structure
1638+
The Structure object with domain motion analysis results.
1639+
t_lim : tuple of float, optional
1640+
The time limits for the plot in seconds. Defaults to (0, 12).
1641+
y_lim : tuple of float, optional
1642+
The y-axis limits for sarcomere length in µm. Defaults to (1.6, 2.2).
1643+
show_contr : bool, optional
1644+
Whether to shade contraction periods (uses union of all domain contractions).
1645+
Defaults to True.
1646+
show_average : bool, optional
1647+
Whether to show the average across all domains. Defaults to True.
1648+
use_median : bool, optional
1649+
If True, use median sarcomere length instead of mean. Defaults to False.
1650+
domain_indices : list or None, optional
1651+
List of domain indices (0-based) to plot. If None, plots all domains.
1652+
Defaults to None.
1653+
1654+
Raises
1655+
------
1656+
ValueError
1657+
If domain motion analysis has not been run.
1658+
"""
1659+
# Validate prerequisites
1660+
if 'domain_slen_timeseries' not in sarc_obj.data:
1661+
raise ValueError("Domain motion analysis not run. Call analyze_domain_motion() first.")
1662+
1663+
# Get data
1664+
if use_median:
1665+
slen_timeseries = sarc_obj.data['domain_slen_median_timeseries']
1666+
else:
1667+
slen_timeseries = sarc_obj.data['domain_slen_timeseries']
1668+
n_domains, n_frames = slen_timeseries.shape
1669+
time = np.arange(n_frames) * sarc_obj.metadata.frametime
1670+
1671+
# Select domains to plot
1672+
if domain_indices is None:
1673+
domain_indices = list(range(n_domains))
1674+
domain_indices = [i for i in domain_indices if 0 <= i < n_domains]
1675+
1676+
# Get contraction data if available
1677+
domain_contr = sarc_obj.data.get('domain_contr', None)
1678+
1679+
# Shade contraction periods (union across selected domains)
1680+
if show_contr and domain_contr is not None:
1681+
any_contr = np.any(domain_contr[domain_indices], axis=0)
1682+
ax.fill_between(time, y_lim[0], y_lim[1], where=any_contr, color='lavender', alpha=0.5)
1683+
1684+
# Domain colormap
1685+
cm = plt.cm.gist_rainbow(np.linspace(0, 1, n_domains))
1686+
1687+
# Plot individual domain trajectories
1688+
for domain_idx in domain_indices:
1689+
ax.plot(time, slen_timeseries[domain_idx], c=cm[domain_idx], lw=0.8,
1690+
label=f'Domain {domain_idx + 1}', alpha=0.8)
1691+
1692+
# Plot average trajectory
1693+
if show_average and len(domain_indices) > 1:
1694+
avg_slen = np.nanmean(slen_timeseries[domain_indices], axis=0)
1695+
ax.plot(time, avg_slen, c='k', lw=2, linestyle='-', label='Average')
1696+
1697+
# Configure axes
1698+
ax.set_xlabel('Time [s]')
1699+
ax.set_ylabel('Sarcomere length [µm]')
1700+
ax.set_xlim(t_lim)
1701+
ax.set_ylim(y_lim)
1702+
PlotUtils.polish_xticks(ax, 2, 1)
1703+
PlotUtils.polish_yticks(ax, 0.2, 0.1)
1704+
1705+
# Add legend
1706+
ax.legend(loc='upper right', fontsize='x-small')
1707+
15201708
@staticmethod
15211709
def plot_phase_space(ax: Axes, motion_obj: Motion, t_lim=(0, 4), number_contr=None, frame=None):
15221710
"""

0 commit comments

Comments
 (0)