Skip to content

Plots improvements #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
phys2bids/_version.py export-subst
nigsp/_version.py export-subst

*.py eol=lf
*.rst eol=lf
13 changes: 9 additions & 4 deletions nigsp/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def export_metric(scgraph, outext, outprefix):
return 0


def plot_metric(scgraph, outprefix, atlas=None, thr=None):
def plot_metric(scgraph, outprefix, atlas=None, title=None, thr=None):
"""
If possible, plot metrics as markerplot.

Expand All @@ -109,10 +109,10 @@ def plot_metric(scgraph, outprefix, atlas=None, thr=None):
The internal object containing all data.
outprefix : str
The prefix of the png file to export
img : 3DNiftiImage or None, optional
The nifti image of the atlas
atlas : 3D Nifti1Image, numpy.ndarray, or None, optional
Either a nifti image containing a valid atlas or a set of parcel coordinates.
title : None or str, optional
Add a title to the graph
thr : float or None, optional
The threshold to use in plotting the nodes.
"""
Expand All @@ -135,14 +135,19 @@ def plot_metric(scgraph, outprefix, atlas=None, thr=None):
if atlas_plot is not None:
if scgraph.sdi is not None:
viz.plot_nodes(
scgraph.sdi, atlas_plot, filename=f"{outprefix}sdi.png", thr=thr
scgraph.sdi,
atlas_plot,
filename=f"{outprefix}sdi.png",
title=title,
thr=thr,
)
elif scgraph.gsdi is not None:
for k in scgraph.gsdi.keys():
viz.plot_nodes(
scgraph.gsdi[k],
atlas_plot,
filename=f"{outprefix}gsdi_{k}.png",
title=title,
thr=thr,
)

Expand Down
2 changes: 1 addition & 1 deletion nigsp/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def compute_graph_energy(self, mean=False): # pragma: no cover
)
return self

def split_graph(self, index=None, keys=["low", "high"]):
def split_graph(self, index=None, keys=["low-pass", "high-pass"]):
"""Implement timeseries.median_cutoff_frequency_idx as class method."""
if index is None:
index = self.index
Expand Down
14 changes: 7 additions & 7 deletions nigsp/operations/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def resize_ts(timeseries, resize=None, globally=False):
if resize == "spc": # pragma: no cover
LGR.info("Expressing timeseries in signal percentage change")
timeseries = spc_ts(timeseries, globally=globally)
elif resize == "norm": # pragma: no cover
elif resize in ["norm", "zscore"]: # pragma: no cover
LGR.info("Normalise timeseries")
timeseries = normalise_ts(timeseries, globally=globally)
elif resize == "demean": # pragma: no cover
Expand Down Expand Up @@ -338,13 +338,13 @@ def median_cutoff_frequency_idx(energy):
return freq_idx


