Skip to content

Commit 5fe8f5c

Browse files
committed
Merge branch 'master' of github.com:automl/fanova
2 parents 99c6234 + 0b9123a commit 5fe8f5c

1 file changed

Lines changed: 107 additions & 95 deletions

File tree

fanova/visualizer.py

Lines changed: 107 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import matplotlib.pyplot as plt
99
from matplotlib import cm
1010
from mpl_toolkits.mplot3d import Axes3D
11-
from ConfigSpace.hyperparameters import Hyperparameter, CategoricalHyperparameter, Constant
11+
from ConfigSpace.hyperparameters import Hyperparameter, CategoricalHyperparameter, Constant, OrdinalHyperparameter, \
12+
NumericalHyperparameter
1213

1314

1415
class Visualizer(object):
@@ -84,7 +85,7 @@ def generate_pairwise_marginal(self, param_list, resolution=20):
8485
array that contains the predicted importance (shape is len(categoricals) or resolution depending on
8586
parameter, so e.g. zz = np.array(3, 20) if first parameter has three categories and resolution is set to 20
8687
"""
87-
if not len(param_list) == 2:
88+
if len(set(param_list)) != 2:
8889
raise ValueError("You have to specify 2 (different) parameters")
8990

9091
params, param_names, param_indices = self._get_parameter(param_list)
@@ -97,10 +98,13 @@ def generate_pairwise_marginal(self, param_list, resolution=20):
9798
if isinstance(p, CategoricalHyperparameter):
9899
grid_orig.append(p.choices)
99100
grid_fanova.append(np.arange(len(p.choices)))
101+
elif isinstance(p, OrdinalHyperparameter):
102+
grid_orig.append(p.sequence)
103+
grid_fanova.append(np.arange(len(p.sequence)))
100104
elif isinstance(p, Constant):
101105
grid_orig.append((p.value,))
102106
grid_fanova.append(np.arange(1))
103-
else:
107+
elif isinstance(p, NumericalHyperparameter):
104108
if p.log:
105109
base = np.e # assuming ConfigSpace uses the natural logarithm
106110
log_lower = np.log(p.lower) / np.log(base)
@@ -110,26 +114,20 @@ def generate_pairwise_marginal(self, param_list, resolution=20):
110114
grid = np.linspace(p.lower, p.upper, resolution)
111115
grid_orig.append(grid)
112116
grid_fanova.append(grid)
117+
else:
118+
raise ValueError("Hyperparameter %s of type %s not supported." % (p.name, type(p)))
113119

114-
# Turn into arrays, squeeze all but the first two dimensions (to avoid squeezing away the dimension for Constants)
120+
# Turn into arrays, squeeze all but the first two dimensions (avoid squeezing away the dimension for Constants)
115121
param_indices = np.array(param_indices)
116122
param_indices = param_indices.reshape([s for i, s in enumerate(param_indices.shape) if i in [0, 1] or s != 1])
117123
grid_fanova = np.array(grid_fanova)
118124
grid_fanova = grid_fanova.reshape([s for i, s in enumerate(grid_fanova.shape) if i in [0, 1] or s != 1])
119125

120-
# The swap-parameter is here because this was how I found this code and without understanding of the
121-
# in-detail implementation of fANOVA I assume there is a reason for the first element of the fanova
122-
# method marginal_mean_variance_for_values to expect the element with less elements as the first parameter
123-
swap = len(grid_fanova[1]) > len(grid_fanova[0])
124-
125126
# Populating the result
126127
zz = np.zeros((len(grid_fanova[0]), len(grid_fanova[1])))
127128
for i, x_value in enumerate(grid_fanova[0]):
128129
for j, y_value in enumerate(grid_fanova[1]):
129-
if swap:
130-
zz[i][j] = self.fanova.marginal_mean_variance_for_values(param_indices[::-1], [y_value, x_value])[0]
131-
else:
132-
zz[i][j] = self.fanova.marginal_mean_variance_for_values(param_indices, [x_value, y_value])[0]
130+
zz[i][j] = self.fanova.marginal_mean_variance_for_values(param_indices, [x_value, y_value])[0]
133131

134132
return grid_orig, zz
135133

@@ -154,48 +152,19 @@ def plot_pairwise_marginal(self, param_list, resolution=20, show=False, three_d=
154152
add_colorbar: bool
155153
whether to add the colorbar for 3d plots
156154
"""
157-
if not len(param_list) == 2:
155+
if len(set(param_list)) != 2:
158156
raise ValueError("You have to specify 2 (different) parameters")
159157

160158
params, param_names, param_indices = self._get_parameter(param_list)
161159

162-
first_is_cat = isinstance(params[0], (CategoricalHyperparameter, Constant))
163-
second_is_cat = isinstance(params[1], (CategoricalHyperparameter, Constant))
160+
first_is_numerical = isinstance(params[0], NumericalHyperparameter)
161+
second_is_numerical = isinstance(params[1], NumericalHyperparameter)
164162

165163
plt.close()
166164
fig = plt.figure()
167165
plt.title('%s and %s' % (param_names[0], param_names[1]))
168166

169-
if first_is_cat or second_is_cat:
170-
# At least one of the two parameters is categorical
171-
if first_is_cat and second_is_cat:
172-
# Both parameters are categorical -> create hotmap
173-
choices, zz = self.generate_pairwise_marginal(param_indices, resolution)
174-
plt.imshow(zz, cmap='hot', interpolation='nearest')
175-
plt.xticks(np.arange(0, len(choices[0])), choices[0], fontsize=8)
176-
plt.yticks(np.arange(0, len(choices[1])), choices[1], fontsize=8)
177-
plt.xlabel(param_names[0])
178-
plt.ylabel(param_names[1])
179-
plt.colorbar().set_label(self._y_label)
180-
else:
181-
# Only one of them is categorical -> create multi-line-plot
182-
# Make sure categorical is first in indices (for iteration below)
183-
param_indices = param_indices if first_is_cat else param_indices[::-1]
184-
params = params if first_is_cat else params[::-1]
185-
choices, zz = self.generate_pairwise_marginal(param_indices, resolution)
186-
187-
for i, cat in enumerate(choices[0]):
188-
if params[1].log:
189-
plt.semilogx(choices[1], zz[i], label='%s' % str(cat))
190-
else:
191-
plt.plot(choices[1], zz[i], label='%s' % str(cat))
192-
193-
plt.ylabel(self._y_label)
194-
plt.xlabel(param_names[0] if second_is_cat else param_names[1]) # x-axis displays non-categorical
195-
plt.legend()
196-
plt.tight_layout()
197-
198-
else:
167+
if first_is_numerical and second_is_numerical:
199168
# No categoricals -> create 3D-plot
200169
grid_list, zz = self.generate_pairwise_marginal(param_indices, resolution)
201170

@@ -223,6 +192,38 @@ def plot_pairwise_marginal(self, param_list, resolution=20, show=False, three_d=
223192

224193
plt.ylabel(param_names[1])
225194
plt.colorbar()
195+
else:
196+
# At least one of the two parameters is non-numerical (categorical, ordinal or constant)
197+
if first_is_numerical or second_is_numerical:
198+
# Only one of them is non-numerical -> create multi-line-plot
199+
# Make sure categorical is first in indices (for iteration below)
200+
numerical_idx = 0 if first_is_numerical else 1
201+
categorical_idx = 1 - numerical_idx
202+
grid_labels, zz = self.generate_pairwise_marginal(param_indices, resolution)
203+
204+
if first_is_numerical:
205+
zz = zz.T
206+
207+
for i, cat in enumerate(grid_labels[categorical_idx]):
208+
if params[numerical_idx].log:
209+
plt.semilogx(grid_labels[numerical_idx], zz[i], label='%s' % str(cat))
210+
else:
211+
plt.plot(grid_labels[numerical_idx], zz[i], label='%s' % str(cat))
212+
213+
plt.ylabel(self._y_label)
214+
plt.xlabel(param_names[numerical_idx]) # x-axis displays numerical
215+
plt.legend()
216+
plt.tight_layout()
217+
218+
else:
219+
# Both parameters are categorical -> create hotmap
220+
choices, zz = self.generate_pairwise_marginal(param_indices, resolution)
221+
plt.imshow(zz, cmap='hot', interpolation='nearest')
222+
plt.xticks(np.arange(0, len(choices[0])), choices[0], fontsize=8)
223+
plt.yticks(np.arange(0, len(choices[1])), choices[1], fontsize=8)
224+
plt.xlabel(param_names[0])
225+
plt.ylabel(param_names[1])
226+
plt.colorbar().set_label(self._y_label)
226227

227228
if show:
228229
plt.show()
@@ -255,17 +256,7 @@ def generate_marginal(self, p, resolution=100):
255256
"""
256257
p, p_name, p_idx = self._get_parameter(p)
257258

258-
if isinstance(p, (CategoricalHyperparameter, Constant)):
259-
try:
260-
categorical_size = len(p.choices)
261-
except AttributeError:
262-
categorical_size = 1
263-
marginals = [self.fanova.marginal_mean_variance_for_values([p_idx], [i]) for i in range(categorical_size)]
264-
mean, v = list(zip(*marginals))
265-
std = np.sqrt(v)
266-
return mean, std
267-
268-
else:
259+
if isinstance(p, NumericalHyperparameter):
269260
lower_bound = p.lower
270261
upper_bound = p.upper
271262
log = p.log
@@ -274,12 +265,14 @@ def generate_marginal(self, p, resolution=100):
274265
log_lower = np.log(lower_bound) / np.log(base)
275266
log_upper = np.log(upper_bound) / np.log(base)
276267
grid = np.logspace(log_lower, log_upper, resolution, endpoint=True, base=base)
277-
'''
268+
278269
if abs(grid[0] - lower_bound) > 0.00001:
279-
raise ValueError()
270+
self.logger.warning("Check the grid's (lower) accuracy for %s (plotted vs theoretical: %s vs %s)"
271+
% (p.name, grid[0], lower_bound))
280272
if abs(grid[-1] - upper_bound) > 0.00001:
281-
raise ValueError()
282-
'''
273+
self.logger.warning("Check the grid's (upper) accuracy for %s (plotted vs theoretical: %s vs %s)"
274+
% (p.name, grid[-1], upper_bound))
275+
283276
else:
284277
grid = np.linspace(lower_bound, upper_bound, resolution)
285278
mean = np.zeros(resolution)
@@ -292,6 +285,20 @@ def generate_marginal(self, p, resolution=100):
292285
std[i] = np.sqrt(v)
293286
return mean, std, grid
294287

288+
else:
289+
if isinstance(p, CategoricalHyperparameter):
290+
categorical_size = len(p.choices)
291+
elif isinstance(p, Constant):
292+
categorical_size = 1
293+
elif isinstance(p, OrdinalHyperparameter):
294+
categorical_size = len(p.sequence)
295+
else:
296+
raise ValueError("Parameter %s of type %s not supported." % (p.name, type(p)))
297+
marginals = [self.fanova.marginal_mean_variance_for_values([p_idx], [i]) for i in range(categorical_size)]
298+
mean, v = list(zip(*marginals))
299+
std = np.sqrt(v)
300+
return mean, std
301+
295302
def plot_marginal(self, param, resolution=100, log_scale=None, show=True, incumbents=None):
296303
"""
297304
Creates a plot of marginal of a selected parameter
@@ -312,40 +319,7 @@ def plot_marginal(self, param, resolution=100, log_scale=None, show=True, incumb
312319
param, param_name, param_idx = self._get_parameter(param)
313320

314321
# check if categorical
315-
if isinstance(param, (CategoricalHyperparameter, Constant)):
316-
# PREPROCESS
317-
try:
318-
labels = param.choices
319-
categorical_size = len(param.choices)
320-
except AttributeError:
321-
labels = str(param)
322-
categorical_size = 1
323-
indices = np.arange(1, categorical_size + 1, 1)
324-
mean, std = self.generate_marginal(param_idx)
325-
min_y = mean[0]
326-
max_y = mean[0]
327-
328-
# PLOT
329-
b = plt.boxplot([[x] for x in mean])
330-
plt.xticks(indices, labels)
331-
# blow up boxes
332-
for box, std_ in zip(b["boxes"], std):
333-
y = box.get_ydata()
334-
y[2:4] = y[2:4] + std_
335-
y[0:2] = y[0:2] - std_
336-
y[4] = y[4] - std_
337-
box.set_ydata(y)
338-
min_y = min(min_y, y[0] - std_)
339-
max_y = max(max_y, y[2] + std_)
340-
341-
plt.ylim([min_y, max_y])
342-
343-
plt.ylabel(self._y_label)
344-
plt.xlabel(param_name)
345-
plt.tight_layout()
346-
347-
else:
348-
322+
if isinstance(param, NumericalHyperparameter):
349323
# PREPROCESS
350324
mean, std, grid = self.generate_marginal(param_idx, resolution)
351325
mean = np.asarray(mean)
@@ -382,6 +356,44 @@ def plot_marginal(self, param, resolution=100, log_scale=None, show=True, incumb
382356
plt.legend()
383357
plt.tight_layout()
384358

359+
else:
360+
# PREPROCESS
361+
if isinstance(param, CategoricalHyperparameter):
362+
labels = param.choices
363+
categorical_size = len(param.choices)
364+
elif isinstance(param, OrdinalHyperparameter):
365+
labels = param.sequence
366+
categorical_size = len(param.sequence)
367+
elif isinstance(param, Constant):
368+
labels = str(param)
369+
categorical_size = 1
370+
else:
371+
raise ValueError("Parameter %s of type %s not supported." % (param.name, type(param)))
372+
373+
indices = np.arange(1, categorical_size + 1, 1)
374+
mean, std = self.generate_marginal(param_idx)
375+
min_y = mean[0]
376+
max_y = mean[0]
377+
378+
# PLOT
379+
b = plt.boxplot([[x] for x in mean])
380+
plt.xticks(indices, labels)
381+
# blow up boxes
382+
for box, std_ in zip(b["boxes"], std):
383+
y = box.get_ydata()
384+
y[2:4] = y[2:4] + std_
385+
y[0:2] = y[0:2] - std_
386+
y[4] = y[4] - std_
387+
box.set_ydata(y)
388+
min_y = min(min_y, y[0] - std_)
389+
max_y = max(max_y, y[2] + std_)
390+
391+
plt.ylim([min_y, max_y])
392+
393+
plt.ylabel(self._y_label)
394+
plt.xlabel(param_name)
395+
plt.tight_layout()
396+
385397
if show:
386398
plt.show()
387399
else:

0 commit comments

Comments
 (0)