Skip to content

Commit a697c73

Browse files
committed
feat: support multilayer plot using maidr-ts
1 parent e7ee68e commit a697c73

File tree

6 files changed

+124
-18
lines changed

6 files changed

+124
-18
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
4+
import maidr
5+
6+
"""
7+
Create a simple multilayer plot with a bar chart and a line chart.
8+
9+
Returns
10+
-------
11+
Tuple[plt.Figure, plt.Axes]
12+
The figure and axes objects of the created plot.
13+
14+
Examples
15+
--------
16+
>>> fig, ax = create_multilayer_plot()
17+
>>> isinstance(fig, plt.Figure)
18+
True
19+
"""
20+
maidr.set_engine("ts")
21+
# Generate sample data
22+
x = np.arange(5)
23+
bar_data = np.array([3, 5, 2, 7, 3])
24+
line_data = np.array([10, 8, 12, 14, 9])
25+
26+
# Create a figure and a set of subplots
27+
fig, ax1 = plt.subplots(figsize=(8, 5))
28+
29+
# Create the bar chart on the first y-axis
30+
ax1.bar(x, bar_data, color="skyblue", label="Bar Data")
31+
ax1.set_xlabel("X values")
32+
ax1.set_ylabel("Bar values", color="blue")
33+
ax1.tick_params(axis="y", labelcolor="blue")
34+
35+
# Create a second y-axis sharing the same x-axis
36+
ax2 = ax1.twinx()
37+
38+
# Create the line chart on the second y-axis
39+
ax2.plot(x, line_data, color="red", marker="o", linestyle="-", label="Line Data")
40+
ax2.set_xlabel("X values")
41+
ax2.set_ylabel("Line values", color="red")
42+
ax2.tick_params(axis="y", labelcolor="red")
43+
44+
# Add title and legend
45+
plt.title("Multilayer Plot Example")
46+
47+
# Add legends for both axes
48+
lines1, labels1 = ax1.get_legend_handles_labels()
49+
lines2, labels2 = ax2.get_legend_handles_labels()
50+
ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper left")
51+
52+
# Adjust layout
53+
fig.tight_layout()
54+
55+
# plt.show()
56+
maidr.show(fig)
57+
# maidr.save_html(fig, "multi-layer.html")

maidr/api.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,31 @@ def render(plot: Any) -> Tag:
2020

2121
def show(plot: Any, renderer: Literal["auto", "ipython", "browser"] = "auto") -> object:
2222
ax = FigureManager.get_axes(plot)
23-
maidr = FigureManager.get_maidr(ax.get_figure())
24-
return maidr.show(renderer)
23+
htmls = []
24+
if type(ax) is list:
25+
for axes in ax:
26+
maidr = FigureManager.get_maidr(axes.get_figure())
27+
htmls.append(maidr.render())
28+
htmls[-1].show(renderer)
29+
else:
30+
maidr = FigureManager.get_maidr(ax.get_figure())
31+
return maidr.show(renderer)
2532

2633

2734
def save_html(
2835
plot: Any, file: str, *, lib_dir: str | None = "lib", include_version: bool = True
2936
) -> str:
3037
ax = FigureManager.get_axes(plot)
31-
maidr = FigureManager.get_maidr(ax.get_figure())
32-
return maidr.save_html(file, lib_dir=lib_dir, include_version=include_version)
38+
htmls = []
39+
if type(ax) is list:
40+
for axes in ax:
41+
maidr = FigureManager.get_maidr(axes.get_figure())
42+
htmls.append(maidr.render())
43+
htmls[-1].save_html(file, libdir=lib_dir, include_version=include_version)
44+
return htmls[-1]
45+
else:
46+
maidr = FigureManager.get_maidr(ax.get_figure())
47+
return maidr.save_html(file, lib_dir=lib_dir, include_version=include_version)
3348

3449

3550
def stacked(plot: Axes | BarContainer) -> Maidr:

maidr/core/maidr.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _create_html_doc(self) -> HTMLDocument:
139139

140140
def _flatten_maidr(self) -> dict | list[dict]:
141141
"""Return a single plot schema or a list of schemas from the Maidr instance."""
142-
if self.plot_type == PlotType.LINE:
142+
if self.plot_type == PlotType.LINE or self.plot_type == PlotType.DODGED:
143143
self._plots = [self._plots[0]]
144144
maidr = [plot.schema for plot in self._plots]
145145

