@@ -332,8 +332,13 @@ def scattermap_plot(
332332 return ax
333333
334334
335- def hinton (data : pd .DataFrame , max_weight = None , ax = None ):
336- """Draw Hinton diagram for visualizing a the size and sign of a weight matrix.
335+ def hinton (
336+ data : pd .DataFrame ,
337+ max_weight : float | None = None ,
338+ ax : plt .Axes | None = None ,
339+ grid : bool = True ,
340+ ):
341+ r"""Draw Hinton diagram for visualizing a the size and sign of a weight matrix.
337342
338343 A red (blue) marker indicates a positive (negative) weight. The size scales as
339344 $\propto \sqrt{|w|}$.
@@ -342,6 +347,7 @@ def hinton(data: pd.DataFrame, max_weight=None, ax=None):
342347 data: Weights to plot.
343348 max_weight: The size that corresponds to a full width marker.
344349 ax: Axes to plot on.
350+ grid: Whether to draw a grid.
345351 """
346352 ax = ax if ax is not None else plt .gca ()
347353
@@ -362,6 +368,7 @@ def hinton(data: pd.DataFrame, max_weight=None, ax=None):
362368 size ,
363369 facecolor = color ,
364370 edgecolor = color ,
371+ zorder = 3 ,
365372 )
366373 ax .add_patch (rect )
367374
@@ -388,14 +395,21 @@ def _make_two_level_ticks(hierarchical_index: pd.MultiIndex) -> tuple:
388395 xticks , xtick_colours = _make_two_level_ticks (data .columns )
389396 ax .set_xticklabels (xticks , rotation = 90 )
390397 if xtick_colours is not None :
391- for xtick , color in zip (ax .get_xticklabels (), xtick_colours ):
398+ for j , ( xtick , color ) in enumerate ( zip (ax .get_xticklabels (), xtick_colours ) ):
392399 xtick .set_color (color )
400+ if grid :
401+ ax .axvline (range_x [j ], color = color , linewidth = 0.75 , zorder = 2 )
402+ elif grid :
403+ ax .xaxis .grid (True , zorder = 1 )
393404
394405 yticks , ytick_colours = _make_two_level_ticks (data .index )
395406 ax .set_yticklabels (yticks )
396407 if ytick_colours is not None :
397- for ytick , color in zip (ax .get_yticklabels (), ytick_colours ):
408+ for i , ( ytick , color ) in enumerate ( zip (ax .get_yticklabels (), ytick_colours ) ):
398409 ytick .set_color (color )
410+ if grid :
411+ ax .axhline (range_y [i ], color = color , linewidth = 0.75 , zorder = 2 )
412+ elif grid :
413+ ax .yaxis .grid (True , zorder = 1 )
399414
400- plt .grid (True , which = "both" )
401415 return ax
0 commit comments