Skip to content
65 changes: 56 additions & 9 deletions mne_hfo/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
from mne_hfo import merge_channel_events, merge_overlapping_events
from mne_hfo.io import create_annotations_df

def plot_hfos(raw, annotations):
mne.viz.set_browser_backend("qt")
raw.set_annotations(annotations)
raw.plot(block=True)

def plot_hfo_event(raw, annotations, eventId):
mne.viz.set_browser_backend("qt")
Expand Down Expand Up @@ -44,7 +40,7 @@ def plot_hfo_event(raw, annotations, eventId):

ch_mins = np.min(subset, axis=1)
ch_maxs = np.max(subset, axis=1)

fig, axs = plt.subplots(subset.shape[0],1, sharex='col', gridspec_kw={'hspace': 0})
# if type(axs) != list:
# axs = [axs]
Expand All @@ -63,9 +59,60 @@ def plot_hfo_event(raw, annotations, eventId):
t_event = np.arange(orig_onset, orig_onset+orig_duration, 1/sfreq)
ax.fill_between(t_event, ch_mins[i], ch_maxs[i], facecolor='red', alpha=0.5)
fig.suptitle(f"Detections comprising Event {eventId}")
plt.show()
...
return fig, axs

def plot_corr_matrix(corr_matrix: np.ndarray,
det_list: list,
ax = None):
"""
Compares similarity between detector results.
Creates a plot of the comparison values in a len(det_list) x len(det_list) plot.


The detectors should be fit to the same data.

Parameters
----------
corr_matrix : np.ndarray
A numpy 2D matrix with all the comparison values for each detector listed in
det_list.
det_list : List
A list containing all Detector instances. Detectors should already be fit to the
data.
ax : matplotlib.axes.Axes (optional)
The axes to which to plot the chart. If no ax given, it will use the current axis.

Returns
-------
ax : matplotlib.axes.Axes
Axes object with comparison chart plotted
"""

# If no axis is provided, the current axis will be used,
if ax == None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this works if there is no figure. The way I have always seen this done is:

if ax is None:
fig, ax = plt.subplots(...)

that way you also have control over things like figure size

ax = plt.gca()

# Creates image using correlation matrix
im = ax.imshow(corr_matrix, cmap='inferno')
ax.set_xticks(np.arange(len(det_list)), labels=[det.__class__() for det in det_list])
ax.set_yticks(np.arange(len(det_list)), labels=[det.__class__() for det in det_list])
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")

# Loop over data dimensions and create text annotations.
for i in range(len(det_list)):
for j in range(len(det_list)):
if round(float(corr_matrix[i, j]),3) > 0.5:
color = 'k'
else:
color = 'w'
text = ax.text(j, i, round(float(corr_matrix[i, j]),3),
ha="center", va="center", color=color)

# Generates colorbar
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Similarity", rotation=-90, va="bottom")
ax.set_title("Detector Comparison")

def get_merged_annotations():
pass
ax.plot()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this line do anything?

return ax
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for consistency with the other plotting functions, I would return both fig, ax and then I would make the fig, ax as optional inputs for plot_hfo_events