Skip to content

Commit e6e0cb4

Browse files
Merge pull request optuna#5867 from not522/refactor-plot-contour
Refactor plot contour
2 parents 945f856 + 2e143f2 commit e6e0cb4

File tree

1 file changed

+31
-79
lines changed

1 file changed

+31
-79
lines changed

optuna/visualization/matplotlib/_contour.py

Lines changed: 31 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -190,19 +190,7 @@ def _calculate_axis_data(
190190
return ci, cat_param_labels, cat_param_pos, list(returned_values)
191191

192192

193-
def _calculate_griddata(
194-
info: _SubContourInfo,
195-
) -> tuple[
196-
np.ndarray,
197-
np.ndarray,
198-
np.ndarray,
199-
list[int],
200-
list[str],
201-
list[int],
202-
list[str],
203-
_PlotValues,
204-
_PlotValues,
205-
]:
193+
def _calculate_griddata(info: _SubContourInfo) -> tuple[np.ndarray, _PlotValues, _PlotValues]:
206194
xaxis = info.xaxis
207195
yaxis = info.yaxis
208196
z_values_dict = info.z_values
@@ -220,17 +208,7 @@ def _calculate_griddata(
220208

221209
# Return empty values when x or y has no value.
222210
if len(x_values) == 0 or len(y_values) == 0:
223-
return (
224-
np.array([]),
225-
np.array([]),
226-
np.array([]),
227-
[],
228-
[],
229-
[],
230-
[],
231-
_PlotValues([], []),
232-
_PlotValues([], []),
233-
)
211+
return np.array([]), _PlotValues([], []), _PlotValues([], [])
234212

235213
xi, cat_param_labels_x, cat_param_pos_x, transformed_x_values = _calculate_axis_data(
236214
xaxis,
@@ -261,90 +239,64 @@ def _calculate_griddata(
261239
infeasible.x.append(x_value)
262240
infeasible.y.append(y_value)
263241

264-
return (
265-
xi,
266-
yi,
267-
zi,
268-
cat_param_pos_x,
269-
cat_param_labels_x,
270-
cat_param_pos_y,
271-
cat_param_labels_y,
272-
feasible,
273-
infeasible,
274-
)
242+
return zi, feasible, infeasible
275243

276244

277245
def _generate_contour_subplot(
278246
info: _SubContourInfo, ax: "Axes", cmap: "Colormap"
279247
) -> "ContourSet" | None:
248+
ax.label_outer()
249+
280250
if len(info.xaxis.indices) < 2 or len(info.yaxis.indices) < 2:
281-
ax.label_outer()
282251
return None
283252

284253
ax.set(xlabel=info.xaxis.name, ylabel=info.yaxis.name)
285254
ax.set_xlim(info.xaxis.range[0], info.xaxis.range[1])
286255
ax.set_ylim(info.yaxis.range[0], info.yaxis.range[1])
287256
x_values, y_values = _filter_missing_values(info.xaxis, info.yaxis)
257+
xi, x_cat_param_label, x_cat_param_pos, _ = _calculate_axis_data(info.xaxis, x_values)
258+
yi, y_cat_param_label, y_cat_param_pos, _ = _calculate_axis_data(info.yaxis, y_values)
288259
if info.xaxis.is_cat:
289-
_, x_cat_param_label, x_cat_param_pos, _ = _calculate_axis_data(info.xaxis, x_values)
290260
ax.set_xticks(x_cat_param_pos)
291261
ax.set_xticklabels(x_cat_param_label)
292262
else:
293263
ax.set_xscale("log" if info.xaxis.is_log else "linear")
294264
if info.yaxis.is_cat:
295-
_, y_cat_param_label, y_cat_param_pos, _ = _calculate_axis_data(info.yaxis, y_values)
296265
ax.set_yticks(y_cat_param_pos)
297266
ax.set_yticklabels(y_cat_param_label)
298267
else:
299268
ax.set_yscale("log" if info.yaxis.is_log else "linear")
300269

301270
if info.xaxis.name == info.yaxis.name:
302-
ax.label_outer()
303271
return None
304272

305-
(
306-
xi,
307-
yi,
308-
zi,
309-
x_cat_param_pos,
310-
x_cat_param_label,
311-
y_cat_param_pos,
312-
y_cat_param_label,
313-
feasible_plot_values,
314-
infeasible_plot_values,
315-
) = _calculate_griddata(info)
273+
zi, feasible_plot_values, infeasible_plot_values = _calculate_griddata(info)
316274
cs = None
317275
if len(zi) > 0:
318-
if info.xaxis.is_log:
319-
ax.set_xscale("log")
320-
if info.yaxis.is_log:
321-
ax.set_yscale("log")
322-
if info.xaxis.name != info.yaxis.name:
323-
# Contour the gridded data.
324-
ax.contour(xi, yi, zi, 15, linewidths=0.5, colors="k")
325-
cs = ax.contourf(xi, yi, zi, 15, cmap=cmap.reversed())
326-
assert isinstance(cs, ContourSet)
327-
# Plot data points.
328-
ax.scatter(
329-
feasible_plot_values.x,
330-
feasible_plot_values.y,
331-
marker="o",
332-
c="black",
333-
s=20,
334-
edgecolors="grey",
335-
linewidth=2.0,
336-
)
337-
ax.scatter(
338-
infeasible_plot_values.x,
339-
infeasible_plot_values.y,
340-
marker="o",
341-
c="#cccccc",
342-
s=20,
343-
edgecolors="grey",
344-
linewidth=2.0,
345-
)
276+
# Contour the gridded data.
277+
ax.contour(xi, yi, zi, 15, linewidths=0.5, colors="k")
278+
cs = ax.contourf(xi, yi, zi, 15, cmap=cmap.reversed())
279+
assert isinstance(cs, ContourSet)
280+
# Plot data points.
281+
ax.scatter(
282+
feasible_plot_values.x,
283+
feasible_plot_values.y,
284+
marker="o",
285+
c="black",
286+
s=20,
287+
edgecolors="grey",
288+
linewidth=2.0,
289+
)
290+
ax.scatter(
291+
infeasible_plot_values.x,
292+
infeasible_plot_values.y,
293+
marker="o",
294+
c="#cccccc",
295+
s=20,
296+
edgecolors="grey",
297+
linewidth=2.0,
298+
)
346299

347-
ax.label_outer()
348300
return cs
349301

350302

0 commit comments

Comments
 (0)