Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Diagram Visualization
persim.plot_diagrams
persim.bottleneck_matching
persim.wasserstein_matching
persim.Barcode
persim.plot_landscape
persim.plot_landscape_simple

Expand Down
220 changes: 204 additions & 16 deletions persim/visuals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import numpy as np
import matplotlib.pyplot as plt

__all__ = ["plot_diagrams", "bottleneck_matching", "wasserstein_matching"]
import io

__all__ = [
"plot_diagrams",
"bottleneck_matching",
"wasserstein_matching",
"Barcode"
]


def plot_diagrams(
Expand All @@ -24,21 +31,21 @@ def plot_diagrams(
Parameters
----------
diagrams: ndarray (n_pairs, 2) or list of diagrams
A diagram or list of diagrams. If diagram is a list of diagrams,
A diagram or list of diagrams. If diagram is a list of diagrams,
then plot all on the same plot using different colors.
plot_only: list of numeric
If specified, an array of only the diagrams that should be plotted.
title: string, default is None
If title is defined, add it as title of the plot.
xy_range: list of numeric [xmin, xmax, ymin, ymax]
User provided range of axes. This is useful for comparing
User provided range of axes. This is useful for comparing
multiple persistence diagrams.
labels: string or list of strings
Legend labels for each diagram.
Legend labels for each diagram.
If none are specified, we use H_0, H_1, H_2,... by default.
colormap: string, default is 'default'
Any of matplotlib color palettes.
Some options are 'default', 'seaborn', 'sequential'.
Any of matplotlib color palettes.
Some options are 'default', 'seaborn', 'sequential'.
See all available styles with

.. code:: python
Expand All @@ -48,17 +55,17 @@ def plot_diagrams(

size: numeric, default is 20
Pixel size of each point plotted.
ax_color: any valid matplotlib color type.
ax_color: any valid matplotlib color type.
See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API.
diagonal: bool, default is True
Plot the diagonal x=y line.
lifetime: bool, default is False. If True, diagonal is turned to False.
Plot life time of each point instead of birth and death.
Plot life time of each point instead of birth and death.
Essentially, visualize (x, y-x).
legend: bool, default is True
If true, show the legend.
show: bool, default is False
Call plt.show() after plotting. If you are using self.plot() as part
Call plt.show() after plotting. If you are using self.plot() as part
of a subplot, set show=False and call plt.show() only once at the end.
"""

Expand Down Expand Up @@ -165,26 +172,23 @@ def plot_diagrams(
if show is True:
plt.show()

def plot_a_bar(p, q, c='b', linestyle='-'):
plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1)

def bottleneck_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None):
""" Visualize bottleneck matching between two diagrams

Parameters
===========

dgm1: Mx(>=2)
dgm1: Mx(>=2)
array of birth/death pairs for PD 1
dgm2: Nx(>=2)
dgm2: Nx(>=2)
array of birth/death paris for PD 2
matching: ndarray(Mx+Nx, 3)
A list of correspondences in an optimal matching, as well as their distance, where:
* First column is index of point in first persistence diagram, or -1 if diagonal
* Second column is index of point in second persistence diagram, or -1 if diagonal
* Third column is the distance of each matching
labels: list of strings
names of diagrams for legend. Default = ["dgm1", "dgm2"],
names of diagrams for legend. Default = ["dgm1", "dgm2"],
ax: matplotlib Axis object
For plotting on a particular axis.

Expand Down Expand Up @@ -248,7 +252,7 @@ def wasserstein_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None)
* Second column is index of point in second persistence diagram, or -1 if diagonal
* Third column is the distance of each matching
labels: list of strings
names of diagrams for legend. Default = ["dgm1", "dgm2"],
names of diagrams for legend. Default = ["dgm1", "dgm2"],
ax: matplotlib Axis object
For plotting on a particular axis.

Expand Down Expand Up @@ -286,3 +290,187 @@ def wasserstein_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None)
ax.plot([dgm1[i, 0], dgm2[j, 0]], [dgm1[i, 1], dgm2[j, 1]], "g")

plot_diagrams([dgm1, dgm2], labels=labels, ax=ax)

class Barcode:
__doc__ = """
Barcode visualisation made easy!

Note that this convenience class requires instantiation as the number
of subplots produced depends on the dimension of the data.
"""

def __init__(self, diagrams, verbose=False):
"""
Parameters
===========
diagrams: list-like
typically the output of ripser(nodes)['dgms']
verbose: bool
Execute print statemens for extra information; currently only echoes
number of bars in each dimension (Default=False).

Examples
===========
>>> n = 300
>>> t = np.linspace(0, 2 * np.pi, n)
>>> noise = np.random.normal(0, 0.1, size=n)
>>> data = np.vstack([((3+d) * np.cos(t[i]+d), (3+d) * np.sin(t[i]+d)) for i, d in enumerate(noise)])
>>> diagrams = ripser(data)
>>> bc = Barcode(diagrams['dgms'])
>>> bc.plot_barcode()
"""
if not isinstance(diagrams, list):
diagrams = [diagrams]

self.diagrams = diagrams
self._verbose = verbose
self._dim = len(diagrams)

def plot_barcode(self, figsize=None, show=True, export_png=False, dpi=100, **kwargs):
"""Wrapper method to produce barcode plot

Parameters
===========
figsize: tuple
figure size, default=(6,6) if H0+H1 only, (6,4) otherwise

show: boolean
show the figure via plt.show()

Comment thread
DeliciousHair marked this conversation as resolved.
export_png: boolean
write image to png data, returned as io.BytesIO() instance,
default=False

**kwargs: artist paramters for the barcodes, defaults:
c='grey'
linestyle='-'
linewidth=0.5
dpi=100 (for png export)

Returns
===========
out: list or None
list of png exports if export_png=True, otherwise None
"""
if self._dim == 2:
if figsize is None:
figsize = (6, 6)

return self._plot_H0_H1(
figsize=figsize,
show=show,
export_png=export_png,
dpi=dpi,
**kwargs
)

else:
if figsize is None:
figsize = (6, 4)

return self._plot_Hn(
figsize=figsize,
show=show,
export_png=export_png,
dpi=dpi,
**kwargs
)

def _plot_H0_H1(self, *, figsize, show, export_png, dpi, **kwargs):
out = []

fig, ax = plt.subplots(2, 1, figsize=figsize)

for dim, diagram in enumerate(self.diagrams):
self._plot_many_bars(dim, diagram, dim, ax, **kwargs)

if export_png:
fp = io.BytesIO()
plt.savefig(fp, dpi=dpi)
fp.seek(0)

out += [fp]

if show:
plt.show()
else:
plt.close()

if any(out):
return out

def _plot_Hn(self, *, figsize, show, export_png, dpi, **kwargs):
out = []

for dim, diagram in enumerate(self.diagrams):
fig, ax = plt.subplots(1, 1, figsize=figsize)

self._plot_many_bars(dim, diagram, 0, [ax], **kwargs)

if export_png:
fp = io.BytesIO()
plt.savefig(fp, dpi=dpi)
fp.seek(0)

out += [fp]

if show:
plt.show()
else:
plt.close()

if any(out):
return out

def _plot_many_bars(self, dim, diagram, idx, ax, **kwargs):
number_of_bars = len(diagram)
if self._verbose:
print("Number of bars in dimension %d: %d" % (dim, number_of_bars))

if number_of_bars > 0:
births = np.vstack([(elem[0], i) for i, elem in enumerate(diagram)])
deaths = np.vstack([(elem[1], i) for i, elem in enumerate(diagram)])

inf_bars = np.where(np.isinf(deaths))[0]
max_death = deaths[np.isfinite(deaths[:, 0]), 0].max()

number_of_bars_fin = births.shape[0] - inf_bars.shape[0]
number_of_bars_inf = inf_bars.shape[0]

_ = [self._plot_a_bar(ax[idx], birth, deaths[i], max_death, **kwargs) for i, birth in enumerate(births)]

# the line below is to plot a vertical red line showing the maximal finite bar length
ax[idx].plot(
[max_death, max_death],
[0, number_of_bars - 1],
c='r',
linestyle='--',
linewidth=0.7
)

title = "H%d barcode: %d finite, %d infinite" % (dim, number_of_bars_fin, number_of_bars_inf)
ax[idx].set_title(title, fontsize=9)
ax[idx].set_yticks([])

for loc in ('right', 'left', 'top'):
ax[idx].spines[loc].set_visible(False)

@staticmethod
def _plot_a_bar(ax, birth, death, max_death, c='gray', linestyle='-', linewidth=0.5):
if np.isinf(death[0]):
death[0] = 1.05 * max_death
ax.plot(
death[0],
death[1],
c=c,
markersize=4,
marker='>',
)

ax.plot(
[birth[0], death[0]],
[birth[1], death[1]],
c=c,
linestyle=linestyle,
linewidth=linewidth,
)