Skip to content

Commit 2ca1df7

Browse files
committed
Add grid option
1 parent 09ba735 commit 2ca1df7

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

src/toplot/weights.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,13 @@ def scattermap_plot(
332332
return ax
333333

334334

335-
def hinton(data: pd.DataFrame, max_weight=None, ax=None):
336-
"""Draw Hinton diagram for visualizing a the size and sign of a weight matrix.
335+
def hinton(
336+
data: pd.DataFrame,
337+
max_weight: float | None = None,
338+
ax: plt.Axes | None = None,
339+
grid: bool = True,
340+
):
341+
r"""Draw Hinton diagram for visualizing a the size and sign of a weight matrix.
337342
338343
A red (blue) marker indicates a positive (negative) weight. The size scales as
339344
$\propto \sqrt{|w|}$.
@@ -342,6 +347,7 @@ def hinton(data: pd.DataFrame, max_weight=None, ax=None):
342347
data: Weights to plot.
343348
max_weight: The size that corresponds to a full width marker.
344349
ax: Axes to plot on.
350+
grid: Whether to draw a grid.
345351
"""
346352
ax = ax if ax is not None else plt.gca()
347353

@@ -362,6 +368,7 @@ def hinton(data: pd.DataFrame, max_weight=None, ax=None):
362368
size,
363369
facecolor=color,
364370
edgecolor=color,
371+
zorder=3,
365372
)
366373
ax.add_patch(rect)
367374

@@ -388,14 +395,21 @@ def _make_two_level_ticks(hierarchical_index: pd.MultiIndex) -> tuple:
388395
xticks, xtick_colours = _make_two_level_ticks(data.columns)
389396
ax.set_xticklabels(xticks, rotation=90)
390397
if xtick_colours is not None:
391-
for xtick, color in zip(ax.get_xticklabels(), xtick_colours):
398+
for j, (xtick, color) in enumerate(zip(ax.get_xticklabels(), xtick_colours)):
392399
xtick.set_color(color)
400+
if grid:
401+
ax.axvline(range_x[j], color=color, linewidth=0.75, zorder=2)
402+
elif grid:
403+
ax.xaxis.grid(True, zorder=1)
393404

394405
yticks, ytick_colours = _make_two_level_ticks(data.index)
395406
ax.set_yticklabels(yticks)
396407
if ytick_colours is not None:
397-
for ytick, color in zip(ax.get_yticklabels(), ytick_colours):
408+
for i, (ytick, color) in enumerate(zip(ax.get_yticklabels(), ytick_colours)):
398409
ytick.set_color(color)
410+
if grid:
411+
ax.axhline(range_y[i], color=color, linewidth=0.75, zorder=2)
412+
elif grid:
413+
ax.yaxis.grid(True, zorder=1)
399414

400-
plt.grid(True, which="both")
401415
return ax

0 commit comments

Comments
 (0)