Skip to content

feat: support subplots in py-maidr using maidr-ts #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 18, 2025
57 changes: 57 additions & 0 deletions example/multilayer/example_mpl_multilayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import matplotlib.pyplot as plt
import numpy as np

import maidr

"""
Create a simple multilayer plot with a bar chart and a line chart.

Returns
-------
Tuple[plt.Figure, plt.Axes]
The figure and axes objects of the created plot.

Examples
--------
>>> fig, ax = create_multilayer_plot()
>>> isinstance(fig, plt.Figure)
True
"""
maidr.set_engine("ts")
# Generate sample data
x = np.arange(5)
bar_data = np.array([3, 5, 2, 7, 3])
line_data = np.array([10, 8, 12, 14, 9])

# Create a figure and a set of subplots
fig, ax1 = plt.subplots(figsize=(8, 5))

# Create the bar chart on the first y-axis
ax1.bar(x, bar_data, color="skyblue", label="Bar Data")
ax1.set_xlabel("X values")
ax1.set_ylabel("Bar values", color="blue")
ax1.tick_params(axis="y", labelcolor="blue")

# Create a second y-axis sharing the same x-axis
ax2 = ax1.twinx()

# Create the line chart on the second y-axis
ax2.plot(x, line_data, color="red", marker="o", linestyle="-", label="Line Data")
ax2.set_xlabel("X values")
ax2.set_ylabel("Line values", color="red")
ax2.tick_params(axis="y", labelcolor="red")

# Add title and legend
plt.title("Multilayer Plot Example")

# Add legends for both axes
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper left")

# Adjust layout
fig.tight_layout()

# plt.show()
maidr.show(fig)
# maidr.save_html(fig, "multi-layer.html")
140 changes: 140 additions & 0 deletions example/multilayer/example_multilayer_plot.ipynb

Large diffs are not rendered by default.

320 changes: 320 additions & 0 deletions example/multipanel/example_multipanel_plot.ipynb

Large diffs are not rendered by default.

57 changes: 57 additions & 0 deletions example/multipanel/matplotlib/example_mpl_multipanel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
"""
Example of creating a multipanel plot with matplotlib.

This script demonstrates how to create a figure with multiple panels
containing different types of plots: line plot, bar plot, and scatter plot.
"""

import matplotlib.pyplot as plt
import numpy as np

import maidr

maidr.set_engine("ts")

x_line = np.array([1, 2, 3, 4, 5, 6, 7, 8])
y_line = np.array([2, 4, 1, 5, 3, 7, 6, 8])

# Data for bar plot
categories = ["A", "B", "C", "D", "E"]
values = np.random.rand(5) * 10

# Data for bar plot
categories_2 = ["A", "B", "C", "D", "E"]
values_2 = np.random.randn(5) * 100

# Data for scatter plot
x_scatter = np.random.randn(50)
y_scatter = np.random.randn(50)

# Create a figure with 3 subplots arranged vertically
fig, axs = plt.subplots(3, 1, figsize=(10, 12))

# First panel: Line plot
axs[0].plot(x_line, y_line, color="blue", linewidth=2)
axs[0].set_title("Line Plot: Random Data")
axs[0].set_xlabel("X-axis")
axs[0].set_ylabel("Values")
axs[0].grid(True, linestyle="--", alpha=0.7)

# Second panel: Bar plot
axs[1].bar(categories, values, color="green", alpha=0.7)
axs[1].set_title("Bar Plot: Random Values")
axs[1].set_xlabel("Categories")
axs[1].set_ylabel("Values")

# Third panel: Bar plot
axs[2].bar(categories_2, values_2, color="blue", alpha=0.7)
axs[2].set_title("Bar Plot 2: Random Values")
axs[2].set_xlabel("Categories")
axs[2].set_ylabel("Values")

# Adjust layout to prevent overlap
plt.tight_layout()

# Display the figure
maidr.show(fig)
65 changes: 65 additions & 0 deletions example/multipanel/seaborn/example_sns_multipanel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python3
"""
Example of creating a multipanel plot with seaborn.

This script demonstrates how to create a figure with multiple panels
containing different types of plots using seaborn: line plot, bar plot, and bar plot.
"""

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import maidr

# Set the plotting style
sns.set_theme(style="whitegrid")

# Set the maidr engine
maidr.set_engine("ts")

# Data for line plot
x_line = np.array([1, 2, 3, 4, 5, 6, 7, 8])
y_line = np.array([2, 4, 1, 5, 3, 7, 6, 8])
line_data = {"x": x_line, "y": y_line}

# Data for first bar plot
categories = ["A", "B", "C", "D", "E"]
values = np.random.rand(5) * 10
bar_data = {"categories": categories, "values": values}