def graph_filter(timeseries, eigenvec, freq_idx, keys=["low", "high"]):
def graph_filter(timeseries, eigenvec, freq_idx, keys=["low-pass", "high-pass"]):
"""
Filter a graph decomposition into two parts based on freq_idx.

Return the two eigenvector lists (high freq and low freq) that are equal
to the original eigenvector list, but "low" is zero-ed for all frequencies
>= of the given index, and "high" is zero-ed for all frequencies < to the
to the original eigenvector list, but "low-pass" is zero-ed for all frequencies
>= of the given index, and "high-pass" is zero-ed for all frequencies < to the
given index.
Also return their projection onto a timeseries.

Expand All @@ -357,7 +357,7 @@ def graph_filter(timeseries, eigenvec, freq_idx, keys=["low", "high"]):
freq_idx : int or list
The index of the frequency that splits the spectral power into two
(more or less) equal parts - i.e. the index of the first frequency in
the "high" component.
the "high-pass" component.
keys : list, optional
The keys to call the split parts with

Expand All @@ -371,8 +371,8 @@ def graph_filter(timeseries, eigenvec, freq_idx, keys=["low", "high"]):
Raises
------
IndexError
If the given index is 0 (all "high"), the last possible index (all "low"),
or higher than the last possible index (not applicable).
If the given index is 0 (all "high-pass"), the last possible index
(all "low-pass"), or higher than the last possible index (not applicable).
"""
# #!# Find better name
# #!# Implement an index splitter
Expand Down
24 changes: 12 additions & 12 deletions nigsp/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,27 @@ def test_integration(timeseries, sc_mtx, atlas, mean_fc, sdi, testdir):
# Check that files were created
assert isdir(testdir)
assert isdir(join(testdir, "logs"))
assert isdir(join(testdir, "testfile_timeseries_low"))
assert isdir(join(testdir, "testfile_timeseries_high"))
assert isfile(join(testdir, "testfile_timeseries_low", "000.tsv"))
assert isfile(join(testdir, "testfile_timeseries_high", "000.tsv"))
assert isdir(join(testdir, "testfile_timeseries_low-pass"))
assert isdir(join(testdir, "testfile_timeseries_high-pass"))
assert isfile(join(testdir, "testfile_timeseries_low-pass", "000.tsv"))
assert isfile(join(testdir, "testfile_timeseries_high-pass", "000.tsv"))
assert isfile(join(testdir, "testfile_fc.tsv"))
assert isfile(join(testdir, "testfile_fc_low.tsv"))
assert isfile(join(testdir, "testfile_fc_high.tsv"))
assert isfile(join(testdir, "testfile_fc_low-pass.tsv"))
assert isfile(join(testdir, "testfile_fc_high-pass.tsv"))
assert isfile(join(testdir, "testfile_eigenval.tsv"))
assert isfile(join(testdir, "testfile_eigenvec.tsv"))
assert isfile(join(testdir, "testfile_eigenvec_low.tsv"))
assert isfile(join(testdir, "testfile_eigenvec_high.tsv"))
assert isfile(join(testdir, "testfile_eigenvec_low-pass.tsv"))
assert isfile(join(testdir, "testfile_eigenvec_high-pass.tsv"))
assert isfile(join(testdir, "testfile_sdi.tsv"))
assert isfile(join(testdir, "testfile_mkd_sdi.tsv"))
assert isfile(join(testdir, "testfile_laplacian.png"))
assert isfile(join(testdir, "testfile_sc.png"))
assert isfile(join(testdir, "testfile_fc.png"))
assert isfile(join(testdir, "testfile_fc_low.png"))
assert isfile(join(testdir, "testfile_fc_high.png"))
assert isfile(join(testdir, "testfile_fc_low-pass.png"))
assert isfile(join(testdir, "testfile_fc_high-pass.png"))
assert isfile(join(testdir, "testfile_greyplot.png"))
assert isfile(join(testdir, "testfile_greyplot_low.png"))
assert isfile(join(testdir, "testfile_greyplot_high.png"))
assert isfile(join(testdir, "testfile_greyplot_low-pass.png"))
assert isfile(join(testdir, "testfile_greyplot_high-pass.png"))
assert isfile(join(testdir, "testfile_sdi.png"))
assert isfile(join(testdir, "testfile_mkd_sdi.png"))

Expand Down
10 changes: 5 additions & 5 deletions nigsp/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def test_sdi():
ts3 = np.arange(5, 7)[..., np.newaxis]
sdi_in = np.log2(np.arange(3.0, 1.0, -1.0))

ts = {"low": ts1, "high": ts2}
ts = {"low-pass": ts1, "high-pass": ts2}
sdi_out = metrics.sdi(ts)
assert (sdi_out == sdi_in).all()

ts = {"HIGH": ts2, "LOW": ts1}
ts = {"HIGH-PASS": ts2, "LOW-PASS": ts1}
sdi_out = metrics.sdi(ts)
assert (sdi_out == sdi_in).all()

Expand All @@ -29,8 +29,8 @@ def test_sdi():
assert (sdi_out == sdi_in).all()

ts = {
"low": np.repeat(np.repeat(ts1[..., np.newaxis], 3, axis=1), 3, axis=2),
"high": np.repeat(np.repeat(ts2[..., np.newaxis], 3, axis=1), 3, axis=2),
"low-pass": np.repeat(np.repeat(ts1[..., np.newaxis], 3, axis=1), 3, axis=2),
"high-pass": np.repeat(np.repeat(ts2[..., np.newaxis], 3, axis=1), 3, axis=2),
}
sdi_out = metrics.sdi(ts, mean=True)
sdi_out = np.around(sdi_out, decimals=15)
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_break_sdi():
ts = {"alpha": ts1, "beta": ts2, "gamma": ts3}

