1+ from typing import Optional
2+
13import matplotlib .pyplot as plt
24import numpy as np
3- from typing import Optional
45
56
67def range_frame (ax , x , y , pad = 0.1 ):
8+ """
9+ Set the limits of the axes to include all data points with a padding of
10+ `pad` times the range of the data. This is useful to ensure that the data
11+ points are not cut off by the axes.
12+
13+ Args:
14+ ax: The axes object.
15+ x: The x-coordinates of the data points.
16+ y: The y-coordinates of the data points.
17+ pad: The padding factor.
18+ """
719 y_min , y_max = y .min (), y .max ()
820 x_min , x_max = x .min (), x .max ()
921
@@ -17,14 +29,20 @@ def range_frame(ax, x, y, pad=0.1):
1729 ax .spines ["left" ].set_bounds (y_min , y_max )
1830
1931
20- def ylabel_top (
21- string : str , ax : Optional [plt .Axes ] = None , x_pad : float = 0.01 , y_pad : float = 0.02
22- ) -> None :
23- # Rotate the ylabel (such that you can read it comfortably) and place it
24- # above the top ytick. This requires some logic, so it cannot be
25- # incorporated in `style`. See
26- # <https://stackoverflow.com/a/27919217/353337> on how to get the axes
27- # coordinates of the top ytick.
32+ def ylabel_top (string : str , ax : Optional [plt .Axes ] = None , x_pad : float = 0.01 , y_pad : float = 0.02 ) -> None :
33+ """
34+ Rotate the ylabel (such that you can read it comfortably) and place it
35+ above the top ytick. This requires some logic, so it cannot be
36+ incorporated in `style`. See
37+ <https://stackoverflow.com/a/27919217/353337> on how to get the axes
38+ coordinates of the top ytick.
39+
40+ Args:
41+ string: The string to be displayed as the ylabel.
42+ ax: The axes object.
43+ x_pad: The x-padding in axes coordinates.
44+ y_pad: The y-padding in axes coordinates.
45+ """
2846 if ax is None :
2947 ax = plt .gca ()
3048
@@ -66,14 +84,24 @@ def ylabel_top(
6684
6785
6886def add_identity (axes , * line_args , ** line_kwargs ):
69- identity , = axes .plot ([], [], * line_args , ** line_kwargs )
87+ """
88+ Add a 1:1 line to the axes. This is useful to compare the data to a
89+
90+ Args:
91+ axes: The axes object.
92+ line_args: The positional arguments for the line.
93+ line_kwargs: The keyword arguments for the line.
94+ """
95+ (identity ,) = axes .plot ([], [], * line_args , ** line_kwargs )
96+
7097 def callback (axes ):
7198 low_x , high_x = axes .get_xlim ()
7299 low_y , high_y = axes .get_ylim ()
73100 low = max (low_x , low_y )
74101 high = min (high_x , high_y )
75102 identity .set_data ([low , high ], [low , high ])
103+
76104 callback (axes )
77- axes .callbacks .connect (' xlim_changed' , callback )
78- axes .callbacks .connect (' ylim_changed' , callback )
79- return axes
105+ axes .callbacks .connect (" xlim_changed" , callback )
106+ axes .callbacks .connect (" ylim_changed" , callback )
107+ return axes
0 commit comments