# Data for second bar plot
categories_2 = ["A", "B", "C", "D", "E"]
values_2 = np.random.randn(5) * 100
bar_data_2 = {"categories": categories_2, "values": values_2}

# Create a figure with 3 subplots arranged vertically
fig, axs = plt.subplots(3, 1, figsize=(10, 12))

# First panel: Line plot using seaborn
sns.lineplot(x="x", y="y", data=line_data, color="blue", linewidth=2, ax=axs[0])
axs[0].set_title("Line Plot: Random Data")
axs[0].set_xlabel("X-axis")
axs[0].set_ylabel("Values")

# Second panel: Bar plot using seaborn
sns.barplot(
x="categories", y="values", data=bar_data, color="green", alpha=0.7, ax=axs[1]
)
axs[1].set_title("Bar Plot: Random Values")
axs[1].set_xlabel("Categories")
axs[1].set_ylabel("Values")

# Third panel: Bar plot using seaborn
sns.barplot(
x="categories", y="values", data=bar_data_2, color="blue", alpha=0.7, ax=axs[2]
)
axs[2].set_title("Bar Plot 2: Random Values") # Fixed the typo in the title
axs[2].set_xlabel("Categories")
axs[2].set_ylabel("Values")

# Adjust layout to prevent overlap
plt.tight_layout()

# Display the figure
maidr.show(fig)
10 changes: 7 additions & 3 deletions example/stacked/matplotlib/example_mpl_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import maidr

