Skip to content

Commit ca9295d

Browse files
committed
Add tranpose option for bar_plot_stacked
1 parent 2ca1df7 commit ca9295d

File tree

1 file changed

+46
-14
lines changed

1 file changed

+46
-14
lines changed

src/toplot/weights.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,38 @@
3131
from toplot.scattermap import scattermap
3232

3333

34+
def _bar(names, sizes, offsets, error_bars, color, transpose: bool = False, ax=None):
35+
"""Like matplotlib `bar` but can be transposed."""
36+
if ax is None:
37+
ax = plt.gca()
38+
39+
if not transpose:
40+
return ax.bar(
41+
names,
42+
sizes,
43+
bottom=offsets,
44+
color=color,
45+
yerr=error_bars,
46+
)
47+
48+
ax.invert_yaxis()
49+
return ax.barh(
50+
names,
51+
sizes,
52+
left=offsets,
53+
color=color,
54+
xerr=error_bars,
55+
)
56+
57+
3458
def bar_plot_stacked(
3559
dataframe,
3660
quantile_range=(0.025, 0.975),
3761
height: Literal["mean", "median"] = "mean",
3862
ax=None,
3963
labels: bool = True,
4064
fontsize=None,
65+
transpose: bool = False,
4166
):
4267
"""Plot posterior of a topic as probability bars by stacking categories per set.
4368
@@ -53,6 +78,7 @@ def bar_plot_stacked(
5378
ax: Matplotlib axes to plot on.
5479
labels: If `True`, annotate bars with category labels.
5580
fontsize: Font size for the category labels.
81+
transpose: If `True`, swap the x and y axes of the bar plot.
5682
5783
Example:
5884
```python
@@ -126,26 +152,32 @@ def bar_plot_stacked(
126152
if j == n_categories - 1:
127153
err_j = None
128154

129-
ax.bar(
155+
_bar(
130156
feature_name,
131-
feature_weights.loc[category],
132-
bottom=offsets[j],
157+
sizes=feature_weights.loc[category],
158+
offsets=offsets[j],
159+
error_bars=err_j,
133160
color=color,
134-
yerr=err_j,
161+
transpose=transpose,
162+
ax=ax,
135163
)
136164
if labels:
137-
ax.text(
138-
x=feature_name,
139-
y=offsets[j] + feature_weights.loc[category] / 2,
140-
s=category,
141-
ha="center",
142-
va="center",
143-
fontsize=fontsize,
165+
text_properties = dict(
166+
s=category, ha="center", va="center", fontsize=fontsize
144167
)
168+
position = offsets[j] + feature_weights.loc[category] / 2
169+
if not transpose:
170+
ax.text(x=feature_name, y=position, **text_properties)
171+
else:
172+
ax.text(x=position, y=feature_name, **text_properties)
145173
# Rotate the x-axis labels.
146-
ax.tick_params(axis="x", labelrotation=90, labelsize=fontsize)
147-
ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
148-
ax.set_ylabel("Probability")
174+
if not transpose:
175+
ax.tick_params(axis="x", labelrotation=90, labelsize=fontsize)
176+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
177+
ax.set_ylabel("Probability")
178+
else:
179+
ax.xaxis.set_major_formatter(mtick.PercentFormatter(1.0))
180+
ax.set_xlabel("Probability")
149181
return ax
150182

151183

0 commit comments

Comments
 (0)