|
1 | | -from typing import List, Tuple |
| 1 | +from typing import List, Optional, Tuple |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import pandas as pd |
5 | 5 | from matplotlib.axes import Axes |
6 | 6 | from matplotlib.figure import Figure |
| 7 | +from scipy.stats import gaussian_kde, probplot |
| 8 | +from sklearn.metrics import mean_squared_error |
7 | 9 |
|
8 | 10 | from .base_plotting import DefaultStaticPlotting |
9 | 11 | from .colors import default_colors |
10 | 12 |
|
11 | 13 |
|
| 14 | +def rmse(pred: np.ndarray, tar: np.ndarray) -> float: |
| 15 | + """ |
| 16 | + Calculate the Root Mean Square Error between predicted and target values. |
| 17 | +
|
| 18 | + Parameters |
| 19 | + ---------- |
| 20 | + pred : np.ndarray |
| 21 | + Array of predicted values. |
| 22 | + tar : np.ndarray |
| 23 | + Array of target/actual values. |
| 24 | +
|
| 25 | + Returns |
| 26 | + ------- |
| 27 | + float |
| 28 | + The Root Mean Square Error value. |
| 29 | + """ |
| 30 | + |
| 31 | + if len(pred) != len(tar): |
| 32 | + raise ValueError("pred and tar must have the same length") |
| 33 | + |
| 34 | + return np.sqrt(((pred - tar) ** 2).mean()) |
| 35 | + |
| 36 | + |
| 37 | +def bias(pred: np.ndarray, tar: np.ndarray) -> float: |
| 38 | + """ |
| 39 | + Calculate the bias between predicted and target values. |
| 40 | +
|
| 41 | + Parameters |
| 42 | + ---------- |
| 43 | + pred : np.ndarray |
| 44 | + Array of predicted values. |
| 45 | + tar : np.ndarray |
| 46 | + Array of target/actual values. |
| 47 | +
|
| 48 | + Returns |
| 49 | + ------- |
| 50 | + float |
| 51 | + The bias value (mean difference between predictions and targets). |
| 52 | + """ |
| 53 | + |
| 54 | + if len(pred) != len(tar): |
| 55 | + raise ValueError("pred and tar must have the same length") |
| 56 | + |
| 57 | + return sum(pred - tar) / len(pred) |
| 58 | + |
| 59 | + |
| 60 | +def si(pred: np.ndarray, tar: np.ndarray) -> float: |
| 61 | + """ |
| 62 | + Calculate the Scatter Index between predicted and target values. |
| 63 | +
|
| 64 | + Parameters |
| 65 | + ---------- |
| 66 | + pred : np.ndarray |
| 67 | + Array of predicted values. |
| 68 | + tar : np.ndarray |
| 69 | + Array of target/actual values. |
| 70 | +
|
| 71 | + Returns |
| 72 | + ------- |
| 73 | + float |
| 74 | + The Scatter Index value. |
| 75 | + """ |
| 76 | + |
| 77 | + if len(pred) != len(tar): |
| 78 | + raise ValueError("pred and tar must have the same length") |
| 79 | + |
| 80 | + pred_mean = pred.mean() |
| 81 | + tar_mean = tar.mean() |
| 82 | + |
| 83 | + return np.sqrt(sum(((pred - pred_mean) - (tar - tar_mean)) ** 2) / (sum(tar**2))) |
| 84 | + |
| 85 | + |
| 86 | +def density_scatter( |
| 87 | + x: np.ndarray, y: np.ndarray |
| 88 | +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 89 | + """ |
| 90 | + Compute a density scatter for two arrays using gaussian KDE. |
| 91 | +
|
| 92 | + Parameters |
| 93 | + ---------- |
| 94 | + x : np.ndarray |
| 95 | + X values for the scatter plot. |
| 96 | + y : np.ndarray |
| 97 | + Y values for the scatter plot. |
| 98 | +
|
| 99 | + Returns |
| 100 | + ------- |
| 101 | + Tuple[np.ndarray, np.ndarray, np.ndarray] |
| 102 | + A tuple containing: |
| 103 | + - Sorted x values |
| 104 | + - Sorted y values |
| 105 | + - Density values corresponding to each point |
| 106 | + """ |
| 107 | + |
| 108 | + if len(x) != len(y): |
| 109 | + raise ValueError("x and y must have the same length") |
| 110 | + |
| 111 | + xy = np.vstack([x, y]) |
| 112 | + z = gaussian_kde(xy)(xy) |
| 113 | + idx = z.argsort() |
| 114 | + x1, y1, z = x[idx], y[idx], z[idx] |
| 115 | + |
| 116 | + return x1, y1, z |
| 117 | + |
| 118 | + |
| 119 | +def validation_scatter( |
| 120 | + axs: Axes, |
| 121 | + x: np.ndarray, |
| 122 | + y: np.ndarray, |
| 123 | + xlabel: str, |
| 124 | + ylabel: str, |
| 125 | + title: str, |
| 126 | +) -> None: |
| 127 | + """ |
| 128 | + Plot a density scatter and Q-Q plot for validation. |
| 129 | +
|
| 130 | + Parameters |
| 131 | + ---------- |
| 132 | + axs : Axes |
| 133 | + Matplotlib axes to plot on. |
| 134 | + x : np.ndarray |
| 135 | + X values for the scatter plot. |
| 136 | + y : np.ndarray |
| 137 | + Y values for the scatter plot. |
| 138 | + xlabel : str |
| 139 | + Label for the X-axis. |
| 140 | + ylabel : str |
| 141 | + Label for the Y-axis. |
| 142 | + title : str |
| 143 | + Title for the plot. |
| 144 | + """ |
| 145 | + |
| 146 | + x2, y2, z = density_scatter(x, y) |
| 147 | + |
| 148 | + # plot |
| 149 | + axs.scatter(x2, y2, c=z, s=5, cmap="rainbow") |
| 150 | + |
| 151 | + # labels |
| 152 | + axs.set_xlabel(xlabel) |
| 153 | + axs.set_ylabel(ylabel) |
| 154 | + axs.set_title(title) |
| 155 | + |
| 156 | + # axis limits |
| 157 | + maxt = np.ceil(max(max(x) + 0.1, max(y) + 0.1)) |
| 158 | + axs.set_xlim(0, maxt) |
| 159 | + axs.set_ylim(0, maxt) |
| 160 | + axs.plot([0, maxt], [0, maxt], "-r") |
| 161 | + axs.set_xticks(np.linspace(0, maxt, 5)) |
| 162 | + axs.set_yticks(np.linspace(0, maxt, 5)) |
| 163 | + axs.set_aspect("equal") |
| 164 | + |
| 165 | + # qq-plot |
| 166 | + xq = probplot(x, dist="norm") |
| 167 | + yq = probplot(y, dist="norm") |
| 168 | + axs.plot(xq[0][1], yq[0][1], "o", markersize=0.5, color="k", label="Q-Q plot") |
| 169 | + |
| 170 | + # diagnostic errors |
| 171 | + props = dict( |
| 172 | + boxstyle="round", facecolor="w", edgecolor="grey", linewidth=0.8, alpha=0.5 |
| 173 | + ) |
| 174 | + mse = mean_squared_error(x2, y2) |
| 175 | + rmse_e = rmse(x2, y2) |
| 176 | + BIAS = bias(x2, y2) |
| 177 | + SI = si(x2, y2) |
| 178 | + label = "\n".join( |
| 179 | + ( |
| 180 | + r"RMSE = %.2f" % (rmse_e,), |
| 181 | + r"mse = %.2f" % (mse,), |
| 182 | + r"BIAS = %.2f" % (BIAS,), |
| 183 | + R"SI = %.2f" % (SI,), |
| 184 | + ) |
| 185 | + ) |
| 186 | + axs.text( |
| 187 | + 0.05, |
| 188 | + 0.95, |
| 189 | + label, |
| 190 | + transform=axs.transAxes, |
| 191 | + fontsize=9, |
| 192 | + verticalalignment="top", |
| 193 | + bbox=props, |
| 194 | + ) |
| 195 | + |
| 196 | + |
12 | 197 | def plot_scatters_in_triangle( |
13 | 198 | dataframes: List[pd.DataFrame], |
14 | | - data_colors: List[str] = default_colors, |
| 199 | + data_colors: Optional[List[str]] = None, |
15 | 200 | **kwargs, |
16 | | -) -> Tuple[Figure, Axes]: |
| 201 | +) -> Tuple[Figure, np.ndarray]: |
17 | 202 | """ |
18 | | - Plot a scatter plot of the dataframes with axes in a triangle. |
| 203 | + Plot scatter plots of the dataframes with axes in a triangle arrangement. |
19 | 204 |
|
20 | 205 | Parameters |
21 | 206 | ---------- |
22 | 207 | dataframes : List[pd.DataFrame] |
23 | | - List of dataframes to plot. |
24 | | - data_colors : List[str], optional |
25 | | - List of colors for the dataframes. |
| 208 | + List of dataframes to plot. Each dataframe should contain the same columns. |
| 209 | + data_colors : Optional[List[str]], optional |
| 210 | + List of colors for the dataframes. If None, uses default_colors. |
26 | 211 | **kwargs : dict, optional |
27 | | - Keyword arguments for the scatter plot. Will be passed to the |
28 | | - DefaultStaticPlotting.plot_scatter method, which is the same |
29 | | - as the one in matplotlib.pyplot.scatter. |
30 | | - For example, to change the marker size, you can use: |
31 | | - ``plot_scatters_in_triangle(dataframes, s=10)`` |
| 212 | + Additional keyword arguments for the scatter plot. These will be passed to |
| 213 | + matplotlib.pyplot.scatter. Common parameters include: |
| 214 | + - s : float, marker size |
| 215 | + - alpha : float, transparency |
| 216 | + - marker : str, marker style |
32 | 217 |
|
33 | 218 | Returns |
34 | 219 | ------- |
35 | | - fig : Figure |
36 | | - Figure object. |
37 | | - axes : Axes |
38 | | - Axes object. |
| 220 | + Tuple[Figure, np.ndarray] |
| 221 | + A tuple containing: |
| 222 | + - Figure object |
| 223 | + - 2D array of Axes objects |
| 224 | +
|
| 225 | + Raises |
| 226 | + ------ |
| 227 | + ValueError |
| 228 | + If the variables in the first dataframe are not present in all other dataframes. |
39 | 229 | """ |
40 | 230 |
|
| 231 | + if data_colors is None: |
| 232 | + data_colors = default_colors |
| 233 | + |
41 | 234 | # Get the number and names of variables from the first dataframe |
42 | 235 | variables_names = list(dataframes[0].columns) |
43 | 236 | num_variables = len(variables_names) |
|
0 commit comments