-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathDiffusionViz.py
89 lines (76 loc) · 2.98 KB
/
DiffusionViz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import abc
from bokeh.palettes import Category20_9 as cols
import os
import matplotlib as mpl
if os.environ.get('DISPLAY', '') == '':
print('no display found. Using non-interactive Agg backend')
mpl.use('Agg')
import matplotlib.pyplot as plt
import future.utils
import six
import itertools
__author__ = 'Giulio Rossetti'
__license__ = "BSD-2-Clause"
marker = itertools.cycle(('D', '+', '>', 'o', '*'))
@six.add_metaclass(abc.ABCMeta)
class DiffusionPlot(object):
# __metaclass__ = abc.ABCMeta
def __init__(self, model, trends):
self.model = model
self.trends = trends
statuses = model.available_statuses
self.srev = {v: k for k, v in future.utils.iteritems(statuses)}
self.ylabel = ""
self.title = ""
self.nnodes = model.graph.number_of_nodes()
self.normalized = True
@abc.abstractmethod
def iteration_series(self, percentile):
"""
Prepare the data to be visualized
:param percentile: The percentile for the trend variance area
:return: a dictionary where iteration ids are keys and the associated values are the computed measures
"""
pass
def plot(self, filename=None, percentile=90, statuses=None):
"""
Generates the plot
:param filename: Output filename
:param percentile: The percentile for the trend variance area
:param statuses: List of statuses to plot. If not specified all statuses trends will be shown.
"""
pres = self.iteration_series(percentile)
# infos = self.model.get_info()
# descr = ""
plt.figure(figsize=(20, 10))
# for k, v in future.utils.iteritems(infos):
# descr += "%s: %s, " % (k, v)
# descr = descr[:-2].replace("_", " ")
mx = 0
i = 0
for k, l in future.utils.iteritems(pres):
if statuses is not None and self.srev[k] not in statuses:
continue
mx = len(l[0])
if self.normalized:
plt.plot(list(range(0, mx)), l[1]/self.nnodes, lw=2, label=self.srev[k], alpha=0.5, marker = next(marker),markersize=12, color=cols[i])
plt.fill_between(list(range(0, mx)), l[0]/self.nnodes, l[2]/self.nnodes, alpha="0.2",
color=cols[i])
else:
plt.plot(list(range(0, mx)), l[1], lw=2, label=self.srev[k], alpha=0.5, marker = next(marker),markersize=12, color=cols[i])
plt.fill_between(list(range(0, mx)), l[0], l[2], alpha="0.2",
color=cols[i])
i += 1
plt.grid(axis="y")
# plt.title(descr)
plt.xlabel("Iterations", fontsize=24)
plt.ylabel(self.ylabel, fontsize=24)
plt.legend(loc="best", fontsize=18)
plt.xlim((0, mx))
plt.tight_layout()
if filename is not None:
plt.savefig(filename)
plt.clf()
else:
plt.show()