with raises(ValueError) as errorinfo:
metrics.sdi(ts, keys=["high", "low"])
metrics.sdi(ts, keys=["high-pass", "low-pass"])
assert "provided keys" in str(errorinfo.value)

with raises(ValueError) as errorinfo:
Expand Down
89 changes: 70 additions & 19 deletions nigsp/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@

LGR = logging.getLogger(__name__)
SET_DPI = 100
FIGSIZE = (18, 10)
FIGSIZE = (12, 7)
FIGSIZE_SQUARE = (6, 5)
FIGSIZE_LONG = (10, 5)
FIGSIZE_LONG = (12, 4)


def plot_connectivity(mtx, filename=None, closeplot=False):
def plot_connectivity(mtx, filename=None, title=None, crange=None, closeplot=False):
"""
Create a connectivity matrix plot.

Expand All @@ -43,6 +43,10 @@ def plot_connectivity(mtx, filename=None, closeplot=False):
A (square) array with connectivity information inside.
filename : None, str, or os.PathLike, optional
The path to save the plot on disk.
title : None or str, optional
Add a title to the graph
range : None or list, optional
Set vmin and vmax
closeplot : bool, optional
Whether to close plots after saving or not. Mainly used for debug
or use with live python/ipython instances.
Expand Down Expand Up @@ -85,11 +89,31 @@ def plot_connectivity(mtx, filename=None, closeplot=False):
LGR.warning("Given matrix is not a square matrix!")

LGR.info("Creating connectivity plot.")
plt.figure(figsize=FIGSIZE_SQUARE)
plot_matrix(mtx)
fig = plt.figure(figsize=FIGSIZE_SQUARE)
ax = fig.subplots()

pc_args = {"mat": mtx, "axes": ax}
if crange is not None:
if type(crange) in [list, tuple]:
pc_args["vmin"] = crange[0]
pc_args["vmax"] = crange[1]
else:
vmax = np.nanpercentile(mtx, 98) # mtx.max()
vmin = np.abs(np.nanpercentile(mtx, 2)) # mtx.min()
if crange == "auto-symm" and mtx.min() < 0 and vmax > 0:
pc_args["vmax"] = vmax if vmax > vmin else vmin
pc_args["vmin"] = -vmin if vmin > vmax else -vmax
elif crange == "auto-zero" or mtx.min() > 0 or vmax < 0:
pass
else:
raise NotImplementedError(f"{crange} option not implemented.")

plot_matrix(**pc_args)
if title is not None:
fig.suptitle(title)

if filename is not None:
plt.savefig(filename, dpi=SET_DPI)
plt.savefig(filename, dpi=SET_DPI, bbox_inches="tight")
closeplot = True

if closeplot:
Expand Down Expand Up @@ -162,17 +186,18 @@ def plot_greyplot(timeseries, filename=None, title=None, resize=None, closeplot=
timeseries = resize_ts(timeseries, resize)

LGR.info("Creating greyplot.")
plt.figure(figsize=FIGSIZE_LONG)
if title is not None:
plt.title(title)
vmax = np.percentile(timeseries, 99)
vmin = np.percentile(timeseries, 1)
plt.imshow(timeseries, cmap="gray", vmin=vmin, vmax=vmax)
plt.colorbar()
fig = plt.figure(figsize=FIGSIZE_LONG)
ax = fig.subplots()
im = ax.imshow(timeseries, cmap="gray", vmin=vmin, vmax=vmax)
plt.colorbar(im, ax=ax)
if title is not None:
fig.suptitle(title)
plt.tight_layout()

if filename is not None:
plt.savefig(filename, dpi=SET_DPI)
plt.savefig(filename, dpi=SET_DPI, bbox_inches="tight")
closeplot = True

if closeplot:
Expand All @@ -183,7 +208,9 @@ def plot_greyplot(timeseries, filename=None, title=None, resize=None, closeplot=
return 0


def plot_nodes(ns, atlas, filename=None, thr=None, closeplot=False):
def plot_nodes(
ns, atlas, filename=None, title=None, thr=None, cmap=None, closeplot=False
):
"""
Create a marker plot in the MNI space.