species = (
"Adelie\n $\\mu=$3700.66g",
"Chinstrap\n $\\mu=$3733.09g",
"Gentoo\n $\\mu=5076.02g$",
"Adelie\nMean = 3700.66g",
"Chinstrap\nMean = 3733.09g",
"Gentoo\nMean = 5076.02g",
)
weight_counts = {
"Below": np.array([70, 31, 58]),
Expand All @@ -15,12 +15,16 @@
width = 0.5

fig, ax = plt.subplots()

bottom = np.zeros(3)

for boolean, weight_count in weight_counts.items():
p = ax.bar(species, weight_count, width, label=boolean, bottom=bottom)
bottom += weight_count

ax.set_xlabel("Species of Penguins")
ax.set_ylabel("Average Body Mass")

ax.set_title("Number of penguins with above average body mass")
ax.legend(loc="upper right")

Expand Down
23 changes: 19 additions & 4 deletions maidr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,31 @@ def render(plot: Any) -> Tag:

def show(plot: Any, renderer: Literal["auto", "ipython", "browser"] = "auto") -> object:
ax = FigureManager.get_axes(plot)
maidr = FigureManager.get_maidr(ax.get_figure())
return maidr.show(renderer)
htmls = []
if isinstance(ax, list):
for axes in ax:
maidr = FigureManager.get_maidr(axes.get_figure())
htmls.append(maidr.render())
return htmls[-1].show(renderer)
else:
maidr = FigureManager.get_maidr(ax.get_figure())
return maidr.show(renderer)


def save_html(
plot: Any, file: str, *, lib_dir: str | None = "lib", include_version: bool = True
) -> str:
ax = FigureManager.get_axes(plot)
maidr = FigureManager.get_maidr(ax.get_figure())
return maidr.save_html(file, lib_dir=lib_dir, include_version=include_version)
htmls = []
if isinstance(ax, list):
for axes in ax:
maidr = FigureManager.get_maidr(axes.get_figure())
htmls.append(maidr.render())
htmls[-1].save_html(file, libdir=lib_dir, include_version=include_version)
return htmls[-1]
else:
maidr = FigureManager.get_maidr(ax.get_figure())
return maidr.save_html(file, lib_dir=lib_dir, include_version=include_version)


def stacked(plot: Axes | BarContainer) -> Maidr:
Expand Down
69 changes: 59 additions & 10 deletions maidr/core/maidr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile
import uuid
import webbrowser
from typing import Literal
from typing import Any, Literal

from htmltools import HTML, HTMLDocument, Tag, tags
from lxml import etree
Expand Down Expand Up @@ -139,18 +139,66 @@ def _create_html_doc(self) -> HTMLDocument:

def _flatten_maidr(self) -> dict | list[dict]:
"""Return a single plot schema or a list of schemas from the Maidr instance."""
if self.plot_type == PlotType.LINE:
self._plots = [self._plots[0]]
maidr = [plot.schema for plot in self._plots]

# Replace the selector having maidr='true' with maidr={self.maidr_id}
for plot in maidr:
if MaidrKey.SELECTOR in plot:
plot[MaidrKey.SELECTOR] = plot[MaidrKey.SELECTOR].replace(
# To support legacy JS Engine we will just return the format in this way
# but soon enough this should be deprecated and when we will completely
# transition to TypeScript :)
engine = Environment.get_engine()
if engine == "js":
if self.plot_type in (PlotType.LINE, PlotType.DODGED, PlotType.STACKED):
self._plots = [self._plots[0]]
maidr = [plot.schema for plot in self._plots]
for plot in maidr:
if MaidrKey.SELECTOR in plot:
plot[MaidrKey.SELECTOR] = plot[MaidrKey.SELECTOR].replace(
"maidr='true'", f"maidr='{self.selector_id}'"
)
return maidr if len(maidr) != 1 else maidr[0]

# Now let's start building the maidr object for the newer TypeScript engine

plot_schemas = []

for plot in self._plots:
schema = plot.schema
if MaidrKey.SELECTOR in schema:
schema[MaidrKey.SELECTOR] = schema[MaidrKey.SELECTOR].replace(
"maidr='true'", f"maidr='{self.selector_id}'"
)
plot_schemas.append(
{
"schema": schema,
"row": getattr(plot, "row_index", 0),
"col": getattr(plot, "col_index", 0),
}
)

max_row = max([plot.get("row", 0) for plot in plot_schemas], default=0)
max_col = max([plot.get("col", 0) for plot in plot_schemas], default=0)

subplot_grid: list[list[dict[str, str | list[Any]]]] = [
[{} for _ in range(max_col + 1)] for _ in range(max_row + 1)
]

position_groups = {}
for plot in plot_schemas:
pos = (plot.get("row", 0), plot.get("col", 0))
if pos not in position_groups:
position_groups[pos] = []
position_groups[pos].append(plot["schema"])

for (row, col), layers in position_groups.items():
if subplot_grid[row][col]:
subplot_grid[row][col]["layers"].append(layers)
else:
subplot_grid[row][col] = {"id": Maidr._unique_id(), "layers": layers}

for i in range(len(subplot_grid)):
subplot_grid[i] = [
cell if cell is not None else {"id": Maidr._unique_id(), "layers": []}
for cell in subplot_grid[i]
]

return maidr if len(maidr) != 1 else maidr[0]
return {"id": Maidr._unique_id(), "subplots": subplot_grid}

def _get_svg(self) -> HTML:
"""Extract the chart SVG from ``matplotlib.figure.Figure``."""
Expand Down Expand Up @@ -200,6 +248,7 @@ def _inject_plot(plot: HTML, maidr: str, maidr_id) -> Tag:

engine = Environment.get_engine()

# MAIDR_TS_CDN_URL = "http://localhost:8080/maidr.js" # DEMO URL
MAIDR_TS_CDN_URL = "https://cdn.jsdelivr.net/npm/maidr-ts/dist/maidr.js"

maidr_js_script = f"""
Expand Down
6 changes: 3 additions & 3 deletions maidr/core/plot/barplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def _extract_plot_data(self) -> list:
levels = self.extract_level(self.ax)
if engine == "ts":
formatted_data = []
combined_data = (
zip(levels, data) if plot[0].orientation == "vertical" else zip(levels, data) # type: ignore
combined_data = list(
zip(levels, data) if plot[0].orientation == "vertical" else zip(data, levels) # type: ignore
)
if len(data) == len(plot): # type: ignore
if combined_data: # type: ignore
for x, y in combined_data: # type: ignore
formatted_data.append({"x": x, "y": y})
return formatted_data
Expand Down
3 changes: 3 additions & 0 deletions maidr/core/plot/maidr_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def __init__(self, ax: Axes, plot_type: PlotType) -> None:
self.ax = ax
self._support_highlighting = True
self._elements = []
ss = self.ax.get_subplotspec()
self.row_index = ss.rowspan.start
self.col_index = ss.colspan.start

# MAIDR data
self.type = plot_type
Expand Down
20 changes: 11 additions & 9 deletions maidr/patch/barplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,18 @@ def bar(
bottom = kwargs.get("bottom")
if bottom is not None:
plot_type = PlotType.STACKED
elif args:
x = args[0]
is_numeric = False
if isinstance(x, np.ndarray) and np.issubdtype(x.dtype, np.number):
is_numeric = True
elif isinstance(x, (list, tuple)) and x and isinstance(x[0], Number):
is_numeric = True
if is_numeric:
plot_type = PlotType.DODGED
else:
if len(args) >= 3:
real_width = args[2]
else:
real_width = kwargs.get("width", 0.8)

align = kwargs.get("align", "center")

if (isinstance(real_width, (int, float)) and float(real_width) < 0.8) or (
align == "edge"
):
plot_type = PlotType.DODGED
return common(plot_type, wrapped, instance, args, kwargs)


Expand Down
Loading