Skip to content

Commit 4dd1b39

Browse files
gshvartsedeno
andauthored
adding function for plotting specific interval lists side by side (#1330)
* adding function for plotting specific interval lists side by side * incoporating as a class method and adding testing * update time units and docstring * add error and warning related to input intervallist * Fix linting * Add typing and improve docstring --------- Co-authored-by: Eric Denovellis <[email protected]> Co-authored-by: Eric Denovellis <[email protected]>
1 parent c25ea48 commit 4dd1b39

File tree

2 files changed

+128
-28
lines changed

2 files changed

+128
-28
lines changed

src/spyglass/common/common_interval.py

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import itertools
2+
import warnings
23
from functools import reduce
34
from typing import Iterable, List, Optional, Tuple, TypeVar, Union
45

56
import datajoint as dj
7+
import matplotlib as mpl
68
import matplotlib.pyplot as plt
79
import numpy as np
810
import pandas as pd
@@ -79,25 +81,83 @@ def fetch_interval(self):
7981
raise ValueError(f"Expected one row, got {len(self)}")
8082
return Interval(self.fetch1())
8183

82-
def plot_intervals(self, figsize=(20, 5), return_fig=False):
83-
"""Plot the intervals in the interval list."""
84-
interval_list = pd.DataFrame(self)
85-
fig, ax = plt.subplots(figsize=figsize)
86-
interval_count = 0
87-
for row in interval_list.itertuples(index=False):
88-
for interval in row.valid_times:
89-
ax.plot(interval, [interval_count, interval_count])
90-
ax.scatter(
91-
interval,
92-
[interval_count, interval_count],
93-
alpha=0.8,
94-
zorder=2,
95-
)
96-
interval_count += 1
97-
ax.set_yticks(np.arange(interval_list.shape[0]))
98-
ax.set_yticklabels(interval_list.interval_list_name)
99-
ax.set_xlabel("Time [s]")
100-
ax.grid(True)
84+
def plot_intervals(
85+
self, start_time: float = 0, return_fig: bool = False
86+
) -> Optional[plt.Figure]:
87+
"""
88+
Plot all intervals in the given IntervalList table.
89+
90+
Parameters
91+
----------
92+
start_time : float, optional
93+
The reference time (in seconds) for the interval comparison plot.
94+
For example, the first timepoint of a session. Defaults to 0.
95+
return_fig : bool, optional
96+
If True, return the matplotlib Figure object. Defaults to False.
97+
98+
Returns
99+
-------
100+
fig : matplotlib.figure.Figure or None
101+
The matplotlib Figure object if `return_fig` is True, otherwise None.
102+
103+
Raises
104+
------
105+
ValueError
106+
If more than one unique `nwb_file_name` is found in the IntervalList.
107+
The intended use is to compare intervals within a single NWB file.
108+
UserWarning
109+
If more than 100 intervals are being plotted.
110+
"""
111+
interval_lists_df = pd.DataFrame(self)
112+
113+
if len(interval_lists_df["nwb_file_name"].unique()) > 1:
114+
raise ValueError(
115+
">1 nwb_file_name found in IntervalList. the intended use of plot_intervals is to compare intervals within a single nwb_file_name."
116+
)
117+
118+
interval_list_names = interval_lists_df["interval_list_name"].values
119+
120+
n_compare = len(interval_list_names)
121+
122+
if n_compare > 100:
123+
warnings.warn(
124+
f"plot_intervals is plotting {n_compare} intervals. if this is unintended, please pass in a smaller IntervalList.",
125+
UserWarning,
126+
)
127+
128+
# plot broken bar horizontals
129+
fig, ax = plt.subplots(figsize=(20, 2 / 3 * n_compare))
130+
131+
# get n colors
132+
cmap = plt.get_cmap("turbo", n_compare)
133+
custom_palette = [mpl.colors.rgb2hex(cmap(i)) for i in range(cmap.N)]
134+
135+
def convert_intervals_to_range(intervals, start_time):
136+
return [
137+
((i[0] - start_time) / 60, (i[1] - i[0]) / 60)
138+
for i in intervals
139+
] # return time in minutes
140+
141+
all_intervals = interval_lists_df["valid_times"].values
142+
for i, (intervals, color) in enumerate(
143+
zip(all_intervals, custom_palette)
144+
):
145+
int_range = convert_intervals_to_range(intervals, start_time)
146+
ax.broken_barh(
147+
int_range, (10 * (i + 1), 6), facecolors=color, alpha=0.7
148+
)
149+
150+
ax.set_ylim(5, 10 * (n_compare + 1) + 5)
151+
ax.set_xlabel("time from start (min)", fontsize=16)
152+
ax.set_yticks(
153+
np.arange(n_compare) * 10 + 15,
154+
labels=interval_list_names,
155+
fontsize=16,
156+
)
157+
ax.set_xticks(ax.get_xticks())
158+
ax.set_xticklabels(ax.get_xticklabels())
159+
ax.tick_params(axis="x", labelsize=16)
160+
101161
if return_fig:
102162
return fig
103163

tests/common/test_interval.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23
from numpy import array_equal
34

@@ -7,17 +8,56 @@ def interval_list(common):
78
yield common.IntervalList()
89

910

10-
def test_plot_intervals(mini_insert, interval_list):
11-
fig = (interval_list & 'interval_list_name LIKE "raw%"').plot_intervals(
12-
return_fig=True
11+
def test_plot_intervals(interval_list, start_time=0):
12+
fig = (interval_list).plot_intervals(return_fig=True, start_time=start_time)
13+
14+
ax = fig.axes[0]
15+
16+
# extract interval list names from the figure
17+
fig_interval_list_names = np.array(
18+
[label.get_text() for label in ax.get_yticklabels()]
19+
)
20+
21+
# extract interval list names from the IntervalList table
22+
fetch_interval_list_names = np.array(
23+
interval_list.fetch("interval_list_name")
1324
)
14-
interval_list_name = fig.get_axes()[0].get_yticklabels()[0].get_text()
15-
times_fetch = (
16-
interval_list & {"interval_list_name": interval_list_name}
17-
).fetch1("valid_times")[0]
18-
times_plot = fig.get_axes()[0].lines[0].get_xdata()
1925

20-
assert array_equal(times_fetch, times_plot), "plot_intervals failed"
26+
# check that the interval list names match between the two methods
27+
assert array_equal(
28+
fig_interval_list_names, fetch_interval_list_names
29+
), "plot_intervals failed: plotted interval list names do not match"
30+
31+
# extract the interval times from the figure
32+
intervals_fig = []
33+
for collection in ax.collections:
34+
collection_data = []
35+
for patch in collection.get_paths():
36+
# Extract patch vertices to get x data
37+
vertices = patch.vertices
38+
x_start = vertices[0, 0]
39+
x_end = vertices[2, 0]
40+
collection_data.append(
41+
[x_start * 60 + start_time, x_end * 60 + start_time]
42+
)
43+
intervals_fig.append(np.array(collection_data))
44+
intervals_fig = np.array(intervals_fig, dtype="object")
45+
46+
# extract interval times from the IntervalList table
47+
intervals_fetch = interval_list.fetch("valid_times")
48+
49+
all_equal = True
50+
for i_fig, i_fetch in zip(intervals_fig, intervals_fetch):
51+
# permit rounding errors up to the 4th decimal place, as a result of inaccuracies during unit conversions
52+
i_fig = np.round(i_fig.astype("float"), 4)
53+
i_fetch = np.round(i_fetch.astype("float"), 4)
54+
if not array_equal(i_fig, i_fetch):
55+
all_equal = False
56+
57+
# check that the interval list times match between the two methods
58+
assert (
59+
all_equal
60+
), "plot_intervals failed: plotted interval times do not match"
2161

2262

2363
def test_plot_epoch(mini_insert, interval_list):

0 commit comments

Comments
 (0)