@@ -150,7 +150,11 @@ def _flatten_maidr(self) -> dict | list[dict]:
150150
"maidr='true'", f"maidr='{self.selector_id}'"
151151
)
152152

153-
return maidr if len(maidr) != 1 else maidr[0]
153+
# return maidr if len(maidr) != 1 else maidr[0]
154+
return {
155+
"id": Maidr._unique_id(),
156+
"panels": [[{"id": Maidr._unique_id(), "layers": maidr}]],
157+
}
154158

155159
def _get_svg(self) -> HTML:
156160
"""Extract the chart SVG from ``matplotlib.figure.Figure``."""
@@ -200,6 +204,7 @@ def _inject_plot(plot: HTML, maidr: str, maidr_id) -> Tag:
200204

201205
engine = Environment.get_engine()
202206

207+
# MAIDR_TS_CDN_URL = "http://localhost:8000/maidr.js" # DEMO URL
203208
MAIDR_TS_CDN_URL = "https://cdn.jsdelivr.net/npm/maidr-ts/dist/maidr.js"
204209

205210
maidr_js_script = f"""

maidr/core/plot/barplot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _extract_plot_data(self) -> list:
4141
combined_data = (
4242
zip(levels, data) if plot[0].orientation == "vertical" else zip(levels, data) # type: ignore
4343
)
44-
if len(data) == len(plot): # type: ignore
44+
if combined_data: # type: ignore
4545
for x, y in combined_data: # type: ignore
4646
formatted_data.append({"x": x, "y": y})
4747
return formatted_data

maidr/patch/barplot.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,18 @@ def bar(
5252
bottom = kwargs.get("bottom")
5353
if bottom is not None:
5454
plot_type = PlotType.STACKED
55-
elif args:
56-
x = args[0]
57-
is_numeric = False
58-
if isinstance(x, np.ndarray) and np.issubdtype(x.dtype, np.number):
59-
is_numeric = True
60-
elif isinstance(x, (list, tuple)) and x and isinstance(x[0], Number):
61-
is_numeric = True
62-
if is_numeric:
63-
plot_type = PlotType.DODGED
55+
else:
56+
if len(args) >= 3:
57+
real_width = args[2]
58+
else:
59+
real_width = kwargs.get("width", 0.8)
60+
61+
align = kwargs.get("align", "center")
6462

63+
if (isinstance(real_width, (int, float)) and float(real_width) < 0.8) or (
64+
align == "edge"
65+
):
66+
plot_type = PlotType.DODGED
6567
return common(plot_type, wrapped, instance, args, kwargs)
6668

6769

maidr/util/mixin/extractor_mixin.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,36 @@ def extract_level(ax: Axes, key: MaidrKey = MaidrKey.X) -> list[str] | None:
4545

4646
level = None
4747
if MaidrKey.X == key:
48-
level = [label.get_text() for label in ax.get_xticklabels()]
48+
ticks = ax.get_xticks()
49+
labels = [label.get_text() for label in ax.get_xticklabels()]
50+
51+
if hasattr(ax, "dataLim") and ax.dataLim.width != 0:
52+
# Use the actual data limits rather than padded view limits
53+
data_x_min, data_x_max = ax.dataLim.x0, ax.dataLim.x0 + ax.dataLim.width
54+
# Filter tick labels to only those within the actual data range
55+
valid_indices = [
56+
i for i, pos in enumerate(ticks) if data_x_min <= pos <= data_x_max
57+
]
58+
labels = [labels[i] for i in valid_indices if i < len(labels)]
59+
60+
level = labels
4961
elif MaidrKey.Y == key:
50-
level = [label.get_text() for label in ax.get_yticklabels()]
62+
ticks = ax.get_yticks()
63+
labels = [label.get_text() for label in ax.get_yticklabels()]
64+
65+
if hasattr(ax, "dataLim") and ax.dataLim.height != 0:
66+
# Use the actual data limits rather than padded view limits
67+
data_y_min, data_y_max = (
68+
ax.dataLim.y0,
69+
ax.dataLim.y0 + ax.dataLim.height,
70+
)
71+
# Filter tick labels to only those within the actual data range
72+
valid_indices = [
73+
i for i, pos in enumerate(ticks) if data_y_min <= pos <= data_y_max
74+
]
75+
labels = [labels[i] for i in valid_indices if i < len(labels)]
76+
77+
level = labels
5178
elif MaidrKey.FILL == key:
5279
level = [container.get_label() for container in ax.containers]
5380

0 commit comments

Comments
 (0)