Expand All @@ -198,8 +225,12 @@ def plot_nodes(ns, atlas, filename=None, thr=None, closeplot=False):
or a list of coordinates of the center of mass of parcels.
filename : None, str, or os.PathLike, optional
The path to save the plot on disk.
title : None or str, optional
Add a title to the graph
thr : float or None, optional
The threshold to use in plotting the nodes.
cmap : None or matplotlib.pyplot.cm colormap object, optional.
The colormap to adopt in plotting nodes. Defaults to reverse viridis.
closeplot : bool, optional
Whether to close plots after saving or not. Mainly used for debug
or use with live python/ipython instances.
Expand Down Expand Up @@ -252,11 +283,17 @@ def plot_nodes(ns, atlas, filename=None, thr=None, closeplot=False):
raise ValueError("Node array and coordinates array have different length.")

LGR.info("Creating markerplot.")
plt.figure(figsize=FIGSIZE)
plot_markers(ns, coord, node_threshold=thr)
fig = plt.figure(figsize=FIGSIZE)
ax = fig.subplots()

cmap = plt.cm.viridis_r if cmap is None else cmap
plot_markers(ns, coord, axes=ax, node_threshold=thr, node_cmap=cmap)
if title is not None:
fig.suptitle(title)

plt.tight_layout()
if filename is not None:
plt.savefig(filename, dpi=SET_DPI)
plt.savefig(filename, dpi=SET_DPI, bbox_inches="tight")
closeplot = True

if closeplot:
Expand All @@ -265,7 +302,9 @@ def plot_nodes(ns, atlas, filename=None, thr=None, closeplot=False):
return 0


def plot_edges(mtx, atlas, filename=None, thr=None, closeplot=False):
def plot_edges(
mtx, atlas, filename=None, title=None, thr=None, cmap=None, closeplot=False
):
"""
Create a connectivity plot in the MNI space.

Expand All @@ -280,9 +319,13 @@ def plot_edges(mtx, atlas, filename=None, thr=None, closeplot=False):
or a list of coordinates of the center of mass of parcels.
filename : None, str, or os.PathLike, optional
The path to save the plot on disk.
title : None or str, optional
Add a title to the graph
thr : float, str or None, optional
The threshold to use in plotting the nodes.
If `str`, needs to express a percentage.
cmap : None or matplotlib.pyplot.cm colormap object, optional.
The colormap to adopt in plotting nodes. Defaults to reverse viridis.
closeplot : bool, optional
Whether to close plots after saving or not. Mainly used for debug
or use with live python/ipython instances.
Expand Down Expand Up @@ -335,26 +378,34 @@ def plot_edges(mtx, atlas, filename=None, thr=None, closeplot=False):
raise ValueError("Matrix axis and coordinates array have different length.")

LGR.info("Creating connectome-like plot.")
plt.figure(figsize=FIGSIZE)
fig = plt.figure(figsize=FIGSIZE)
ax = fig.subplots()

pc_args = {
"adjacency_matrix": mtx,
"node_coords": coord,
"node_color": "black",
"node_size": 5,
"edge_threshold": thr,
"edge_cmap": plt.cm.bwr,
"colorbar": True,
"axes": ax,
}

if mtx.min() >= 0:
pc_args["edge_vmin"] = 0
pc_args["edge_vmax"] = mtx.max()
pc_args["edge_cmap"] = cm.red_transparent_full_alpha_range

pc_args["edge_cmap"] = pc_args["edge_cmap"] if cmap is None else cmap
plot_connectome(**pc_args)

if title is not None:
fig.suptitle(title)

plt.tight_layout()
if filename is not None:
plt.savefig(filename, dpi=SET_DPI)
plt.savefig(filename, dpi=SET_DPI, bbox_inches="tight")
closeplot = True

if closeplot:
Expand Down
Loading