diff --git a/MANIFEST.in b/MANIFEST.in
index 7c063ca..0b66f27 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,2 +1,3 @@
include src/lama_aesthetics/styles/*.mplstyle
-
+include src/lama_aesthetics/fonts/*.otf
+include src/lama_aesthetics/fonts/*.ttf
diff --git a/README.md b/README.md
index a9e221b..1c12369 100644
--- a/README.md
+++ b/README.md
@@ -5,10 +5,22 @@
[](https://lamalab-org.github.io/lama-aesthetics/)
[](https://img.shields.io/github/license/lamalab-org/lama-aesthetics)
-Plotting styles and helpers by LamaLab
+Publication-quality plotting styles and helpers for matplotlib by [LamaLab](https://lamalab.org).
-- **Github repository**:
-- **Documentation**
+- **GitHub repository**:
+- **Documentation**:
+
+## Features at a glance
+
+| Category | What you get |
+|---|---|
+| **Styles** | `main` (publication), `presentation` (talks), `dark` (dark-themed) |
+| **Bundled font** | CMU Sans Serif — auto-registered, no system install needed |
+| **Range frame** | Tufte-style axis frame trimmed to the data range, with automatic nice-number bounds for numeric axes and full support for categorical axes |
+| **Label helpers** | `ylabel_top` — horizontal y-label placed above the top tick |
+| **Reference lines** | `add_identity` — dynamic 1:1 diagonal that follows axis changes |
+| **Figure decomposition** | `decompose_figure` — split a multi-series figure into one figure per labeled artist |
+| **Dimension constants** | `ONE_COL_WIDTH`, `TWO_COL_WIDTH`, `ONE_COL_HEIGHT`, `TWO_COL_HEIGHT` (golden ratio) |
## Installation
@@ -24,21 +36,48 @@ uv pip install -e .
make install
```
-## Usage
+## Quick start
+
+```python
+import matplotlib.pyplot as plt
+import numpy as np
+import lama_aesthetics
+
+lama_aesthetics.get_style("main")
+
+fig, ax = plt.subplots()
+x = np.linspace(0.5, 9.3, 40)
+y = np.sin(x) * 10 + 15
+
+ax.plot(x, y)
+lama_aesthetics.range_frame(ax, x, y) # nice-number axis bounds by default
+lama_aesthetics.ylabel_top("y", ax=ax)
+
+plt.show()
+```
+
+---
-### Styles
+## Styles
-The library provides two main plotting styles:
+The library ships three matplotlib style sheets, all using the bundled **CMU Sans Serif** font:
-- **main**: Optimized for publications, reports, and other documents.
-- **presentation**: Features larger fonts and thicker lines for better visibility in presentations
+| Style | Apply with | Description |
+|---|---|---|
+| **main** | `get_style("main")` | Optimized for single- and two-column journal figures. Compact figure size (3.3 × 2.5 in), inward ticks, thin lines. |
+| **presentation** | `get_style("presentation")` | Same layout but with larger fonts (13 / 12 pt) for slides and posters. |
+| **dark** | `get_style("dark")` | Black background, white text and lines — ideal for dark-themed slides or dashboards. |
+
+All styles disable the right and top spines, use inward-facing ticks, and remove the legend frame for a clean Tufte-inspired look.
```python
-import matplotlib.pyplot as plt
-import numpy as np
import lama_aesthetics
-lama_aesthetics.get_style("main") # or la.get_style("presentation")
+lama_aesthetics.get_style("main")
+# or
+lama_aesthetics.get_style("presentation")
+# or
+lama_aesthetics.get_style("dark")
```
@@ -49,68 +88,198 @@ lama_aesthetics.get_style("main") # or la.get_style("presentation")
Left: Main style; Right: Presentation style
-### Helpers
+### Bundled font
+
+The **CMU Sans Serif** font (`cmunss.otf`) is bundled with the package and
+registered automatically the first time you apply a style. You can also
+register it manually:
+
+```python
+lama_aesthetics.register_fonts()
+font_name = lama_aesthetics.get_font_name() # → "CMU Sans Serif"
+```
+
+---
+
+## Plotting utilities
+
+### `range_frame` — Tufte-style range frame with nice-number bounds
+
+`range_frame` trims the axis spines to the data range and offsets them outward
+for a clean, separated look. It automatically detects whether each axis is
+**numeric** or **categorical** and handles them differently:
+
+**Numeric axes (default: `nice=True`):**
+The spine bounds are snapped to "nice" tick positions (multiples of 1, 2, 2.5,
+5, or 10 × 10^n) that bracket the data. This means the axis line always
+starts and ends exactly on a tick mark — no more floating spine endpoints.
+Tick positions are computed via matplotlib's `MaxNLocator` and explicitly set
+on the axes so there is zero drift between ticks and spine bounds.
+
+**Categorical axes:**
+When `x` or `y` contains strings (e.g. `["a", "b", "c"]`), the spine spans
+the integer range `0 .. len(values) - 1`. No rounding is applied.
+
+```python
+import matplotlib.pyplot as plt
+import numpy as np
+from lama_aesthetics.plotutils import range_frame
+
+fig, axes = plt.subplots(1, 3, figsize=(12, 3))
+
+# 1) Basic numeric — bounds snap to nice ticks
+x = np.linspace(3.2, 47.8, 20)
+y = x * 1.3 - 1.5
+axes[0].plot(x, y)
+range_frame(axes[0], x, y) # nice=True by default
+axes[0].set_title("Numeric (nice bounds)")
+
+# 2) Categorical x-axis
+categories = ["Model A", "Model B", "Model C", "Model D"]
+values = np.array([0.82, 0.91, 0.87, 0.95])
+axes[1].plot(categories, values)
+range_frame(axes[1], categories, values) # categorical x, numeric y
+axes[1].set_title("Categorical x-axis")
+
+# 3) Disable nice bounds — raw padding only
+axes[2].plot(x, y)
+range_frame(axes[2], x, y, nice=False, pad=0.05)
+axes[2].set_title("nice=False (raw padding)")
-The package includes several plotting utilities to enhance your visualizations:
+plt.tight_layout()
+plt.show()
+```
+
+#### Parameters
+
+| Parameter | Default | Description |
+|---|---|---|
+| `ax` | — | Matplotlib `Axes` object |
+| `x`, `y` | — | Data arrays (numeric or string/categorical) |
+| `pad` | `0.1` | Padding factor applied to both axes (only used when `nice=False`) |
+| `pad_x` | `None` | Per-axis padding near the x-axis (vertical). Overrides `pad`. |
+| `pad_y` | `None` | Per-axis padding near the y-axis (horizontal). Overrides `pad`. |
+| `nice` | `True` | Snap numeric spine bounds to nice tick positions that bracket the data. When `True`, padding parameters are ignored for numeric axes. |
+
+### `ylabel_top` — horizontal y-label above the axis
+
+Places the y-axis label horizontally above the top tick, making it easier to
+read without head-tilting:
+
+```python
+from lama_aesthetics.plotutils import ylabel_top
+
+fig, ax = plt.subplots()
+ax.plot([0, 1, 2], [0, 1, 4])
+ylabel_top("Energy (eV)", ax=ax)
+```
+
+| Parameter | Default | Description |
+|---|---|---|
+| `string` | — | Label text |
+| `ax` | `None` | Axes (defaults to `plt.gca()`) |
+| `x_pad` | `0.01` | Horizontal offset in axes coordinates |
+| `y_pad` | `0.02` | Vertical offset above the top tick |
+
+### `add_identity` — dynamic 1:1 reference line
+
+Adds a diagonal identity line that automatically adjusts when the axis limits
+change (e.g. during zoom or pan):
+
+```python
+from lama_aesthetics.plotutils import add_identity
+
+fig, ax = plt.subplots()
+predicted = np.array([1.1, 2.3, 2.9, 4.2])
+observed = np.array([1.0, 2.0, 3.0, 4.0])
+ax.scatter(observed, predicted)
+add_identity(ax, linestyle="--", color="gray", alpha=0.6)
+```
-- **range_frame**: Draws a frame around the data range.
-- **ylabel_top**: Places the y-label at the top of the y-axis.
-- **add_identity**: Adds a diagonal reference line.
+### `decompose_figure` — split a figure by legend entry
-### Figure Dimensions
+Takes a figure (or a single `Axes`) containing multiple labeled series and
+returns a list of `(label, figure)` tuples — one separate figure per legend
+entry. This is useful when you want to highlight individual series, e.g. to
+include them separately in a paper or presentation.
-The package provides predefined figure dimension constants based on common journal column widths and the golden ratio:
+```python
+from lama_aesthetics import decompose_figure, get_style
+
+get_style("main")
+
+fig, ax = plt.subplots()
+x = np.linspace(0, 2 * np.pi, 50)
+ax.plot(x, np.sin(x), label="sin(x)")
+ax.plot(x, np.cos(x), label="cos(x)")
+ax.plot(x, np.sin(x) + np.cos(x), label="sin+cos")
+ax.set_xlabel("x")
+ax.set_ylabel("f(x)")
+ax.set_title("Trigonometric Functions")
+ax.legend()
+
+parts = decompose_figure(fig) # also accepts an Axes directly
+
+for label, part_fig in parts:
+ part_fig.savefig(f"{label}.png")
+ plt.close(part_fig)
+```
+
+Each decomposed figure inherits axis labels, title, limits, scale, tick
+positions, spine visibility, and grid state from the original. Pass
+`show_legend=False` to omit the legend from the individual figures.
+
+**Supported artist types:** line plots (`plot`), scatter plots (`scatter`),
+bar charts (`bar`), and filled regions (`fill_between`).
+
+---
+
+## Figure dimension constants
+
+Predefined sizes based on common journal column widths and the golden ratio:
```python
from lama_aesthetics import (
- ONE_COL_WIDTH,
- TWO_COL_WIDTH,
- ONE_COL_HEIGHT,
- TWO_COL_HEIGHT,
+ ONE_COL_WIDTH, # 3 inches
+ TWO_COL_WIDTH, # 7.25 inches
+ ONE_COL_HEIGHT, # ONE_COL_WIDTH / φ ≈ 1.854 inches
+ TWO_COL_HEIGHT, # TWO_COL_WIDTH / φ ≈ 4.481 inches
)
-# Create a single-column figure with golden ratio proportions
fig, ax = plt.subplots(figsize=(ONE_COL_WIDTH, ONE_COL_HEIGHT))
-
-# Create a two-column figure with golden ratio proportions
-fig, ax = plt.subplots(figsize=(TWO_COL_WIDTH, TWO_COL_HEIGHT))
```
-Available constants:
+---
-- `ONE_COL_WIDTH`: 3 inches (typical single-column width)
-- `TWO_COL_WIDTH`: 7.25 inches (typical two-column width)
-- `ONE_COL_HEIGHT`: Single-column height based on golden ratio
-- `TWO_COL_HEIGHT`: Two-column height based on golden ratio
-
-### Plotting Utilities Examples
+## Full example
```python
import matplotlib.pyplot as plt
import numpy as np
+import lama_aesthetics
from lama_aesthetics.plotutils import range_frame, ylabel_top, add_identity
-# Create sample data
+lama_aesthetics.get_style("main")
+
x = np.linspace(0, 10, 100)
y = np.sin(x)
-# Create a figure
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
-# Example 1: Range frame - only shows axes within the data range
+# Range frame with nice bounds
axes[0].plot(x, y)
range_frame(axes[0], x, y)
axes[0].set_title("Range Frame")
-# Example 2: Top Y-label - places ylabel at top of y-axis
+# Top y-label
axes[1].plot(x, y)
ylabel_top("sin(x)", axes[1])
axes[1].set_title("Top Y-Label")
-# Example 3: Identity line - adds a diagonal reference line
-x_scatter = np.linspace(0, 1, 20)
-y_scatter = x_scatter + 0.1*np.random.randn(20)
-axes[2].scatter(x_scatter, y_scatter)
+# Identity line
+x_sc = np.linspace(0, 1, 20)
+y_sc = x_sc + 0.1 * np.random.randn(20)
+axes[2].scatter(x_sc, y_sc)
add_identity(axes[2], linestyle="--", color="gray")
axes[2].set_title("Identity Line")
@@ -118,6 +287,29 @@ plt.tight_layout()
plt.show()
```
-
Left: Range Frame; Center: Top Y-Label; Right: Identity Line
+
+

+
Left: Range Frame; Center: Top Y-Label; Right: Identity Line
+
+
+---
+
+## API reference
+
+| Function / Constant | Module | Description |
+|---|---|---|
+| `get_style(name)` | `aesthetics` | Apply a bundled style (`"main"`, `"presentation"`, `"dark"`) |
+| `register_fonts()` | `aesthetics` | Register bundled CMU Sans Serif with matplotlib |
+| `get_font_name()` | `aesthetics` | Return the registered font family name |
+| `range_frame(ax, x, y, ...)` | `plotutils` | Tufte-style range frame with nice-number bounds |
+| `ylabel_top(string, ax, ...)` | `plotutils` | Place y-label horizontally above the top tick |
+| `add_identity(ax, ...)` | `plotutils` | Add a dynamic 1:1 diagonal reference line |
+| `decompose_figure(fig, ...)` | `plotutils` | Split a figure into one figure per labeled artist |
+| `ONE_COL_WIDTH` | `aesthetics` | 3 inches |
+| `TWO_COL_WIDTH` | `aesthetics` | 7.25 inches |
+| `ONE_COL_HEIGHT` | `aesthetics` | `ONE_COL_WIDTH / golden_ratio` |
+| `TWO_COL_HEIGHT` | `aesthetics` | `TWO_COL_WIDTH / golden_ratio` |
+
+---
Repository initiated with [lamalab-org/cookiecutter-uv](https://github.com/lamalab-org/cookiecutter-uv).
diff --git a/pyproject.toml b/pyproject.toml
index a630d76..57a8a9d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,12 +39,12 @@ dependencies = [
"numpy",
"scipy",
]
-requires-python = ">=3.10"
+requires-python = ">=3.9"
readme = "README.md"
license = { text = "MIT license" }
[tool.setuptools.package-data]
-lama_aesthetics = ["styles/*.mplstyle"]
+lama_aesthetics = ["styles/*.mplstyle", "fonts/*.otf", "fonts/*.ttf"]
[project.optional-dependencies]
optional_dependencies = []
diff --git a/src/lama_aesthetics/__init__.py b/src/lama_aesthetics/__init__.py
index b19ba41..c94a267 100644
--- a/src/lama_aesthetics/__init__.py
+++ b/src/lama_aesthetics/__init__.py
@@ -4,16 +4,21 @@
STYLES,
TWO_COL_HEIGHT,
TWO_COL_WIDTH,
+ get_font_name,
get_style,
+ register_fonts,
)
-from lama_aesthetics.plotutils import add_identity, range_frame, ylabel_top
+from lama_aesthetics.plotutils import add_identity, decompose_figure, range_frame, ylabel_top
__all__ = [
"STYLES",
"get_style",
+ "register_fonts",
+ "get_font_name",
"range_frame",
"ylabel_top",
"add_identity",
+ "decompose_figure",
"ONE_COL_WIDTH",
"TWO_COL_WIDTH",
"ONE_COL_HEIGHT",
diff --git a/src/lama_aesthetics/aesthetics.py b/src/lama_aesthetics/aesthetics.py
index 4747317..9f440dd 100644
--- a/src/lama_aesthetics/aesthetics.py
+++ b/src/lama_aesthetics/aesthetics.py
@@ -1,11 +1,49 @@
import importlib.resources
import matplotlib.pyplot as plt
+from matplotlib import font_manager
+from matplotlib.font_manager import FontProperties
from scipy.constants import golden
+# Track whether fonts have been registered
+_fonts_registered = False
+
+
+def register_fonts() -> None:
+ """Register bundled fonts with Matplotlib's font manager.
+
+ This function adds the CMU Sans Serif font bundled with the package
+ to Matplotlib's font manager, allowing it to be used without system-wide
+ installation. Safe to call multiple times (will only register once).
+ """
+ global _fonts_registered
+ if _fonts_registered:
+ return
+
+ # Get path to the bundled font
+ font_files = ["cmunss.otf"]
+
+ for font_file in font_files:
+ with importlib.resources.as_file(importlib.resources.files("lama_aesthetics.fonts").joinpath(font_file)) as font_path:
+ font_manager.fontManager.addfont(str(font_path))
+
+ _fonts_registered = True
+
+
+def get_font_name() -> str:
+ """Get the family name of the bundled CMU Sans Serif font.
+
+ Returns:
+ The font family name as recognized by Matplotlib.
+ """
+ with importlib.resources.as_file(importlib.resources.files("lama_aesthetics.fonts").joinpath("cmunss.otf")) as font_path:
+ return FontProperties(fname=str(font_path)).get_name()
+
+
STYLES = {
"main": "lamalab.mplstyle",
"presentation": "presentation.mplstyle",
+ "dark": "lamalab_dark.mplstyle",
}
# Figure dimensions
@@ -18,8 +56,11 @@
def get_style(style_name: str) -> None:
"""Get the path to a matplotlib style file and apply it.
+ This function registers bundled fonts before applying the style,
+ ensuring the CMU Sans Serif font is available without system installation.
+
Args:
- style_name: Name of the style ('main' or 'presentation')
+ style_name: Name of the style ('main', 'presentation', or 'dark')
Raises:
KeyError: If style_name is not in STYLES dictionary
@@ -27,6 +68,9 @@ def get_style(style_name: str) -> None:
if style_name not in STYLES:
raise KeyError(f"Style '{style_name}' not found. Available styles: {list(STYLES.keys())}")
+ # Register bundled fonts before applying the style
+ register_fonts()
+
style_file = STYLES[style_name]
# Get the file contents as a string
diff --git a/src/lama_aesthetics/fonts/cmunss.otf b/src/lama_aesthetics/fonts/cmunss.otf
new file mode 100644
index 0000000..49fecee
Binary files /dev/null and b/src/lama_aesthetics/fonts/cmunss.otf differ
diff --git a/src/lama_aesthetics/plotutils.py b/src/lama_aesthetics/plotutils.py
index b293198..01d5ed6 100644
--- a/src/lama_aesthetics/plotutils.py
+++ b/src/lama_aesthetics/plotutils.py
@@ -1,32 +1,134 @@
-from typing import Optional
+from typing import List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
+from matplotlib.collections import PathCollection, PolyCollection
+from matplotlib.container import BarContainer
+from matplotlib.figure import Figure
+from matplotlib.lines import Line2D
+from matplotlib.ticker import MaxNLocator
-def range_frame(ax, x, y, pad=0.1):
+def _get_axis_bounds(values):
+ """Return axis bounds for numeric or categorical values."""
+ arr = np.asarray(values)
+
+ try:
+ numeric_arr = np.asarray(values, dtype=float)
+ except (TypeError, ValueError):
+ return 0, max(len(arr) - 1, 0), False
+
+ return numeric_arr.min(), numeric_arr.max(), True
+
+
+def _nice_tick_bounds(data_min, data_max):
+ """Return nice tick positions and spine bounds that strictly bracket the data.
+
+ Uses matplotlib's ``MaxNLocator`` to compute tick positions for the
+ data range and then selects the outermost ticks as spine bounds.
+ The returned bounds are guaranteed to satisfy
+ ``bound_lo <= data_min`` and ``bound_hi >= data_max``, and every
+ tick between the bounds (inclusive) is included.
+
+ Args:
+ data_min: Minimum value present in the data.
+ data_max: Maximum value present in the data.
+
+ Returns:
+ ``(bound_lo, bound_hi, ticks)`` where *ticks* is a 1-D array of
+ the tick positions that fall within ``[bound_lo, bound_hi]``.
+ """
+ if data_min == data_max:
+ # Degenerate case — expand symmetrically so the locator has a range.
+ if data_min == 0:
+ data_min, data_max = -0.5, 0.5
+ else:
+ delta = abs(data_min) * 0.1
+ data_min, data_max = data_min - delta, data_max + delta
+
+ locator = MaxNLocator(nbins="auto", steps=[1, 2, 2.5, 5, 10])
+ ticks = np.asarray(locator.tick_values(data_min, data_max))
+
+ # The locator already returns values that bracket [data_min, data_max],
+ # but enforce the invariant explicitly.
+ bound_lo = float(ticks[ticks <= data_min + 1e-12].max()) if np.any(ticks <= data_min + 1e-12) else float(ticks[0])
+ bound_hi = float(ticks[ticks >= data_max - 1e-12].min()) if np.any(ticks >= data_max - 1e-12) else float(ticks[-1])
+
+ # Keep only ticks within the chosen bounds.
+ mask = (ticks >= bound_lo - 1e-12) & (ticks <= bound_hi + 1e-12)
+ ticks = ticks[mask]
+
+ return bound_lo, bound_hi, ticks
+
+
+def range_frame(ax, x, y, pad=0.1, pad_x=None, pad_y=None, nice=True):
"""
Set the limits of the axes to include all data points with a padding of
`pad` times the range of the data. This is useful to ensure that the data
points are not cut off by the axes.
+ Per-axis padding can be controlled with ``pad_x`` and ``pad_y``. When
+ either is *None* (the default) the value of ``pad`` is used instead.
+
+ When ``nice`` is *True* (the default) and the axis carries numerical
+ data, the spine bounds are snapped to nice tick positions that bracket
+ the data, so that the axis line starts and ends exactly at tick marks.
+ Tick positions are computed via matplotlib's ``MaxNLocator`` and
+ explicitly set on the axes so there is no drift between ticks and
+ spine endpoints. The ``pad`` / ``pad_x`` / ``pad_y`` parameters are
+ ignored for any axis that receives nice bounds.
+
Args:
ax: The axes object.
x: The x-coordinates of the data points.
y: The y-coordinates of the data points.
- pad: The padding factor.
+ pad: The default padding factor applied to both axes.
+ pad_x: Padding near the x-axis (vertical direction). Overrides ``pad`` when set.
+ pad_y: Padding near the y-axis (horizontal direction). Overrides ``pad`` when set.
+ nice: If *True* (default), snap numeric spine bounds to nice tick
+ positions that bracket the data.
"""
- y_min, y_max = y.min(), y.max()
- x_min, x_max = x.min(), x.max()
+ if pad_x is None:
+ pad_x = pad
+ if pad_y is None:
+ pad_y = pad
- ax.set_ylim(y_min - pad * (y_max - y_min), y_max + pad * (y_max - y_min))
- ax.set_xlim(x_min - pad * (x_max - x_min), x_max + pad * (x_max - x_min))
+ y_min, y_max, y_is_numeric = _get_axis_bounds(y)
+ x_min, x_max, x_is_numeric = _get_axis_bounds(x)
+
+ # --- Y axis ------------------------------------------------------------
+ if y_is_numeric:
+ if nice:
+ y_bound_min, y_bound_max, y_ticks = _nice_tick_bounds(y_min, y_max)
+ ax.set_yticks(y_ticks)
+ ax.set_ylim(y_bound_min, y_bound_max)
+ else:
+ y_bound_min = y_min
+ y_bound_max = y_max
+ ax.set_ylim(y_min - pad_x * (y_max - y_min), y_max + pad_x * (y_max - y_min))
+ else:
+ y_bound_min, y_bound_max = y_min, y_max
+ ax.set_ylim(y_min, y_max)
+
+ # --- X axis ------------------------------------------------------------
+ if x_is_numeric:
+ if nice:
+ x_bound_min, x_bound_max, x_ticks = _nice_tick_bounds(x_min, x_max)
+ ax.set_xticks(x_ticks)
+ ax.set_xlim(x_bound_min, x_bound_max)
+ else:
+ x_bound_min = x_min
+ x_bound_max = x_max
+ ax.set_xlim(x_min - pad_y * (x_max - x_min), x_max + pad_y * (x_max - x_min))
+ else:
+ x_bound_min, x_bound_max = x_min, x_max
+ ax.set_xlim(x_min, x_max)
ax.spines["left"].set_position(("outward", 10))
ax.spines["bottom"].set_position(("outward", 10))
- ax.spines["bottom"].set_bounds(x_min, x_max)
- ax.spines["left"].set_bounds(y_min, y_max)
+ ax.spines["bottom"].set_bounds(x_bound_min, x_bound_max)
+ ax.spines["left"].set_bounds(y_bound_min, y_bound_max)
def ylabel_top(string: str, ax: Optional[plt.Axes] = None, x_pad: float = 0.01, y_pad: float = 0.02) -> None:
@@ -105,3 +207,201 @@ def callback(axes):
axes.callbacks.connect("xlim_changed", callback)
axes.callbacks.connect("ylim_changed", callback)
return axes
+
+
+def _setup_axes_like(source_ax: plt.Axes, target_ax: plt.Axes) -> None:
+ """Copy axis labels, title, limits, scales, and spine visibility from *source_ax* to *target_ax*."""
+ target_ax.set_xlim(source_ax.get_xlim())
+ target_ax.set_ylim(source_ax.get_ylim())
+ target_ax.set_xlabel(source_ax.get_xlabel())
+ target_ax.set_ylabel(source_ax.get_ylabel())
+ target_ax.set_title(source_ax.get_title())
+ target_ax.set_xscale(source_ax.get_xscale())
+ target_ax.set_yscale(source_ax.get_yscale())
+
+ # Copy tick parameters
+ target_ax.set_xticks(source_ax.get_xticks())
+ target_ax.set_yticks(source_ax.get_yticks())
+ try:
+ target_ax.set_xticklabels([t.get_text() for t in source_ax.get_xticklabels()])
+ target_ax.set_yticklabels([t.get_text() for t in source_ax.get_yticklabels()])
+ except Exception:
+ pass
+
+ # Re-apply limits after setting ticks (ticks can change limits)
+ target_ax.set_xlim(source_ax.get_xlim())
+ target_ax.set_ylim(source_ax.get_ylim())
+
+ # Copy spine visibility
+ for spine_name in ("top", "bottom", "left", "right"):
+ target_ax.spines[spine_name].set_visible(source_ax.spines[spine_name].get_visible())
+ target_ax.spines[spine_name].set_bounds(*source_ax.spines[spine_name].get_bounds()) if source_ax.spines[spine_name].get_bounds() else None
+ target_ax.spines[spine_name].set_position(source_ax.spines[spine_name].get_position())
+
+ # Copy grid state
+ target_ax.grid(source_ax.xaxis.get_gridlines()[0].get_visible() if source_ax.xaxis.get_gridlines() else False)
+
+
+def _replay_line(source_line: Line2D, target_ax: plt.Axes) -> None:
+ """Re-draw a Line2D artist onto *target_ax*."""
+ target_ax.plot(
+ source_line.get_xdata(),
+ source_line.get_ydata(),
+ color=source_line.get_color(),
+ linestyle=source_line.get_linestyle(),
+ linewidth=source_line.get_linewidth(),
+ marker=source_line.get_marker(),
+ markersize=source_line.get_markersize(),
+ markerfacecolor=source_line.get_markerfacecolor(),
+ markeredgecolor=source_line.get_markeredgecolor(),
+ markeredgewidth=source_line.get_markeredgewidth(),
+ alpha=source_line.get_alpha(),
+ label=source_line.get_label(),
+ )
+
+
+def _replay_scatter(source_coll: PathCollection, target_ax: plt.Axes, label: str) -> None:
+ """Re-draw a scatter PathCollection onto *target_ax*."""
+ offsets = source_coll.get_offsets()
+ if len(offsets) == 0:
+ return
+ facecolors = source_coll.get_facecolor()
+ edgecolors = source_coll.get_edgecolor()
+ sizes = source_coll.get_sizes()
+ target_ax.scatter(
+ offsets[:, 0],
+ offsets[:, 1],
+ color=facecolors if len(facecolors) > 1 else facecolors[0],
+ edgecolors=edgecolors if len(edgecolors) > 1 else edgecolors[0],
+ s=sizes if len(sizes) > 1 else sizes[0],
+ alpha=source_coll.get_alpha(),
+ label=label,
+ )
+
+
+def _replay_bar(container: BarContainer, target_ax: plt.Axes) -> None:
+ """Re-draw a BarContainer onto *target_ax*."""
+ for patch in container.patches:
+ target_ax.bar(
+ patch.get_x() + patch.get_width() / 2,
+ patch.get_height(),
+ width=patch.get_width(),
+ bottom=patch.get_y(),
+ color=patch.get_facecolor(),
+ edgecolor=patch.get_edgecolor(),
+ linewidth=patch.get_linewidth(),
+ alpha=patch.get_alpha(),
+ label=container.get_label() if patch is container.patches[0] else None,
+ )
+
+
+def _replay_fill(source_coll: PolyCollection, target_ax: plt.Axes, label: str) -> None:
+ """Re-draw a PolyCollection (e.g. fill_between) onto *target_ax*."""
+ for path in source_coll.get_paths():
+ verts = path.vertices
+ target_ax.fill(
+ verts[:, 0],
+ verts[:, 1],
+ facecolor=source_coll.get_facecolor()[0],
+ edgecolor=source_coll.get_edgecolor()[0] if len(source_coll.get_edgecolor()) > 0 else None,
+ alpha=source_coll.get_alpha(),
+ label=label,
+ )
+ label = None # only label the first polygon
+
+
+def decompose_figure(
+ fig_or_ax: Union[Figure, plt.Axes],
+ *,
+ show_legend: bool = True,
+) -> List[Tuple[str, Figure]]:
+ """Decompose a matplotlib figure into individual figures, one per labeled artist.
+
+ Each returned figure contains a single plotted element (line, scatter,
+ bar group, fill, …) together with the same axis labels, limits, and
+ title as the original. Only artists that carry a label (and would
+ therefore appear in a legend) are considered; artists whose label
+ starts with ``_`` are skipped, following the matplotlib convention.
+
+ Args:
+ fig_or_ax: A :class:`~matplotlib.figure.Figure` or a single
+ :class:`~matplotlib.axes.Axes` instance. When a *Figure*
+ is given the first ``Axes`` is used.
+ show_legend: If *True* (default) a legend is added to every
+ decomposed figure.
+
+ Returns:
+ A list of ``(label, figure)`` tuples where *label* is the
+ legend text associated with the artist and *figure* is a new
+ :class:`~matplotlib.figure.Figure` containing only that artist.
+
+ Example::
+
+ fig, ax = plt.subplots()
+ ax.plot([0, 1], [0, 1], label="Linear")
+ ax.plot([0, 1], [0, 2], label="Steep")
+
+ parts = decompose_figure(fig)
+ for label, part_fig in parts:
+ part_fig.savefig(f"{label}.png")
+ """
+ # Resolve axes ----------------------------------------------------------
+ if isinstance(fig_or_ax, Figure):
+ axes_list = fig_or_ax.get_axes()
+ if not axes_list:
+ return []
+ source_ax = axes_list[0]
+ source_fig = fig_or_ax
+ else:
+ source_ax = fig_or_ax
+ source_fig = fig_or_ax.get_figure()
+
+ figsize = source_fig.get_size_inches()
+
+ # Collect labelled artists ----------------------------------------------
+ items: List[Tuple[str, object]] = []
+
+ # 1. Lines
+ for line in source_ax.get_lines():
+ lbl = line.get_label()
+ if lbl and not lbl.startswith("_"):
+ items.append((lbl, line))
+
+ # 2. Bar containers
+ for container in source_ax.containers:
+ if isinstance(container, BarContainer):
+ lbl = container.get_label()
+ if lbl and not lbl.startswith("_"):
+ items.append((lbl, container))
+
+ # 3. Collections (scatter / fill_between)
+ for coll in source_ax.collections:
+ lbl = coll.get_label()
+ if lbl and not lbl.startswith("_"):
+ items.append((lbl, coll))
+
+ # Build one figure per item ---------------------------------------------
+ results: List[Tuple[str, Figure]] = []
+ for label, artist in items:
+ new_fig, new_ax = plt.subplots(figsize=figsize)
+
+ # Reproduce axes properties
+ _setup_axes_like(source_ax, new_ax)
+
+ # Replay the single artist
+ if isinstance(artist, Line2D):
+ _replay_line(artist, new_ax)
+ elif isinstance(artist, BarContainer):
+ _replay_bar(artist, new_ax)
+ elif isinstance(artist, PathCollection):
+ _replay_scatter(artist, new_ax, label)
+ elif isinstance(artist, PolyCollection):
+ _replay_fill(artist, new_ax, label)
+
+ if show_legend:
+ new_ax.legend()
+
+ new_fig.tight_layout()
+ results.append((label, new_fig))
+
+ return results
diff --git a/src/lama_aesthetics/styles/lamalab_dark.mplstyle b/src/lama_aesthetics/styles/lamalab_dark.mplstyle
new file mode 100644
index 0000000..7aeb77e
--- /dev/null
+++ b/src/lama_aesthetics/styles/lamalab_dark.mplstyle
@@ -0,0 +1,74 @@
+# '000000',
+axes.prop_cycle : cycler('color', ['0C5DA5', '00B945', 'FF9500', 'FF2C00', '845B97', '474747', '9e9e9e', "9A607F"])
+
+#axes.prop_cycle : cycler('color', ["DB444B", "006BA2", "3EBCD2", "379A8B", "EBB434", "#B4BA39", "#9A607F", '#9e9e9e', "#D1B07C"])
+
+
+# Figure size
+figure.figsize : 3.3, 2.5 # max width is 3.5 for single column
+# Set x axis
+xtick.direction : in
+xtick.major.size : 3
+xtick.major.width : 0.5
+xtick.minor.size : 1.5
+xtick.minor.width : 0.5
+xtick.minor.visible : True
+xtick.top : False
+xtick.minor.top: False
+xtick.minor.bottom: False
+
+# Set y axis
+ytick.direction : in
+ytick.major.size : 3
+ytick.major.width : 0.5
+ytick.minor.size : 1.5
+ytick.minor.width : 0.5
+ytick.minor.visible : True
+ytick.minor.left: False
+ytick.right: False
+
+# Set line widths
+axes.linewidth : 0.5
+grid.linewidth : 0.5
+lines.linewidth : 1.
+lines.markersize: 3
+axes.xmargin: 0
+axes.ymargin: 0
+
+# Remove legend frame
+legend.frameon : False
+
+# Always save as 'tight'
+# savefig.bbox : tight
+# savefig.pad_inches : 0.01 # Use virtually all space when we specify figure dimensions
+
+# Font sizes
+axes.labelsize: 11
+xtick.labelsize: 10
+ytick.labelsize: 10
+legend.fontsize: 10
+font.size: 10
+
+# Font Family
+font.family: sans-serif
+font.sans-serif: CMU Sans Serif, IBM Plex Sans, Roboto, Helvetica
+mathtext.fontset : dejavusans
+
+# LaTeX packages
+text.latex.preamble : \usepackage{amsmath} \usepackage{amssymb} \usepackage{sfmath}
+
+# remove non-data ink
+axes.spines.top: False
+axes.spines.right: False
+axes.axisbelow: True
+
+# Dark theme colors
+text.color: white
+axes.labelcolor: white
+axes.edgecolor: white
+xtick.color: white
+ytick.color: white
+figure.facecolor: black
+axes.facecolor: black
+savefig.facecolor: black
+legend.labelcolor: white
diff --git a/tests/test_plotutils.py b/tests/test_plotutils.py
index dec9179..65e5294 100644
--- a/tests/test_plotutils.py
+++ b/tests/test_plotutils.py
@@ -1,21 +1,47 @@
import matplotlib.pyplot as plt
import numpy as np
-from lama_aesthetics.plotutils import add_identity, range_frame, ylabel_top
+from lama_aesthetics.plotutils import (
+ _nice_tick_bounds,
+ add_identity,
+ decompose_figure,
+ range_frame,
+ ylabel_top,
+)
def test_range_frame():
- """Test that range_frame sets axis limits correctly."""
+ """Test that range_frame sets axis limits correctly (nice=True by default)."""
fig, ax = plt.subplots()
x = np.array([0, 1, 2, 3, 4])
y = np.array([0, 2, 4, 6, 8])
- range_frame(ax, x, y, pad=0.1)
+ range_frame(ax, x, y)
+
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+
+ # Nice bounds should contain all data
+ assert xlim[0] <= x.min()
+ assert xlim[1] >= x.max()
+ assert ylim[0] <= y.min()
+ assert ylim[1] >= y.max()
+
+ plt.close(fig)
+
+
+def test_range_frame_nice_false():
+ """Test that range_frame with nice=False uses raw padding."""
+ fig, ax = plt.subplots()
+ x = np.array([0, 1, 2, 3, 4])
+ y = np.array([0, 2, 4, 6, 8])
+
+ range_frame(ax, x, y, pad=0.1, nice=False)
xlim = ax.get_xlim()
ylim = ax.get_ylim()
- # Check that limits include all data points
+ # With nice=False the old padding-based behaviour applies
assert xlim[0] < x.min()
assert xlim[1] > x.max()
assert ylim[0] < y.min()
@@ -24,6 +50,67 @@ def test_range_frame():
plt.close(fig)
+def test_range_frame_per_axis_pad():
+ """Test that pad_x and pad_y override the default pad independently (nice=False)."""
+ fig, ax = plt.subplots()
+ x = np.array([0, 1, 2, 3, 4])
+ y = np.array([0, 2, 4, 6, 8])
+
+ # pad_x controls padding near x-axis (vertical / y-limits)
+ # pad_y controls padding near y-axis (horizontal / x-limits)
+ range_frame(ax, x, y, pad_x=0.2, pad_y=0.0, nice=False)
+
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+
+ # pad_y=0.0 → x-limits equal data range (no horizontal padding)
+ assert xlim[0] == x.min()
+ assert xlim[1] == x.max()
+
+ # pad_x=0.2 → y-limits have vertical padding
+ y_range = y.max() - y.min()
+ assert ylim[0] < y.min()
+ assert ylim[1] > y.max()
+ assert abs(ylim[0] - (y.min() - 0.2 * y_range)) < 1e-10
+ assert abs(ylim[1] - (y.max() + 0.2 * y_range)) < 1e-10
+
+ plt.close(fig)
+
+
+def test_range_frame_non_numeric_x_axis():
+ """Non-numeric x values should use 0..len(x)-1 for the range frame."""
+ fig, ax = plt.subplots()
+ x = ["a", "b", "c", "d"]
+ y = np.array([0, 2, 4, 6])
+
+ ax.plot(x, y)
+ range_frame(ax, x, y, pad=0.1)
+
+ xlim = ax.get_xlim()
+
+ assert xlim == (0.0, float(len(x) - 1))
+ assert ax.spines["bottom"].get_bounds() == (0, len(x) - 1)
+
+ plt.close(fig)
+
+
+def test_range_frame_non_numeric_y_axis():
+ """Non-numeric y values should use 0..len(y)-1 for the range frame."""
+ fig, ax = plt.subplots()
+ x = np.array([0, 1, 2, 3])
+ y = ["low", "mid", "high", "top"]
+
+ ax.plot(x, y)
+ range_frame(ax, x, y, pad=0.1)
+
+ ylim = ax.get_ylim()
+
+ assert ylim == (0.0, float(len(y) - 1))
+ assert ax.spines["left"].get_bounds() == (0, len(y) - 1)
+
+ plt.close(fig)
+
+
def test_ylabel_top():
"""Test that ylabel_top sets ylabel without errors."""
fig, ax = plt.subplots()
@@ -63,3 +150,210 @@ def test_add_identity():
assert result == ax
plt.close(fig)
+
+
+# --- nice tick bounds tests -------------------------------------------------
+
+
+def test_nice_tick_bounds_brackets_data():
+ """_nice_tick_bounds should return bounds that contain the data."""
+ lo, hi, ticks = _nice_tick_bounds(3.2, 47.8)
+ assert lo <= 3.2
+ assert hi >= 47.8
+ # Bounds must be tick positions
+ assert any(abs(lo - t) < 1e-10 for t in ticks)
+ assert any(abs(hi - t) < 1e-10 for t in ticks)
+
+
+def test_nice_tick_bounds_already_nice():
+ """When data already spans a tick-aligned range, bounds should match."""
+ lo, hi, ticks = _nice_tick_bounds(0, 10)
+ assert lo == 0.0
+ assert hi == 10.0
+
+
+def test_nice_tick_bounds_negative():
+ """Negative data ranges should also produce nice bounds."""
+ lo, hi, ticks = _nice_tick_bounds(-7.3, -1.2)
+ assert lo <= -7.3
+ assert hi >= -1.2
+ assert any(abs(lo - t) < 1e-10 for t in ticks)
+ assert any(abs(hi - t) < 1e-10 for t in ticks)
+
+
+def test_range_frame_nice_bounds_spine_alignment():
+ """With nice=True the spine bounds must coincide with actual tick positions."""
+ fig, ax = plt.subplots()
+ x = np.array([0.5, 1.3, 2.7, 3.9])
+ y = np.array([1.1, 4.4, 7.2, 9.8])
+
+ range_frame(ax, x, y, nice=True)
+
+ x_spine = ax.spines["bottom"].get_bounds()
+ y_spine = ax.spines["left"].get_bounds()
+
+ # Spine bounds should contain all data
+ assert x_spine[0] <= x.min()
+ assert x_spine[1] >= x.max()
+ assert y_spine[0] <= y.min()
+ assert y_spine[1] >= y.max()
+
+ # Spine bounds must be actual tick positions
+ x_ticks = ax.xaxis.get_ticklocs()
+ y_ticks = ax.yaxis.get_ticklocs()
+ assert any(abs(x_spine[0] - t) < 1e-10 for t in x_ticks), f"x spine lo {x_spine[0]} not in ticks {x_ticks}"
+ assert any(abs(x_spine[1] - t) < 1e-10 for t in x_ticks), f"x spine hi {x_spine[1]} not in ticks {x_ticks}"
+ assert any(abs(y_spine[0] - t) < 1e-10 for t in y_ticks), f"y spine lo {y_spine[0]} not in ticks {y_ticks}"
+ assert any(abs(y_spine[1] - t) < 1e-10 for t in y_ticks), f"y spine hi {y_spine[1]} not in ticks {y_ticks}"
+
+ plt.close(fig)
+
+
+def test_range_frame_nice_non_numeric_unchanged():
+ """nice=True should not affect categorical axes."""
+ fig, ax = plt.subplots()
+ x = ["a", "b", "c"]
+ y = np.array([1, 5, 9])
+
+ ax.plot(x, y)
+ range_frame(ax, x, y, nice=True)
+
+ # Categorical x-axis: bounds should be index-based
+ assert ax.spines["bottom"].get_bounds() == (0, 2)
+
+ plt.close(fig)
+
+
+# --- decompose_figure tests ------------------------------------------------
+
+
+def test_decompose_figure_lines():
+ """decompose_figure should return one figure per labelled line."""
+ fig, ax = plt.subplots()
+ ax.plot([0, 1, 2], [0, 1, 4], label="Linear")
+ ax.plot([0, 1, 2], [0, 2, 8], label="Steep")
+ ax.set_xlabel("x")
+ ax.set_ylabel("y")
+ ax.set_title("My Plot")
+
+ parts = decompose_figure(fig)
+
+ assert len(parts) == 2
+ assert parts[0][0] == "Linear"
+ assert parts[1][0] == "Steep"
+
+ # Each returned figure should have exactly one line on its axes
+ for label, part_fig in parts:
+ part_ax = part_fig.get_axes()[0]
+ labelled_lines = [line for line in part_ax.get_lines() if not line.get_label().startswith("_")]
+ assert len(labelled_lines) == 1
+ assert labelled_lines[0].get_label() == label
+ # Axes metadata preserved
+ assert part_ax.get_xlabel() == "x"
+ assert part_ax.get_ylabel() == "y"
+ assert part_ax.get_title() == "My Plot"
+
+ for _, f in parts:
+ plt.close(f)
+ plt.close(fig)
+
+
+def test_decompose_figure_scatter():
+ """decompose_figure should handle scatter plots."""
+ fig, ax = plt.subplots()
+ ax.scatter([0, 1], [2, 3], label="Group A")
+ ax.scatter([4, 5], [6, 7], label="Group B")
+
+ parts = decompose_figure(fig)
+
+ assert len(parts) == 2
+ assert {p[0] for p in parts} == {"Group A", "Group B"}
+
+ for _, f in parts:
+ plt.close(f)
+ plt.close(fig)
+
+
+def test_decompose_figure_bars():
+ """decompose_figure should handle bar plots."""
+ fig, ax = plt.subplots()
+ ax.bar([0, 1, 2], [3, 4, 5], label="Series 1")
+ ax.bar([0, 1, 2], [1, 2, 3], bottom=[3, 4, 5], label="Series 2")
+
+ parts = decompose_figure(fig)
+
+ assert len(parts) == 2
+ assert parts[0][0] == "Series 1"
+ assert parts[1][0] == "Series 2"
+
+ for _, f in parts:
+ plt.close(f)
+ plt.close(fig)
+
+
+def test_decompose_figure_skips_unlabelled():
+ """Artists without a label or with _-prefixed labels should be skipped."""
+ fig, ax = plt.subplots()
+ ax.plot([0, 1], [0, 1]) # no label
+ ax.plot([0, 1], [0, 2], label="_hidden")
+ ax.plot([0, 1], [0, 3], label="Visible")
+
+ parts = decompose_figure(fig)
+
+ assert len(parts) == 1
+ assert parts[0][0] == "Visible"
+
+ for _, f in parts:
+ plt.close(f)
+ plt.close(fig)
+
+
+def test_decompose_figure_accepts_axes():
+ """decompose_figure should accept a single Axes object."""
+ fig, ax = plt.subplots()
+ ax.plot([0, 1], [0, 1], label="A")
+ ax.plot([0, 1], [0, 2], label="B")
+
+ parts = decompose_figure(ax) # pass Axes, not Figure
+
+ assert len(parts) == 2
+
+ for _, f in parts:
+ plt.close(f)
+ plt.close(fig)
+
+
+def test_decompose_figure_no_legend():
+ """When show_legend=False no legend should be present."""
+ fig, ax = plt.subplots()
+ ax.plot([0, 1], [0, 1], label="A")
+
+ parts = decompose_figure(fig, show_legend=False)
+
+ assert len(parts) == 1
+ part_ax = parts[0][1].get_axes()[0]
+ assert part_ax.get_legend() is None
+
+ for _, f in parts:
+ plt.close(f)
+ plt.close(fig)
+
+
+def test_decompose_figure_preserves_limits():
+ """Axis limits should be the same as the original plot."""
+ fig, ax = plt.subplots()
+ ax.plot([0, 10], [0, 100], label="A")
+ ax.plot([0, 10], [100, 0], label="B")
+ original_xlim = ax.get_xlim()
+ original_ylim = ax.get_ylim()
+
+ parts = decompose_figure(fig)
+
+ for _, part_fig in parts:
+ part_ax = part_fig.get_axes()[0]
+ assert part_ax.get_xlim() == original_xlim
+ assert part_ax.get_ylim() == original_ylim
+
+ for _, f in parts:
+ plt.close(f)
+ plt.close(fig)