Skip to content

Commit c2cb99d

Browse files
authored
feat: support hightlighing except for segmented plots and boxplots (#59)
1 parent 26c1e13 commit c2cb99d

16 files changed

+230
-98
lines changed

example/line/matplotlib/example_mpl_line.py

-1
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,3 @@
2727
# Show the plot
2828
plt.show()
2929
maidr.show(line_plot)
30-
print()

maidr/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22

33
from .core import Maidr
44
from .core.enum import PlotType
5-
from .patch import barplot, boxplot, clear, heatmap, histogram, lineplot, scatterplot
5+
from .patch import (
6+
barplot,
7+
boxplot,
8+
clear,
9+
heatmap,
10+
highlight,
11+
histogram,
12+
lineplot,
13+
scatterplot,
14+
)
615
from .maidr import close, save_html, show, stacked
716

817
__all__ = [

maidr/core/context_manager.py

+79-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from __future__ import annotations
2+
13
import contextlib
24
import contextvars
35
import threading
46

57
import wrapt
68

9+
from maidr.core.plot.boxplot import BoxPlotContainer
10+
711

812
class ContextManager:
913
_instance = None
@@ -18,6 +22,10 @@ def __new__(cls):
1822
cls._instance = super(ContextManager, cls).__new__()
1923
return cls._instance
2024

25+
@classmethod
26+
def is_internal_context(cls):
27+
return cls._internal_context.get()
28+
2129
@classmethod
2230
@contextlib.contextmanager
2331
def set_internal_context(cls):
@@ -27,10 +35,6 @@ def set_internal_context(cls):
2735
finally:
2836
cls._internal_context.reset(token_internal_context)
2937

30-
@classmethod
31-
def is_internal_context(cls):
32-
return cls._internal_context.get()
33-
3438

3539
@wrapt.decorator
3640
def manage_context(wrapped=None, _=None, args=None, kwargs=None):
@@ -41,3 +45,74 @@ def manage_context(wrapped=None, _=None, args=None, kwargs=None):
4145
# Set the internal context to avoid cyclic processing.
4246
with ContextManager.set_internal_context():
4347
return wrapped(*args, **kwargs)
48+
49+
50+
class BoxplotContextManager(ContextManager):
51+
_bxp_context = contextvars.ContextVar("bxp_context", default=BoxPlotContainer())
52+
53+
@classmethod
54+
@contextlib.contextmanager
55+
def set_internal_context(cls):
56+
with super(BoxplotContextManager, cls).set_internal_context():
57+
token = cls._bxp_context.set(BoxPlotContainer())
58+
try:
59+
yield cls.get_bxp_context()
60+
finally:
61+
cls._bxp_context.reset(token)
62+
63+
@classmethod
64+
def get_bxp_context(cls) -> BoxPlotContainer:
65+
return cls._bxp_context.get()
66+
67+
@classmethod
68+
def add_bxp_context(cls, bxp_context: dict) -> None:
69+
cls.get_bxp_context().add_artists(bxp_context)
70+
71+
@classmethod
72+
def set_bxp_orientation(cls, orientation: str) -> None:
73+
cls.get_bxp_context().set_orientation(orientation)
74+
75+
76+
class HighlightContextManager:
77+
_instance = None
78+
_lock = threading.Lock()
79+
80+
_maidr_element = contextvars.ContextVar("_maidr_element", default=False)
81+
_elements = contextvars.ContextVar("elements", default=[])
82+
83+
def __new__(cls):
84+
if not cls._instance:
85+
with cls._lock:
86+
if not cls._instance:
87+
cls._instance = super(HighlightContextManager, cls).__new__()
88+
return cls._instance
89+
90+
@classmethod
91+
def is_maidr_element(cls):
92+
return cls._maidr_element.get()
93+
94+
@classmethod
95+
@contextlib.contextmanager
96+
def set_maidr_element(cls, element):
97+
if element not in cls._elements.get():
98+
yield
99+
return
100+
101+
token_maidr_element = cls._maidr_element.set(True)
102+
try:
103+
yield
104+
finally:
105+
cls._maidr_element.reset(token_maidr_element)
106+
# Remove element from the context list after tagging
107+
new_elements = cls._elements.get().copy()
108+
new_elements.remove(element)
109+
cls._elements.set(new_elements)
110+
111+
@classmethod
112+
@contextlib.contextmanager
113+
def set_maidr_elements(cls, elements: list):
114+
token_paths = cls._elements.set(elements)
115+
try:
116+
yield
117+
finally:
118+
cls._elements.reset(token_paths)

maidr/core/maidr.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from matplotlib.figure import Figure
1212

13+
from maidr.core.context_manager import HighlightContextManager
1314
from maidr.core.plot import MaidrPlot
1415

1516

@@ -104,7 +105,9 @@ def destroy(self) -> None:
104105

105106
def _create_html_tag(self) -> Tag:
106107
"""Create the MAIDR HTML using HTML tags."""
107-
svg = self._get_svg()
108+
tagged_elements = [element for plot in self._plots for element in plot.elements]
109+
with HighlightContextManager.set_maidr_elements(tagged_elements):
110+
svg = self._get_svg()
108111
maidr = f"\nlet maidr = {json.dumps(self._flatten_maidr(), indent=2)}\n"
109112

110113
# Inject plot's svg and MAIDR structure into html tag.

maidr/core/plot/barplot.py

+3
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,7 @@ def _extract_bar_container_data(
5151
if len(plot) != len(level):
5252
return None
5353

54+
# Tag the elements for highlighting.
55+
self._elements.extend(plot)
56+
5457
return [float(patch.get_height()) for patch in plot]

maidr/core/plot/boxplot.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,12 @@ class BoxPlot(
126126
DictMergerMixin,
127127
):
128128
def __init__(self, ax: Axes, **kwargs) -> None:
129+
super().__init__(ax, PlotType.BOX)
130+
129131
self._bxp_stats = kwargs.pop("bxp_stats", None)
130132
self._orientation = kwargs.pop("orientation", "vert")
131133
self._bxp_extractor = BoxPlotExtractor(orientation=self._orientation)
132-
super().__init__(ax, PlotType.BOX)
134+
self._support_highlighting = False
133135

134136
def render(self) -> dict:
135137
base_schema = super().render()

maidr/core/plot/grouped_barplot.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class GroupedBarPlot(
1818
):
1919
def __init__(self, ax: Axes, plot_type: PlotType) -> None:
2020
super().__init__(ax, plot_type)
21+
self._support_highlighting = False
2122

2223
def _extract_axes_data(self) -> dict:
2324
base_ax_schema = super()._extract_axes_data()

maidr/core/plot/heatmap.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import numpy.ma as ma
4+
35
from matplotlib.axes import Axes
46
from matplotlib.cm import ScalarMappable
57
from matplotlib.collections import QuadMesh
@@ -44,23 +46,29 @@ def _extract_axes_data(self) -> dict:
4446

4547
def _extract_plot_data(self) -> list[list]:
4648
plot = self.extract_scalar_mappable(self.ax)
47-
data = HeatPlot._extract_scalar_mappable_data(plot)
49+
data = self._extract_scalar_mappable_data(plot)
4850

4951
if data is None:
5052
raise ExtractionError(self.type, plot)
5153

5254
return data
5355

54-
@staticmethod
55-
def _extract_scalar_mappable_data(sm: ScalarMappable | None) -> list[list] | None:
56+
def _extract_scalar_mappable_data(
57+
self, sm: ScalarMappable | None
58+
) -> list[list] | None:
5659
if sm is None or sm.get_array() is None:
5760
return None
5861

5962
array = sm.get_array().data
6063
if isinstance(sm, QuadMesh):
6164
# Data gets flattened in ScalarMappable, when the plot is from QuadMesh.
6265
# So, reshaping the data to reflect the original quadrilaterals
63-
m, n, _ = sm.get_coordinates().shape
66+
m, n, _ = ma.shape(sm.get_coordinates())
6467
array = array.reshape(m - 1, n - 1) # Coordinates shape is (M + 1, N + 1)
6568

69+
# Tag the elements for highlighting
70+
self._elements.append(sm)
71+
else:
72+
self._support_highlighting = False
73+
6674
return [list(map(float, row)) for row in array]

maidr/core/plot/histogram.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@ def __init__(self, ax: Axes) -> None:
1515

1616
def _extract_plot_data(self) -> list[dict]:
1717
plot = self.extract_container(self.ax, BarContainer)
18-
data = HistPlot._extract_bar_container_data(plot)
18+
data = self._extract_bar_container_data(plot)
1919

2020
if data is None:
2121
raise ExtractionError(self.type, plot)
2222

2323
return data
2424

25-
@staticmethod
26-
def _extract_bar_container_data(plot: BarContainer | None) -> list[dict] | None:
25+
def _extract_bar_container_data(
26+
self, plot: BarContainer | None
27+
) -> list[dict] | None:
2728
if plot is None or plot.patches is None:
2829
return None
2930

@@ -44,4 +45,7 @@ def _extract_bar_container_data(plot: BarContainer | None) -> list[dict] | None:
4445
}
4546
)
4647

48+
# Tag the elements for highlighting
49+
self._elements.extend(plot.patches)
50+
4751
return data

maidr/core/plot/lineplot.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,29 @@ class LinePlot(MaidrPlot, LineExtractorMixin):
1313
def __init__(self, ax: Axes) -> None:
1414
super().__init__(ax, PlotType.LINE)
1515

16+
def _get_selector(self) -> str:
17+
return "g[maidr='true'] > path"
18+
1619
def _extract_plot_data(self) -> list[dict]:
1720
plot = self.extract_line(self.ax)
18-
data = LinePlot._extract_line_data(plot)
21+
data = self._extract_line_data(plot)
1922

2023
if data is None:
2124
raise ExtractionError(self.type, plot)
2225

2326
return data
2427

25-
@staticmethod
26-
def _extract_line_data(plot: Line2D | None) -> list[dict] | None:
28+
def _extract_line_data(self, plot: Line2D | None) -> list[dict] | None:
2729
if plot is None or plot.get_xydata() is None:
2830
return None
2931

32+
# Tag the elements for highlighting.
33+
self._elements.append(plot)
34+
3035
return [
3136
{
32-
MaidrKey.X.value: float(x),
33-
MaidrKey.Y.value: float(y),
37+
MaidrKey.X: float(x),
38+
MaidrKey.Y: float(y),
3439
}
3540
for x, y in plot.get_xydata()
3641
]

maidr/core/plot/maidr_plot.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,32 @@ class MaidrPlot(ABC):
3737
def __init__(self, ax: Axes, plot_type: PlotType) -> None:
3838
# graphic object
3939
self.ax = ax
40+
self._support_highlighting = True
41+
self._elements = []
4042

4143
# MAIDR data
4244
self.type = plot_type
4345
self._schema = {}
4446

4547
def render(self) -> dict:
4648
"""Initialize the MAIDR schema dictionary with basic plot information."""
47-
return {
49+
maidr_schema = {
4850
MaidrKey.TYPE: self.type,
4951
MaidrKey.TITLE: self.ax.get_title(),
5052
MaidrKey.AXES: self._extract_axes_data(),
5153
MaidrKey.DATA: self._extract_plot_data(),
5254
}
5355

56+
# Include selector only if the plot supports highlighting.
57+
if self._support_highlighting:
58+
maidr_schema[MaidrKey.SELECTOR] = self._get_selector()
59+
60+
return maidr_schema
61+
62+
def _get_selector(self) -> str:
63+
"""Return the CSS selector for highlighting elements."""
64+
return "path[maidr='true']"
65+
5466
def _extract_axes_data(self) -> dict:
5567
"""Extract the plot's axes data"""
5668
return {
@@ -74,8 +86,14 @@ def schema(self) -> dict:
7486
self._schema = self.render()
7587
return self._schema
7688

89+
@property
90+
def elements(self) -> list:
91+
if not self._schema:
92+
self._schema = self.render()
93+
return self._elements
94+
7795
def set_id(self, maidr_id: str) -> None:
7896
"""Set the unique identifier for the plot within the MAIDR schema."""
7997
if not self._schema:
8098
self._schema = self.render()
81-
self._schema[MaidrKey.ID.value] = maidr_id
99+
self._schema[MaidrKey.ID] = maidr_id

maidr/core/plot/scatterplot.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import numpy.ma as ma
4+
35
from matplotlib.axes import Axes
46
from matplotlib.collections import PathCollection
57

@@ -13,28 +15,29 @@ class ScatterPlot(MaidrPlot, CollectionExtractorMixin):
1315
def __init__(self, ax: Axes) -> None:
1416
super().__init__(ax, PlotType.SCATTER)
1517

18+
def _get_selector(self) -> str:
19+
return "g[maidr='true'] > use"
20+
1621
def _extract_plot_data(self) -> list[dict]:
1722
plot = self.extract_collection(self.ax, PathCollection)
18-
data = ScatterPlot._extract_point_data(plot)
23+
data = self._extract_point_data(plot)
1924

2025
if data is None:
2126
raise ExtractionError(self.type, plot)
2227

2328
return data
2429

25-
@staticmethod
26-
def _extract_point_data(plot: PathCollection | None) -> list[dict] | None:
30+
def _extract_point_data(self, plot: PathCollection | None) -> list[dict] | None:
2731
if plot is None or plot.get_offsets() is None:
2832
return None
2933

30-
data = []
31-
for point in plot.get_offsets().data:
32-
x, y = point
33-
data.append(
34-
{
35-
MaidrKey.X.value: float(x),
36-
MaidrKey.Y.value: float(y),
37-
}
38-
)
34+
# Tag the elements for highlighting.
35+
self._elements.append(plot)
3936

40-
return data
37+
return [
38+
{
39+
MaidrKey.X: float(x),
40+
MaidrKey.Y: float(y),
41+
}
42+
for x, y in ma.getdata(plot.get_offsets())
43+
]

0 commit comments

Comments
 (0)