|
1 | 1 | import itertools |
| 2 | +import warnings |
2 | 3 | from functools import reduce |
3 | 4 | from typing import Iterable, List, Optional, Tuple, TypeVar, Union |
4 | 5 |
|
5 | 6 | import datajoint as dj |
| 7 | +import matplotlib as mpl |
6 | 8 | import matplotlib.pyplot as plt |
7 | 9 | import numpy as np |
8 | 10 | import pandas as pd |
@@ -79,25 +81,83 @@ def fetch_interval(self): |
79 | 81 | raise ValueError(f"Expected one row, got {len(self)}") |
80 | 82 | return Interval(self.fetch1()) |
81 | 83 |
|
82 | | - def plot_intervals(self, figsize=(20, 5), return_fig=False): |
83 | | - """Plot the intervals in the interval list.""" |
84 | | - interval_list = pd.DataFrame(self) |
85 | | - fig, ax = plt.subplots(figsize=figsize) |
86 | | - interval_count = 0 |
87 | | - for row in interval_list.itertuples(index=False): |
88 | | - for interval in row.valid_times: |
89 | | - ax.plot(interval, [interval_count, interval_count]) |
90 | | - ax.scatter( |
91 | | - interval, |
92 | | - [interval_count, interval_count], |
93 | | - alpha=0.8, |
94 | | - zorder=2, |
95 | | - ) |
96 | | - interval_count += 1 |
97 | | - ax.set_yticks(np.arange(interval_list.shape[0])) |
98 | | - ax.set_yticklabels(interval_list.interval_list_name) |
99 | | - ax.set_xlabel("Time [s]") |
100 | | - ax.grid(True) |
| 84 | + def plot_intervals( |
| 85 | + self, start_time: float = 0, return_fig: bool = False |
| 86 | + ) -> Optional[plt.Figure]: |
| 87 | + """ |
| 88 | + Plot all intervals in the given IntervalList table. |
| 89 | +
|
| 90 | + Parameters |
| 91 | + ---------- |
| 92 | + start_time : float, optional |
| 93 | + The reference time (in seconds) for the interval comparison plot. |
| 94 | + For example, the first timepoint of a session. Defaults to 0. |
| 95 | + return_fig : bool, optional |
| 96 | + If True, return the matplotlib Figure object. Defaults to False. |
| 97 | +
|
| 98 | + Returns |
| 99 | + ------- |
| 100 | + fig : matplotlib.figure.Figure or None |
| 101 | + The matplotlib Figure object if `return_fig` is True, otherwise None. |
| 102 | +
|
| 103 | + Raises |
| 104 | + ------ |
| 105 | + ValueError |
| 106 | + If more than one unique `nwb_file_name` is found in the IntervalList. |
| 107 | + The intended use is to compare intervals within a single NWB file. |
| 108 | + UserWarning |
| 109 | + If more than 100 intervals are being plotted. |
| 110 | + """ |
| 111 | + interval_lists_df = pd.DataFrame(self) |
| 112 | + |
| 113 | + if len(interval_lists_df["nwb_file_name"].unique()) > 1: |
| 114 | + raise ValueError( |
| 115 | + ">1 nwb_file_name found in IntervalList. the intended use of plot_intervals is to compare intervals within a single nwb_file_name." |
| 116 | + ) |
| 117 | + |
| 118 | + interval_list_names = interval_lists_df["interval_list_name"].values |
| 119 | + |
| 120 | + n_compare = len(interval_list_names) |
| 121 | + |
| 122 | + if n_compare > 100: |
| 123 | + warnings.warn( |
| 124 | + f"plot_intervals is plotting {n_compare} intervals. if this is unintended, please pass in a smaller IntervalList.", |
| 125 | + UserWarning, |
| 126 | + ) |
| 127 | + |
| 128 | + # plot broken bar horizontals |
| 129 | + fig, ax = plt.subplots(figsize=(20, 2 / 3 * n_compare)) |
| 130 | + |
| 131 | + # get n colors |
| 132 | + cmap = plt.get_cmap("turbo", n_compare) |
| 133 | + custom_palette = [mpl.colors.rgb2hex(cmap(i)) for i in range(cmap.N)] |
| 134 | + |
| 135 | + def convert_intervals_to_range(intervals, start_time): |
| 136 | + return [ |
| 137 | + ((i[0] - start_time) / 60, (i[1] - i[0]) / 60) |
| 138 | + for i in intervals |
| 139 | + ] # return time in minutes |
| 140 | + |
| 141 | + all_intervals = interval_lists_df["valid_times"].values |
| 142 | + for i, (intervals, color) in enumerate( |
| 143 | + zip(all_intervals, custom_palette) |
| 144 | + ): |
| 145 | + int_range = convert_intervals_to_range(intervals, start_time) |
| 146 | + ax.broken_barh( |
| 147 | + int_range, (10 * (i + 1), 6), facecolors=color, alpha=0.7 |
| 148 | + ) |
| 149 | + |
| 150 | + ax.set_ylim(5, 10 * (n_compare + 1) + 5) |
| 151 | + ax.set_xlabel("time from start (min)", fontsize=16) |
| 152 | + ax.set_yticks( |
| 153 | + np.arange(n_compare) * 10 + 15, |
| 154 | + labels=interval_list_names, |
| 155 | + fontsize=16, |
| 156 | + ) |
| 157 | + ax.set_xticks(ax.get_xticks()) |
| 158 | + ax.set_xticklabels(ax.get_xticklabels()) |
| 159 | + ax.tick_params(axis="x", labelsize=16) |
| 160 | + |
101 | 161 | if return_fig: |
102 | 162 | return fig |
103 | 163 |
|
|
0 commit comments