diff --git a/MANIFEST.in b/MANIFEST.in index 7c063ca..0b66f27 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ include src/lama_aesthetics/styles/*.mplstyle - +include src/lama_aesthetics/fonts/*.otf +include src/lama_aesthetics/fonts/*.ttf diff --git a/README.md b/README.md index a9e221b..1c12369 100644 --- a/README.md +++ b/README.md @@ -5,10 +5,22 @@ [![Docs](https://img.shields.io/badge/docs-gh--pages-blue)](https://lamalab-org.github.io/lama-aesthetics/) [![License](https://img.shields.io/github/license/lamalab-org/lama-aesthetics)](https://img.shields.io/github/license/lamalab-org/lama-aesthetics) -Plotting styles and helpers by LamaLab +Publication-quality plotting styles and helpers for matplotlib by [LamaLab](https://lamalab.org). -- **Github repository**: -- **Documentation** +- **GitHub repository**: +- **Documentation**: + +## Features at a glance + +| Category | What you get | +|---|---| +| **Styles** | `main` (publication), `presentation` (talks), `dark` (dark-themed) | +| **Bundled font** | CMU Sans Serif — auto-registered, no system install needed | +| **Range frame** | Tufte-style axis frame trimmed to the data range, with automatic nice-number bounds for numeric axes and full support for categorical axes | +| **Label helpers** | `ylabel_top` — horizontal y-label placed above the top tick | +| **Reference lines** | `add_identity` — dynamic 1:1 diagonal that follows axis changes | +| **Figure decomposition** | `decompose_figure` — split a multi-series figure into one figure per labeled artist | +| **Dimension constants** | `ONE_COL_WIDTH`, `TWO_COL_WIDTH`, `ONE_COL_HEIGHT`, `TWO_COL_HEIGHT` (golden ratio) | ## Installation @@ -24,21 +36,48 @@ uv pip install -e . make install ``` -## Usage +## Quick start + +```python +import matplotlib.pyplot as plt +import numpy as np +import lama_aesthetics + +lama_aesthetics.get_style("main") + +fig, ax = plt.subplots() +x = np.linspace(0.5, 9.3, 40) +y = np.sin(x) * 10 + 15 + +ax.plot(x, y) +lama_aesthetics.range_frame(ax, x, y) # nice-number axis bounds by default +lama_aesthetics.ylabel_top("y", ax=ax) + +plt.show() +``` + +--- -### Styles +## Styles -The library provides two main plotting styles: +The library ships three matplotlib style sheets, all using the bundled **CMU Sans Serif** font: -- **main**: Optimized for publications, reports, and other documents. -- **presentation**: Features larger fonts and thicker lines for better visibility in presentations +| Style | Apply with | Description | +|---|---|---| +| **main** | `get_style("main")` | Optimized for single- and two-column journal figures. Compact figure size (3.3 × 2.5 in), inward ticks, thin lines. | +| **presentation** | `get_style("presentation")` | Same layout but with larger fonts (13 / 12 pt) for slides and posters. | +| **dark** | `get_style("dark")` | Black background, white text and lines — ideal for dark-themed slides or dashboards. | + +All styles disable the right and top spines, use inward-facing ticks, and remove the legend frame for a clean Tufte-inspired look. ```python -import matplotlib.pyplot as plt -import numpy as np import lama_aesthetics -lama_aesthetics.get_style("main") # or la.get_style("presentation") +lama_aesthetics.get_style("main") +# or +lama_aesthetics.get_style("presentation") +# or +lama_aesthetics.get_style("dark") ```
@@ -49,68 +88,198 @@ lama_aesthetics.get_style("main") # or la.get_style("presentation")

Left: Main style; Right: Presentation style

-### Helpers +### Bundled font + +The **CMU Sans Serif** font (`cmunss.otf`) is bundled with the package and +registered automatically the first time you apply a style. You can also +register it manually: + +```python +lama_aesthetics.register_fonts() +font_name = lama_aesthetics.get_font_name() # → "CMU Sans Serif" +``` + +--- + +## Plotting utilities + +### `range_frame` — Tufte-style range frame with nice-number bounds + +`range_frame` trims the axis spines to the data range and offsets them outward +for a clean, separated look. It automatically detects whether each axis is +**numeric** or **categorical** and handles them differently: + +**Numeric axes (default: `nice=True`):** +The spine bounds are snapped to "nice" tick positions (multiples of 1, 2, 2.5, +5, or 10 × 10^n) that bracket the data. This means the axis line always +starts and ends exactly on a tick mark — no more floating spine endpoints. +Tick positions are computed via matplotlib's `MaxNLocator` and explicitly set +on the axes so there is zero drift between ticks and spine bounds. + +**Categorical axes:** +When `x` or `y` contains strings (e.g. `["a", "b", "c"]`), the spine spans +the integer range `0 .. len(values) - 1`. No rounding is applied. + +```python +import matplotlib.pyplot as plt +import numpy as np +from lama_aesthetics.plotutils import range_frame + +fig, axes = plt.subplots(1, 3, figsize=(12, 3)) + +# 1) Basic numeric — bounds snap to nice ticks +x = np.linspace(3.2, 47.8, 20) +y = x * 1.3 - 1.5 +axes[0].plot(x, y) +range_frame(axes[0], x, y) # nice=True by default +axes[0].set_title("Numeric (nice bounds)") + +# 2) Categorical x-axis +categories = ["Model A", "Model B", "Model C", "Model D"] +values = np.array([0.82, 0.91, 0.87, 0.95]) +axes[1].plot(categories, values) +range_frame(axes[1], categories, values) # categorical x, numeric y +axes[1].set_title("Categorical x-axis") + +# 3) Disable nice bounds — raw padding only +axes[2].plot(x, y) +range_frame(axes[2], x, y, nice=False, pad=0.05) +axes[2].set_title("nice=False (raw padding)") -The package includes several plotting utilities to enhance your visualizations: +plt.tight_layout() +plt.show() +``` + +#### Parameters + +| Parameter | Default | Description | +|---|---|---| +| `ax` | — | Matplotlib `Axes` object | +| `x`, `y` | — | Data arrays (numeric or string/categorical) | +| `pad` | `0.1` | Padding factor applied to both axes (only used when `nice=False`) | +| `pad_x` | `None` | Per-axis padding near the x-axis (vertical). Overrides `pad`. | +| `pad_y` | `None` | Per-axis padding near the y-axis (horizontal). Overrides `pad`. | +| `nice` | `True` | Snap numeric spine bounds to nice tick positions that bracket the data. When `True`, padding parameters are ignored for numeric axes. | + +### `ylabel_top` — horizontal y-label above the axis + +Places the y-axis label horizontally above the top tick, making it easier to +read without head-tilting: + +```python +from lama_aesthetics.plotutils import ylabel_top + +fig, ax = plt.subplots() +ax.plot([0, 1, 2], [0, 1, 4]) +ylabel_top("Energy (eV)", ax=ax) +``` + +| Parameter | Default | Description | +|---|---|---| +| `string` | — | Label text | +| `ax` | `None` | Axes (defaults to `plt.gca()`) | +| `x_pad` | `0.01` | Horizontal offset in axes coordinates | +| `y_pad` | `0.02` | Vertical offset above the top tick | + +### `add_identity` — dynamic 1:1 reference line + +Adds a diagonal identity line that automatically adjusts when the axis limits +change (e.g. during zoom or pan): + +```python +from lama_aesthetics.plotutils import add_identity + +fig, ax = plt.subplots() +predicted = np.array([1.1, 2.3, 2.9, 4.2]) +observed = np.array([1.0, 2.0, 3.0, 4.0]) +ax.scatter(observed, predicted) +add_identity(ax, linestyle="--", color="gray", alpha=0.6) +``` -- **range_frame**: Draws a frame around the data range. -- **ylabel_top**: Places the y-label at the top of the y-axis. -- **add_identity**: Adds a diagonal reference line. +### `decompose_figure` — split a figure by legend entry -### Figure Dimensions +Takes a figure (or a single `Axes`) containing multiple labeled series and +returns a list of `(label, figure)` tuples — one separate figure per legend +entry. This is useful when you want to highlight individual series, e.g. to +include them separately in a paper or presentation. -The package provides predefined figure dimension constants based on common journal column widths and the golden ratio: +```python +from lama_aesthetics import decompose_figure, get_style + +get_style("main") + +fig, ax = plt.subplots() +x = np.linspace(0, 2 * np.pi, 50) +ax.plot(x, np.sin(x), label="sin(x)") +ax.plot(x, np.cos(x), label="cos(x)") +ax.plot(x, np.sin(x) + np.cos(x), label="sin+cos") +ax.set_xlabel("x") +ax.set_ylabel("f(x)") +ax.set_title("Trigonometric Functions") +ax.legend() + +parts = decompose_figure(fig) # also accepts an Axes directly + +for label, part_fig in parts: + part_fig.savefig(f"{label}.png") + plt.close(part_fig) +``` + +Each decomposed figure inherits axis labels, title, limits, scale, tick +positions, spine visibility, and grid state from the original. Pass +`show_legend=False` to omit the legend from the individual figures. + +**Supported artist types:** line plots (`plot`), scatter plots (`scatter`), +bar charts (`bar`), and filled regions (`fill_between`). + +--- + +## Figure dimension constants + +Predefined sizes based on common journal column widths and the golden ratio: ```python from lama_aesthetics import ( - ONE_COL_WIDTH, - TWO_COL_WIDTH, - ONE_COL_HEIGHT, - TWO_COL_HEIGHT, + ONE_COL_WIDTH, # 3 inches + TWO_COL_WIDTH, # 7.25 inches + ONE_COL_HEIGHT, # ONE_COL_WIDTH / φ ≈ 1.854 inches + TWO_COL_HEIGHT, # TWO_COL_WIDTH / φ ≈ 4.481 inches ) -# Create a single-column figure with golden ratio proportions fig, ax = plt.subplots(figsize=(ONE_COL_WIDTH, ONE_COL_HEIGHT)) - -# Create a two-column figure with golden ratio proportions -fig, ax = plt.subplots(figsize=(TWO_COL_WIDTH, TWO_COL_HEIGHT)) ``` -Available constants: +--- -- `ONE_COL_WIDTH`: 3 inches (typical single-column width) -- `TWO_COL_WIDTH`: 7.25 inches (typical two-column width) -- `ONE_COL_HEIGHT`: Single-column height based on golden ratio -- `TWO_COL_HEIGHT`: Two-column height based on golden ratio - -### Plotting Utilities Examples +## Full example ```python import matplotlib.pyplot as plt import numpy as np +import lama_aesthetics from lama_aesthetics.plotutils import range_frame, ylabel_top, add_identity -# Create sample data +lama_aesthetics.get_style("main") + x = np.linspace(0, 10, 100) y = np.sin(x) -# Create a figure fig, axes = plt.subplots(1, 3, figsize=(12, 4)) -# Example 1: Range frame - only shows axes within the data range +# Range frame with nice bounds axes[0].plot(x, y) range_frame(axes[0], x, y) axes[0].set_title("Range Frame") -# Example 2: Top Y-label - places ylabel at top of y-axis +# Top y-label axes[1].plot(x, y) ylabel_top("sin(x)", axes[1]) axes[1].set_title("Top Y-Label") -# Example 3: Identity line - adds a diagonal reference line -x_scatter = np.linspace(0, 1, 20) -y_scatter = x_scatter + 0.1*np.random.randn(20) -axes[2].scatter(x_scatter, y_scatter) +# Identity line +x_sc = np.linspace(0, 1, 20) +y_sc = x_sc + 0.1 * np.random.randn(20) +axes[2].scatter(x_sc, y_sc) add_identity(axes[2], linestyle="--", color="gray") axes[2].set_title("Identity Line") @@ -118,6 +287,29 @@ plt.tight_layout() plt.show() ``` -
Helper function examples

Left: Range Frame; Center: Top Y-Label; Right: Identity Line

+
+ Helper function examples +

Left: Range Frame; Center: Top Y-Label; Right: Identity Line

+
+ +--- + +## API reference + +| Function / Constant | Module | Description | +|---|---|---| +| `get_style(name)` | `aesthetics` | Apply a bundled style (`"main"`, `"presentation"`, `"dark"`) | +| `register_fonts()` | `aesthetics` | Register bundled CMU Sans Serif with matplotlib | +| `get_font_name()` | `aesthetics` | Return the registered font family name | +| `range_frame(ax, x, y, ...)` | `plotutils` | Tufte-style range frame with nice-number bounds | +| `ylabel_top(string, ax, ...)` | `plotutils` | Place y-label horizontally above the top tick | +| `add_identity(ax, ...)` | `plotutils` | Add a dynamic 1:1 diagonal reference line | +| `decompose_figure(fig, ...)` | `plotutils` | Split a figure into one figure per labeled artist | +| `ONE_COL_WIDTH` | `aesthetics` | 3 inches | +| `TWO_COL_WIDTH` | `aesthetics` | 7.25 inches | +| `ONE_COL_HEIGHT` | `aesthetics` | `ONE_COL_WIDTH / golden_ratio` | +| `TWO_COL_HEIGHT` | `aesthetics` | `TWO_COL_WIDTH / golden_ratio` | + +--- Repository initiated with [lamalab-org/cookiecutter-uv](https://github.com/lamalab-org/cookiecutter-uv). diff --git a/pyproject.toml b/pyproject.toml index a630d76..57a8a9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,12 +39,12 @@ dependencies = [ "numpy", "scipy", ] -requires-python = ">=3.10" +requires-python = ">=3.9" readme = "README.md" license = { text = "MIT license" } [tool.setuptools.package-data] -lama_aesthetics = ["styles/*.mplstyle"] +lama_aesthetics = ["styles/*.mplstyle", "fonts/*.otf", "fonts/*.ttf"] [project.optional-dependencies] optional_dependencies = [] diff --git a/src/lama_aesthetics/__init__.py b/src/lama_aesthetics/__init__.py index b19ba41..c94a267 100644 --- a/src/lama_aesthetics/__init__.py +++ b/src/lama_aesthetics/__init__.py @@ -4,16 +4,21 @@ STYLES, TWO_COL_HEIGHT, TWO_COL_WIDTH, + get_font_name, get_style, + register_fonts, ) -from lama_aesthetics.plotutils import add_identity, range_frame, ylabel_top +from lama_aesthetics.plotutils import add_identity, decompose_figure, range_frame, ylabel_top __all__ = [ "STYLES", "get_style", + "register_fonts", + "get_font_name", "range_frame", "ylabel_top", "add_identity", + "decompose_figure", "ONE_COL_WIDTH", "TWO_COL_WIDTH", "ONE_COL_HEIGHT", diff --git a/src/lama_aesthetics/aesthetics.py b/src/lama_aesthetics/aesthetics.py index 4747317..9f440dd 100644 --- a/src/lama_aesthetics/aesthetics.py +++ b/src/lama_aesthetics/aesthetics.py @@ -1,11 +1,49 @@ import importlib.resources import matplotlib.pyplot as plt +from matplotlib import font_manager +from matplotlib.font_manager import FontProperties from scipy.constants import golden +# Track whether fonts have been registered +_fonts_registered = False + + +def register_fonts() -> None: + """Register bundled fonts with Matplotlib's font manager. + + This function adds the CMU Sans Serif font bundled with the package + to Matplotlib's font manager, allowing it to be used without system-wide + installation. Safe to call multiple times (will only register once). + """ + global _fonts_registered + if _fonts_registered: + return + + # Get path to the bundled font + font_files = ["cmunss.otf"] + + for font_file in font_files: + with importlib.resources.as_file(importlib.resources.files("lama_aesthetics.fonts").joinpath(font_file)) as font_path: + font_manager.fontManager.addfont(str(font_path)) + + _fonts_registered = True + + +def get_font_name() -> str: + """Get the family name of the bundled CMU Sans Serif font. + + Returns: + The font family name as recognized by Matplotlib. + """ + with importlib.resources.as_file(importlib.resources.files("lama_aesthetics.fonts").joinpath("cmunss.otf")) as font_path: + return FontProperties(fname=str(font_path)).get_name() + + STYLES = { "main": "lamalab.mplstyle", "presentation": "presentation.mplstyle", + "dark": "lamalab_dark.mplstyle", } # Figure dimensions @@ -18,8 +56,11 @@ def get_style(style_name: str) -> None: """Get the path to a matplotlib style file and apply it. + This function registers bundled fonts before applying the style, + ensuring the CMU Sans Serif font is available without system installation. + Args: - style_name: Name of the style ('main' or 'presentation') + style_name: Name of the style ('main', 'presentation', or 'dark') Raises: KeyError: If style_name is not in STYLES dictionary @@ -27,6 +68,9 @@ def get_style(style_name: str) -> None: if style_name not in STYLES: raise KeyError(f"Style '{style_name}' not found. Available styles: {list(STYLES.keys())}") + # Register bundled fonts before applying the style + register_fonts() + style_file = STYLES[style_name] # Get the file contents as a string diff --git a/src/lama_aesthetics/fonts/cmunss.otf b/src/lama_aesthetics/fonts/cmunss.otf new file mode 100644 index 0000000..49fecee Binary files /dev/null and b/src/lama_aesthetics/fonts/cmunss.otf differ diff --git a/src/lama_aesthetics/plotutils.py b/src/lama_aesthetics/plotutils.py index b293198..01d5ed6 100644 --- a/src/lama_aesthetics/plotutils.py +++ b/src/lama_aesthetics/plotutils.py @@ -1,32 +1,134 @@ -from typing import Optional +from typing import List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np +from matplotlib.collections import PathCollection, PolyCollection +from matplotlib.container import BarContainer +from matplotlib.figure import Figure +from matplotlib.lines import Line2D +from matplotlib.ticker import MaxNLocator -def range_frame(ax, x, y, pad=0.1): +def _get_axis_bounds(values): + """Return axis bounds for numeric or categorical values.""" + arr = np.asarray(values) + + try: + numeric_arr = np.asarray(values, dtype=float) + except (TypeError, ValueError): + return 0, max(len(arr) - 1, 0), False + + return numeric_arr.min(), numeric_arr.max(), True + + +def _nice_tick_bounds(data_min, data_max): + """Return nice tick positions and spine bounds that strictly bracket the data. + + Uses matplotlib's ``MaxNLocator`` to compute tick positions for the + data range and then selects the outermost ticks as spine bounds. + The returned bounds are guaranteed to satisfy + ``bound_lo <= data_min`` and ``bound_hi >= data_max``, and every + tick between the bounds (inclusive) is included. + + Args: + data_min: Minimum value present in the data. + data_max: Maximum value present in the data. + + Returns: + ``(bound_lo, bound_hi, ticks)`` where *ticks* is a 1-D array of + the tick positions that fall within ``[bound_lo, bound_hi]``. + """ + if data_min == data_max: + # Degenerate case — expand symmetrically so the locator has a range. + if data_min == 0: + data_min, data_max = -0.5, 0.5 + else: + delta = abs(data_min) * 0.1 + data_min, data_max = data_min - delta, data_max + delta + + locator = MaxNLocator(nbins="auto", steps=[1, 2, 2.5, 5, 10]) + ticks = np.asarray(locator.tick_values(data_min, data_max)) + + # The locator already returns values that bracket [data_min, data_max], + # but enforce the invariant explicitly. + bound_lo = float(ticks[ticks <= data_min + 1e-12].max()) if np.any(ticks <= data_min + 1e-12) else float(ticks[0]) + bound_hi = float(ticks[ticks >= data_max - 1e-12].min()) if np.any(ticks >= data_max - 1e-12) else float(ticks[-1]) + + # Keep only ticks within the chosen bounds. + mask = (ticks >= bound_lo - 1e-12) & (ticks <= bound_hi + 1e-12) + ticks = ticks[mask] + + return bound_lo, bound_hi, ticks + + +def range_frame(ax, x, y, pad=0.1, pad_x=None, pad_y=None, nice=True): """ Set the limits of the axes to include all data points with a padding of `pad` times the range of the data. This is useful to ensure that the data points are not cut off by the axes. + Per-axis padding can be controlled with ``pad_x`` and ``pad_y``. When + either is *None* (the default) the value of ``pad`` is used instead. + + When ``nice`` is *True* (the default) and the axis carries numerical + data, the spine bounds are snapped to nice tick positions that bracket + the data, so that the axis line starts and ends exactly at tick marks. + Tick positions are computed via matplotlib's ``MaxNLocator`` and + explicitly set on the axes so there is no drift between ticks and + spine endpoints. The ``pad`` / ``pad_x`` / ``pad_y`` parameters are + ignored for any axis that receives nice bounds. + Args: ax: The axes object. x: The x-coordinates of the data points. y: The y-coordinates of the data points. - pad: The padding factor. + pad: The default padding factor applied to both axes. + pad_x: Padding near the x-axis (vertical direction). Overrides ``pad`` when set. + pad_y: Padding near the y-axis (horizontal direction). Overrides ``pad`` when set. + nice: If *True* (default), snap numeric spine bounds to nice tick + positions that bracket the data. """ - y_min, y_max = y.min(), y.max() - x_min, x_max = x.min(), x.max() + if pad_x is None: + pad_x = pad + if pad_y is None: + pad_y = pad - ax.set_ylim(y_min - pad * (y_max - y_min), y_max + pad * (y_max - y_min)) - ax.set_xlim(x_min - pad * (x_max - x_min), x_max + pad * (x_max - x_min)) + y_min, y_max, y_is_numeric = _get_axis_bounds(y) + x_min, x_max, x_is_numeric = _get_axis_bounds(x) + + # --- Y axis ------------------------------------------------------------ + if y_is_numeric: + if nice: + y_bound_min, y_bound_max, y_ticks = _nice_tick_bounds(y_min, y_max) + ax.set_yticks(y_ticks) + ax.set_ylim(y_bound_min, y_bound_max) + else: + y_bound_min = y_min + y_bound_max = y_max + ax.set_ylim(y_min - pad_x * (y_max - y_min), y_max + pad_x * (y_max - y_min)) + else: + y_bound_min, y_bound_max = y_min, y_max + ax.set_ylim(y_min, y_max) + + # --- X axis ------------------------------------------------------------ + if x_is_numeric: + if nice: + x_bound_min, x_bound_max, x_ticks = _nice_tick_bounds(x_min, x_max) + ax.set_xticks(x_ticks) + ax.set_xlim(x_bound_min, x_bound_max) + else: + x_bound_min = x_min + x_bound_max = x_max + ax.set_xlim(x_min - pad_y * (x_max - x_min), x_max + pad_y * (x_max - x_min)) + else: + x_bound_min, x_bound_max = x_min, x_max + ax.set_xlim(x_min, x_max) ax.spines["left"].set_position(("outward", 10)) ax.spines["bottom"].set_position(("outward", 10)) - ax.spines["bottom"].set_bounds(x_min, x_max) - ax.spines["left"].set_bounds(y_min, y_max) + ax.spines["bottom"].set_bounds(x_bound_min, x_bound_max) + ax.spines["left"].set_bounds(y_bound_min, y_bound_max) def ylabel_top(string: str, ax: Optional[plt.Axes] = None, x_pad: float = 0.01, y_pad: float = 0.02) -> None: @@ -105,3 +207,201 @@ def callback(axes): axes.callbacks.connect("xlim_changed", callback) axes.callbacks.connect("ylim_changed", callback) return axes + + +def _setup_axes_like(source_ax: plt.Axes, target_ax: plt.Axes) -> None: + """Copy axis labels, title, limits, scales, and spine visibility from *source_ax* to *target_ax*.""" + target_ax.set_xlim(source_ax.get_xlim()) + target_ax.set_ylim(source_ax.get_ylim()) + target_ax.set_xlabel(source_ax.get_xlabel()) + target_ax.set_ylabel(source_ax.get_ylabel()) + target_ax.set_title(source_ax.get_title()) + target_ax.set_xscale(source_ax.get_xscale()) + target_ax.set_yscale(source_ax.get_yscale()) + + # Copy tick parameters + target_ax.set_xticks(source_ax.get_xticks()) + target_ax.set_yticks(source_ax.get_yticks()) + try: + target_ax.set_xticklabels([t.get_text() for t in source_ax.get_xticklabels()]) + target_ax.set_yticklabels([t.get_text() for t in source_ax.get_yticklabels()]) + except Exception: + pass + + # Re-apply limits after setting ticks (ticks can change limits) + target_ax.set_xlim(source_ax.get_xlim()) + target_ax.set_ylim(source_ax.get_ylim()) + + # Copy spine visibility + for spine_name in ("top", "bottom", "left", "right"): + target_ax.spines[spine_name].set_visible(source_ax.spines[spine_name].get_visible()) + target_ax.spines[spine_name].set_bounds(*source_ax.spines[spine_name].get_bounds()) if source_ax.spines[spine_name].get_bounds() else None + target_ax.spines[spine_name].set_position(source_ax.spines[spine_name].get_position()) + + # Copy grid state + target_ax.grid(source_ax.xaxis.get_gridlines()[0].get_visible() if source_ax.xaxis.get_gridlines() else False) + + +def _replay_line(source_line: Line2D, target_ax: plt.Axes) -> None: + """Re-draw a Line2D artist onto *target_ax*.""" + target_ax.plot( + source_line.get_xdata(), + source_line.get_ydata(), + color=source_line.get_color(), + linestyle=source_line.get_linestyle(), + linewidth=source_line.get_linewidth(), + marker=source_line.get_marker(), + markersize=source_line.get_markersize(), + markerfacecolor=source_line.get_markerfacecolor(), + markeredgecolor=source_line.get_markeredgecolor(), + markeredgewidth=source_line.get_markeredgewidth(), + alpha=source_line.get_alpha(), + label=source_line.get_label(), + ) + + +def _replay_scatter(source_coll: PathCollection, target_ax: plt.Axes, label: str) -> None: + """Re-draw a scatter PathCollection onto *target_ax*.""" + offsets = source_coll.get_offsets() + if len(offsets) == 0: + return + facecolors = source_coll.get_facecolor() + edgecolors = source_coll.get_edgecolor() + sizes = source_coll.get_sizes() + target_ax.scatter( + offsets[:, 0], + offsets[:, 1], + color=facecolors if len(facecolors) > 1 else facecolors[0], + edgecolors=edgecolors if len(edgecolors) > 1 else edgecolors[0], + s=sizes if len(sizes) > 1 else sizes[0], + alpha=source_coll.get_alpha(), + label=label, + ) + + +def _replay_bar(container: BarContainer, target_ax: plt.Axes) -> None: + """Re-draw a BarContainer onto *target_ax*.""" + for patch in container.patches: + target_ax.bar( + patch.get_x() + patch.get_width() / 2, + patch.get_height(), + width=patch.get_width(), + bottom=patch.get_y(), + color=patch.get_facecolor(), + edgecolor=patch.get_edgecolor(), + linewidth=patch.get_linewidth(), + alpha=patch.get_alpha(), + label=container.get_label() if patch is container.patches[0] else None, + ) + + +def _replay_fill(source_coll: PolyCollection, target_ax: plt.Axes, label: str) -> None: + """Re-draw a PolyCollection (e.g. fill_between) onto *target_ax*.""" + for path in source_coll.get_paths(): + verts = path.vertices + target_ax.fill( + verts[:, 0], + verts[:, 1], + facecolor=source_coll.get_facecolor()[0], + edgecolor=source_coll.get_edgecolor()[0] if len(source_coll.get_edgecolor()) > 0 else None, + alpha=source_coll.get_alpha(), + label=label, + ) + label = None # only label the first polygon + + +def decompose_figure( + fig_or_ax: Union[Figure, plt.Axes], + *, + show_legend: bool = True, +) -> List[Tuple[str, Figure]]: + """Decompose a matplotlib figure into individual figures, one per labeled artist. + + Each returned figure contains a single plotted element (line, scatter, + bar group, fill, …) together with the same axis labels, limits, and + title as the original. Only artists that carry a label (and would + therefore appear in a legend) are considered; artists whose label + starts with ``_`` are skipped, following the matplotlib convention. + + Args: + fig_or_ax: A :class:`~matplotlib.figure.Figure` or a single + :class:`~matplotlib.axes.Axes` instance. When a *Figure* + is given the first ``Axes`` is used. + show_legend: If *True* (default) a legend is added to every + decomposed figure. + + Returns: + A list of ``(label, figure)`` tuples where *label* is the + legend text associated with the artist and *figure* is a new + :class:`~matplotlib.figure.Figure` containing only that artist. + + Example:: + + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], label="Linear") + ax.plot([0, 1], [0, 2], label="Steep") + + parts = decompose_figure(fig) + for label, part_fig in parts: + part_fig.savefig(f"{label}.png") + """ + # Resolve axes ---------------------------------------------------------- + if isinstance(fig_or_ax, Figure): + axes_list = fig_or_ax.get_axes() + if not axes_list: + return [] + source_ax = axes_list[0] + source_fig = fig_or_ax + else: + source_ax = fig_or_ax + source_fig = fig_or_ax.get_figure() + + figsize = source_fig.get_size_inches() + + # Collect labelled artists ---------------------------------------------- + items: List[Tuple[str, object]] = [] + + # 1. Lines + for line in source_ax.get_lines(): + lbl = line.get_label() + if lbl and not lbl.startswith("_"): + items.append((lbl, line)) + + # 2. Bar containers + for container in source_ax.containers: + if isinstance(container, BarContainer): + lbl = container.get_label() + if lbl and not lbl.startswith("_"): + items.append((lbl, container)) + + # 3. Collections (scatter / fill_between) + for coll in source_ax.collections: + lbl = coll.get_label() + if lbl and not lbl.startswith("_"): + items.append((lbl, coll)) + + # Build one figure per item --------------------------------------------- + results: List[Tuple[str, Figure]] = [] + for label, artist in items: + new_fig, new_ax = plt.subplots(figsize=figsize) + + # Reproduce axes properties + _setup_axes_like(source_ax, new_ax) + + # Replay the single artist + if isinstance(artist, Line2D): + _replay_line(artist, new_ax) + elif isinstance(artist, BarContainer): + _replay_bar(artist, new_ax) + elif isinstance(artist, PathCollection): + _replay_scatter(artist, new_ax, label) + elif isinstance(artist, PolyCollection): + _replay_fill(artist, new_ax, label) + + if show_legend: + new_ax.legend() + + new_fig.tight_layout() + results.append((label, new_fig)) + + return results diff --git a/src/lama_aesthetics/styles/lamalab_dark.mplstyle b/src/lama_aesthetics/styles/lamalab_dark.mplstyle new file mode 100644 index 0000000..7aeb77e --- /dev/null +++ b/src/lama_aesthetics/styles/lamalab_dark.mplstyle @@ -0,0 +1,74 @@ +# '000000', +axes.prop_cycle : cycler('color', ['0C5DA5', '00B945', 'FF9500', 'FF2C00', '845B97', '474747', '9e9e9e', "9A607F"]) + +#axes.prop_cycle : cycler('color', ["DB444B", "006BA2", "3EBCD2", "379A8B", "EBB434", "#B4BA39", "#9A607F", '#9e9e9e', "#D1B07C"]) + + +# Figure size +figure.figsize : 3.3, 2.5 # max width is 3.5 for single column +# Set x axis +xtick.direction : in +xtick.major.size : 3 +xtick.major.width : 0.5 +xtick.minor.size : 1.5 +xtick.minor.width : 0.5 +xtick.minor.visible : True +xtick.top : False +xtick.minor.top: False +xtick.minor.bottom: False + +# Set y axis +ytick.direction : in +ytick.major.size : 3 +ytick.major.width : 0.5 +ytick.minor.size : 1.5 +ytick.minor.width : 0.5 +ytick.minor.visible : True +ytick.minor.left: False +ytick.right: False + +# Set line widths +axes.linewidth : 0.5 +grid.linewidth : 0.5 +lines.linewidth : 1. +lines.markersize: 3 +axes.xmargin: 0 +axes.ymargin: 0 + +# Remove legend frame +legend.frameon : False + +# Always save as 'tight' +# savefig.bbox : tight +# savefig.pad_inches : 0.01 # Use virtually all space when we specify figure dimensions + +# Font sizes +axes.labelsize: 11 +xtick.labelsize: 10 +ytick.labelsize: 10 +legend.fontsize: 10 +font.size: 10 + +# Font Family +font.family: sans-serif +font.sans-serif: CMU Sans Serif, IBM Plex Sans, Roboto, Helvetica +mathtext.fontset : dejavusans + +# LaTeX packages +text.latex.preamble : \usepackage{amsmath} \usepackage{amssymb} \usepackage{sfmath} + +# remove non-data ink +axes.spines.top: False +axes.spines.right: False +axes.axisbelow: True + +# Dark theme colors +text.color: white +axes.labelcolor: white +axes.edgecolor: white +xtick.color: white +ytick.color: white +figure.facecolor: black +axes.facecolor: black +savefig.facecolor: black +legend.labelcolor: white diff --git a/tests/test_plotutils.py b/tests/test_plotutils.py index dec9179..65e5294 100644 --- a/tests/test_plotutils.py +++ b/tests/test_plotutils.py @@ -1,21 +1,47 @@ import matplotlib.pyplot as plt import numpy as np -from lama_aesthetics.plotutils import add_identity, range_frame, ylabel_top +from lama_aesthetics.plotutils import ( + _nice_tick_bounds, + add_identity, + decompose_figure, + range_frame, + ylabel_top, +) def test_range_frame(): - """Test that range_frame sets axis limits correctly.""" + """Test that range_frame sets axis limits correctly (nice=True by default).""" fig, ax = plt.subplots() x = np.array([0, 1, 2, 3, 4]) y = np.array([0, 2, 4, 6, 8]) - range_frame(ax, x, y, pad=0.1) + range_frame(ax, x, y) + + xlim = ax.get_xlim() + ylim = ax.get_ylim() + + # Nice bounds should contain all data + assert xlim[0] <= x.min() + assert xlim[1] >= x.max() + assert ylim[0] <= y.min() + assert ylim[1] >= y.max() + + plt.close(fig) + + +def test_range_frame_nice_false(): + """Test that range_frame with nice=False uses raw padding.""" + fig, ax = plt.subplots() + x = np.array([0, 1, 2, 3, 4]) + y = np.array([0, 2, 4, 6, 8]) + + range_frame(ax, x, y, pad=0.1, nice=False) xlim = ax.get_xlim() ylim = ax.get_ylim() - # Check that limits include all data points + # With nice=False the old padding-based behaviour applies assert xlim[0] < x.min() assert xlim[1] > x.max() assert ylim[0] < y.min() @@ -24,6 +50,67 @@ def test_range_frame(): plt.close(fig) +def test_range_frame_per_axis_pad(): + """Test that pad_x and pad_y override the default pad independently (nice=False).""" + fig, ax = plt.subplots() + x = np.array([0, 1, 2, 3, 4]) + y = np.array([0, 2, 4, 6, 8]) + + # pad_x controls padding near x-axis (vertical / y-limits) + # pad_y controls padding near y-axis (horizontal / x-limits) + range_frame(ax, x, y, pad_x=0.2, pad_y=0.0, nice=False) + + xlim = ax.get_xlim() + ylim = ax.get_ylim() + + # pad_y=0.0 → x-limits equal data range (no horizontal padding) + assert xlim[0] == x.min() + assert xlim[1] == x.max() + + # pad_x=0.2 → y-limits have vertical padding + y_range = y.max() - y.min() + assert ylim[0] < y.min() + assert ylim[1] > y.max() + assert abs(ylim[0] - (y.min() - 0.2 * y_range)) < 1e-10 + assert abs(ylim[1] - (y.max() + 0.2 * y_range)) < 1e-10 + + plt.close(fig) + + +def test_range_frame_non_numeric_x_axis(): + """Non-numeric x values should use 0..len(x)-1 for the range frame.""" + fig, ax = plt.subplots() + x = ["a", "b", "c", "d"] + y = np.array([0, 2, 4, 6]) + + ax.plot(x, y) + range_frame(ax, x, y, pad=0.1) + + xlim = ax.get_xlim() + + assert xlim == (0.0, float(len(x) - 1)) + assert ax.spines["bottom"].get_bounds() == (0, len(x) - 1) + + plt.close(fig) + + +def test_range_frame_non_numeric_y_axis(): + """Non-numeric y values should use 0..len(y)-1 for the range frame.""" + fig, ax = plt.subplots() + x = np.array([0, 1, 2, 3]) + y = ["low", "mid", "high", "top"] + + ax.plot(x, y) + range_frame(ax, x, y, pad=0.1) + + ylim = ax.get_ylim() + + assert ylim == (0.0, float(len(y) - 1)) + assert ax.spines["left"].get_bounds() == (0, len(y) - 1) + + plt.close(fig) + + def test_ylabel_top(): """Test that ylabel_top sets ylabel without errors.""" fig, ax = plt.subplots() @@ -63,3 +150,210 @@ def test_add_identity(): assert result == ax plt.close(fig) + + +# --- nice tick bounds tests ------------------------------------------------- + + +def test_nice_tick_bounds_brackets_data(): + """_nice_tick_bounds should return bounds that contain the data.""" + lo, hi, ticks = _nice_tick_bounds(3.2, 47.8) + assert lo <= 3.2 + assert hi >= 47.8 + # Bounds must be tick positions + assert any(abs(lo - t) < 1e-10 for t in ticks) + assert any(abs(hi - t) < 1e-10 for t in ticks) + + +def test_nice_tick_bounds_already_nice(): + """When data already spans a tick-aligned range, bounds should match.""" + lo, hi, ticks = _nice_tick_bounds(0, 10) + assert lo == 0.0 + assert hi == 10.0 + + +def test_nice_tick_bounds_negative(): + """Negative data ranges should also produce nice bounds.""" + lo, hi, ticks = _nice_tick_bounds(-7.3, -1.2) + assert lo <= -7.3 + assert hi >= -1.2 + assert any(abs(lo - t) < 1e-10 for t in ticks) + assert any(abs(hi - t) < 1e-10 for t in ticks) + + +def test_range_frame_nice_bounds_spine_alignment(): + """With nice=True the spine bounds must coincide with actual tick positions.""" + fig, ax = plt.subplots() + x = np.array([0.5, 1.3, 2.7, 3.9]) + y = np.array([1.1, 4.4, 7.2, 9.8]) + + range_frame(ax, x, y, nice=True) + + x_spine = ax.spines["bottom"].get_bounds() + y_spine = ax.spines["left"].get_bounds() + + # Spine bounds should contain all data + assert x_spine[0] <= x.min() + assert x_spine[1] >= x.max() + assert y_spine[0] <= y.min() + assert y_spine[1] >= y.max() + + # Spine bounds must be actual tick positions + x_ticks = ax.xaxis.get_ticklocs() + y_ticks = ax.yaxis.get_ticklocs() + assert any(abs(x_spine[0] - t) < 1e-10 for t in x_ticks), f"x spine lo {x_spine[0]} not in ticks {x_ticks}" + assert any(abs(x_spine[1] - t) < 1e-10 for t in x_ticks), f"x spine hi {x_spine[1]} not in ticks {x_ticks}" + assert any(abs(y_spine[0] - t) < 1e-10 for t in y_ticks), f"y spine lo {y_spine[0]} not in ticks {y_ticks}" + assert any(abs(y_spine[1] - t) < 1e-10 for t in y_ticks), f"y spine hi {y_spine[1]} not in ticks {y_ticks}" + + plt.close(fig) + + +def test_range_frame_nice_non_numeric_unchanged(): + """nice=True should not affect categorical axes.""" + fig, ax = plt.subplots() + x = ["a", "b", "c"] + y = np.array([1, 5, 9]) + + ax.plot(x, y) + range_frame(ax, x, y, nice=True) + + # Categorical x-axis: bounds should be index-based + assert ax.spines["bottom"].get_bounds() == (0, 2) + + plt.close(fig) + + +# --- decompose_figure tests ------------------------------------------------ + + +def test_decompose_figure_lines(): + """decompose_figure should return one figure per labelled line.""" + fig, ax = plt.subplots() + ax.plot([0, 1, 2], [0, 1, 4], label="Linear") + ax.plot([0, 1, 2], [0, 2, 8], label="Steep") + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_title("My Plot") + + parts = decompose_figure(fig) + + assert len(parts) == 2 + assert parts[0][0] == "Linear" + assert parts[1][0] == "Steep" + + # Each returned figure should have exactly one line on its axes + for label, part_fig in parts: + part_ax = part_fig.get_axes()[0] + labelled_lines = [line for line in part_ax.get_lines() if not line.get_label().startswith("_")] + assert len(labelled_lines) == 1 + assert labelled_lines[0].get_label() == label + # Axes metadata preserved + assert part_ax.get_xlabel() == "x" + assert part_ax.get_ylabel() == "y" + assert part_ax.get_title() == "My Plot" + + for _, f in parts: + plt.close(f) + plt.close(fig) + + +def test_decompose_figure_scatter(): + """decompose_figure should handle scatter plots.""" + fig, ax = plt.subplots() + ax.scatter([0, 1], [2, 3], label="Group A") + ax.scatter([4, 5], [6, 7], label="Group B") + + parts = decompose_figure(fig) + + assert len(parts) == 2 + assert {p[0] for p in parts} == {"Group A", "Group B"} + + for _, f in parts: + plt.close(f) + plt.close(fig) + + +def test_decompose_figure_bars(): + """decompose_figure should handle bar plots.""" + fig, ax = plt.subplots() + ax.bar([0, 1, 2], [3, 4, 5], label="Series 1") + ax.bar([0, 1, 2], [1, 2, 3], bottom=[3, 4, 5], label="Series 2") + + parts = decompose_figure(fig) + + assert len(parts) == 2 + assert parts[0][0] == "Series 1" + assert parts[1][0] == "Series 2" + + for _, f in parts: + plt.close(f) + plt.close(fig) + + +def test_decompose_figure_skips_unlabelled(): + """Artists without a label or with _-prefixed labels should be skipped.""" + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1]) # no label + ax.plot([0, 1], [0, 2], label="_hidden") + ax.plot([0, 1], [0, 3], label="Visible") + + parts = decompose_figure(fig) + + assert len(parts) == 1 + assert parts[0][0] == "Visible" + + for _, f in parts: + plt.close(f) + plt.close(fig) + + +def test_decompose_figure_accepts_axes(): + """decompose_figure should accept a single Axes object.""" + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], label="A") + ax.plot([0, 1], [0, 2], label="B") + + parts = decompose_figure(ax) # pass Axes, not Figure + + assert len(parts) == 2 + + for _, f in parts: + plt.close(f) + plt.close(fig) + + +def test_decompose_figure_no_legend(): + """When show_legend=False no legend should be present.""" + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], label="A") + + parts = decompose_figure(fig, show_legend=False) + + assert len(parts) == 1 + part_ax = parts[0][1].get_axes()[0] + assert part_ax.get_legend() is None + + for _, f in parts: + plt.close(f) + plt.close(fig) + + +def test_decompose_figure_preserves_limits(): + """Axis limits should be the same as the original plot.""" + fig, ax = plt.subplots() + ax.plot([0, 10], [0, 100], label="A") + ax.plot([0, 10], [100, 0], label="B") + original_xlim = ax.get_xlim() + original_ylim = ax.get_ylim() + + parts = decompose_figure(fig) + + for _, part_fig in parts: + part_ax = part_fig.get_axes()[0] + assert part_ax.get_xlim() == original_xlim + assert part_ax.get_ylim() == original_ylim + + for _, f in parts: + plt.close(f) + plt.close(fig)