-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathgluon_mlp_dashboard.py
More file actions
204 lines (170 loc) · 8.41 KB
/
Copy pathgluon_mlp_dashboard.py
File metadata and controls
204 lines (170 loc) · 8.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from itertools import chain, product
import numpy as np
import mxnet as mx
from bqplot import *
from bqplot.marks import Graph
from ipywidgets import IntSlider, Dropdown, RadioButtons, HBox, VBox, Button, Layout
from bqplot import pyplot as plt
from bqplot import OrdinalScale, OrdinalColorScale
from bqplot.colorschemes import CATEGORY10
from IPython.display import display
class MLPDashboard(VBox):
def __init__(self, net, path, name, **kwargs):
self.net = net
self.path = path
self.name = name
self.data = kwargs.pop('data', None)
self.data = mx.nd.array(self.data)
self.ctx = kwargs.pop('ctx', mx.cpu())
self.num_epochs = kwargs.pop('num_epochs', 10)
self.height = kwargs.get('height', 800)
self.width = kwargs.get('width', 900)
self.directed_links = kwargs.get('directed_links', False)
self.get_shapes()
self.layer_colors = kwargs.get('layer_colors',
['Orange'] * (len(self.num_hidden_layers) + 2))
self.build_net()
self.create_charts()
self.graph.observe(self.hovered_change, 'hovered_point')
super(MLPDashboard, self).__init__(children=[self.controls, self.figure], **kwargs)
self.layout = layout=Layout(min_height='1000px')
def load_epoch(self, epoch_num):
file_path = self.path + '/' + self.name + '-' + str(epoch_num) +'.params'
self.net.load_params(file_path, ctx=self.ctx)
def get_weights_for_node_at_layer(self, epoch_num, layer_num, node_num):
self.load_epoch(epoch_num)
weights = self.net.collect_params()[list(self.net.collect_params())[2*(layer_num - 1)]].data().asnumpy()
node_weights = weights[node_num-1, :]
return node_weights
def get_activations_hist(self, epoch, layer, node):
if self.data is None:
return
self.load_epoch(epoch)
outputs = self.net(self.data)
self.graph.tooltip = self.hist_figure
self.hist_figure.title = 'Activation Histogram for {}th Node at the {}th Layer - Epoch {}'.format(node, layer, epoch)
self.hist_plot.sample = outputs[len(self.num_hidden_layers) - (layer - 1)][:, node].asnumpy()
def update_bar_chart(self, layer, node):
epoch = self.epoch_slider.value
if self.mode_dd.value == 'Activations':
self.get_activations_hist(epoch, layer, node)
return
if layer == 0:
self.bar_plot.x = []
self.bar_plot.y = []
return
if self.mode_dd.value == 'Weights':
display_vals = self.get_weights_for_node_at_layer(epoch, layer, node)
elif self.mode_dd.value == 'Gradients':
display_vals = self.get_gradients_for_node_at_layer(epoch, layer, node)
self.bar_figure.title = self.mode_dd.value + ' for layer:' + str(layer) + ' node: ' + str(node) + ' at epoch: ' + str(epoch)
self.bar_plot.x = np.arange(len(display_vals))
self.bar_plot.y = display_vals
self.graph.tooltip = self.bar_figure
def hovered_change(self, change):
point_index = change['new']
self.set_colors(point_index)
if point_index is None:
return
else:
for i, n in enumerate(self.node_counts):
if point_index < n:
break
else:
point_index = point_index - n
self.update_bar_chart(i, point_index)
def set_colors(self, index):
link_data_new = []
count = 0
for v in self.graph.link_data:
v_new = {}
v_new['source'] = v['source']
v_new['target'] = v['target']
if v['target'] == index:
v_new['value'] = (count % 11) + 1
count = count + 1
else:
v_new['value'] = 0
link_data_new.append(v_new)
self.graph.link_data = link_data_new
def get_gradients_for_node_at_layer(self, epoch_num, layer_num, node_num):
self.load_epoch(epoch_num)
grads_weights = self.net.collect_params()[list(self.net.collect_params())[2*(layer_num - 1)]].grad().asnumpy()
node_gradients = grads_weights[node_num-1, :]
return node_gradients
def get_shapes(self):
shapes = []
for layer in list(self.net.collect_params()):
if self.net.prefix + 'dense0_weight' in layer:
self.num_inputs = self.net.collect_params()[layer].shape[1]
if 'weight' in layer:
shapes.append(self.net.collect_params()[layer].shape[0])
self.num_hidden_layers = shapes[:-1]
self.nodes_output_layer = shapes[-1]
self.node_counts = [self.num_inputs] + self.num_hidden_layers + [self.nodes_output_layer]
def create_charts(self):
self.epoch_slider = IntSlider(description='Epoch:', min=1, max=self.num_epochs, value=1)
self.mode_dd = Dropdown(description='View', options=['Weights', 'Gradients', 'Activations'], value='Weights')
self.update_btn = Button(description='Update')
self.bar_figure = plt.figure()
self.bar_plot = plt.bar([], [], scales={'x': OrdinalScale()})
self.hist_figure = plt.figure(title='Histogram of Activations')
self.hist_plot = plt.hist([], bins=20)
self.controls = HBox([self.epoch_slider, self.mode_dd, self.update_btn])
self.graph.tooltip = self.bar_figure
def build_net(self):
# create nodes
self.layer_nodes = []
self.layer_nodes.append(['x' + str(i+1) for i in range(self.num_inputs)])
for i, h in enumerate(self.num_hidden_layers):
self.layer_nodes.append(['h' + str(i+1) + ',' + str(j+1) for j in range(h)])
self.layer_nodes.append(['y' + str(i+1) for i in range(self.nodes_output_layer)])
self.flattened_layer_nodes = list(chain(*self.layer_nodes))
# build link matrix
i = 0
node_indices = {}
for layer in self.layer_nodes:
for node in layer:
node_indices[node] = i
i += 1
n = len(self.flattened_layer_nodes)
self.link_data = []
for i in range(len(self.layer_nodes) - 1):
curr_layer_nodes_indices = [node_indices[d] for d in self.layer_nodes[i]]
next_layer_nodes = [node_indices[d] for d in self.layer_nodes[i+1]]
for s, t in product(curr_layer_nodes_indices, next_layer_nodes):
self.link_data.append({'source': s, 'target': t, 'value': 0})
# set node x locations
self.nodes_x = np.repeat(np.linspace(0, 100,
len(self.layer_nodes) + 1,
endpoint=False)[1:],
[len(n) for n in self.layer_nodes])
# set node y locations
self.nodes_y = np.array([])
for layer in self.layer_nodes:
n = len(layer)
ys = np.linspace(0, 100, n+1, endpoint=False)[1:]
self.nodes_y = np.append(self.nodes_y, ys[::-1])
# set node colors
n_layers = len(self.layer_nodes)
self.node_colors = np.repeat(np.array(self.layer_colors[:n_layers]),
[len(layer) for layer in self.layer_nodes]).tolist()
xs = LinearScale(min=0, max=100)
ys = LinearScale(min=0, max=100)
link_color_scale = OrdinalColorScale(colors=['gray'] + CATEGORY10, domain=list(range(11)))
self.graph = Graph(node_data=[{'label': d,
'label_display': 'none'} for d in self.flattened_layer_nodes],
link_data=self.link_data,
link_type='line',
colors=self.node_colors, directed=self.directed_links,
scales={'x': xs, 'y': ys, 'link_color': link_color_scale}, x=self.nodes_x, y=self.nodes_y)
self.graph.hovered_style = {'stroke': '1.5'}
self.graph.unhovered_style = {'opacity': '0.1'}
self.graph.selected_style = {'opacity': '1',
'stroke': 'red',
'stroke-width': '2.5'}
self.figure = Figure()
self.figure.marks = [self.graph]
self.figure.title = 'Analyzing the Trained Neural Network'
self.figure.layout.width = str(self.width) + 'px'
self.figure.layout.height = str(self.height) + 'px'