Skip to content

Commit ff3e6d3

Browse files
authored
Merge pull request #1 from lamalab-org/minor_changes
feat: add docstrings
2 parents 976b257 + 679e223 commit ff3e6d3

3 files changed

Lines changed: 50 additions & 21 deletions

File tree

lama_aesthetics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .aesthetics import get_style, STYLES
1+
from .aesthetics import STYLES, get_style

lama_aesthetics/aesthetics.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
1-
import matplotlib.pyplot as plt
2-
import os
31
import importlib.resources
42

3+
import matplotlib.pyplot as plt
4+
55
STYLES = {
66
"main": "lamalab.mplstyle",
77
"presentation": "presentation.mplstyle",
88
}
99

10+
1011
def get_style(style_name: str) -> None:
1112
"""Get the path to a matplotlib style file and apply it.
12-
13+
1314
Args:
1415
style_name: Name of the style ('main' or 'presentation')
15-
16+
1617
Raises:
1718
KeyError: If style_name is not in STYLES dictionary
1819
"""
1920
if style_name not in STYLES:
2021
raise KeyError(f"Style '{style_name}' not found. Available styles: {list(STYLES.keys())}")
2122

2223
style_file = STYLES[style_name]
23-
24+
2425
# Get the file contents as a string
2526
# This will only work for Python 3.7 and later
26-
with importlib.resources.path('lama_aesthetics.styles', style_file) as style_path:
27+
with importlib.resources.path("lama_aesthetics.styles", style_file) as style_path:
2728
plt.style.use(style_path)
2829

29-
return
30+
return

lama_aesthetics/plotutils.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
from typing import Optional
2+
13
import matplotlib.pyplot as plt
24
import numpy as np
3-
from typing import Optional
45

56

67
def range_frame(ax, x, y, pad=0.1):
8+
"""
9+
Set the limits of the axes to include all data points with a padding of
10+
`pad` times the range of the data. This is useful to ensure that the data
11+
points are not cut off by the axes.
12+
13+
Args:
14+
ax: The axes object.
15+
x: The x-coordinates of the data points.
16+
y: The y-coordinates of the data points.
17+
pad: The padding factor.
18+
"""
719
y_min, y_max = y.min(), y.max()
820
x_min, x_max = x.min(), x.max()
921

@@ -17,14 +29,20 @@ def range_frame(ax, x, y, pad=0.1):
1729
ax.spines["left"].set_bounds(y_min, y_max)
1830

1931

20-
def ylabel_top(
21-
string: str, ax: Optional[plt.Axes] = None, x_pad: float = 0.01, y_pad: float = 0.02
22-
) -> None:
23-
# Rotate the ylabel (such that you can read it comfortably) and place it
24-
# above the top ytick. This requires some logic, so it cannot be
25-
# incorporated in `style`. See
26-
# <https://stackoverflow.com/a/27919217/353337> on how to get the axes
27-
# coordinates of the top ytick.
32+
def ylabel_top(string: str, ax: Optional[plt.Axes] = None, x_pad: float = 0.01, y_pad: float = 0.02) -> None:
33+
"""
34+
Rotate the ylabel (such that you can read it comfortably) and place it
35+
above the top ytick. This requires some logic, so it cannot be
36+
incorporated in `style`. See
37+
<https://stackoverflow.com/a/27919217/353337> on how to get the axes
38+
coordinates of the top ytick.
39+
40+
Args:
41+
string: The string to be displayed as the ylabel.
42+
ax: The axes object.
43+
x_pad: The x-padding in axes coordinates.
44+
y_pad: The y-padding in axes coordinates.
45+
"""
2846
if ax is None:
2947
ax = plt.gca()
3048

@@ -66,14 +84,24 @@ def ylabel_top(
6684

6785

6886
def add_identity(axes, *line_args, **line_kwargs):
69-
identity, = axes.plot([], [], *line_args, **line_kwargs)
87+
"""
88+
Add a 1:1 line to the axes. This is useful to compare the data to a
89+
90+
Args:
91+
axes: The axes object.
92+
line_args: The positional arguments for the line.
93+
line_kwargs: The keyword arguments for the line.
94+
"""
95+
(identity,) = axes.plot([], [], *line_args, **line_kwargs)
96+
7097
def callback(axes):
7198
low_x, high_x = axes.get_xlim()
7299
low_y, high_y = axes.get_ylim()
73100
low = max(low_x, low_y)
74101
high = min(high_x, high_y)
75102
identity.set_data([low, high], [low, high])
103+
76104
callback(axes)
77-
axes.callbacks.connect('xlim_changed', callback)
78-
axes.callbacks.connect('ylim_changed', callback)
79-
return axes
105+
axes.callbacks.connect("xlim_changed", callback)
106+
axes.callbacks.connect("ylim_changed", callback)
107+
return axes

0 commit comments

Comments
 (0)