diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 8fb5be0..9c737f0 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -48,6 +48,6 @@ jobs: source /usr/share/miniconda/etc/profile.d/conda.sh conda activate bluemath-tk python -m unittest discover tests/datamining/ - python -m unittest discover tests/downloaders/ + python -m unittest discover tests/distributions/ python -m unittest discover tests/interpolation/ python -m unittest discover tests/wrappers/ diff --git a/bluemath_tk/core/decorators.py b/bluemath_tk/core/decorators.py index f597c40..7b1c200 100644 --- a/bluemath_tk/core/decorators.py +++ b/bluemath_tk/core/decorators.py @@ -437,3 +437,59 @@ def wrapper( ) return wrapper + + +def validate_data_calval(func): + """ + Decorator to validate data in CalVal class fit method. + + Parameters + ---------- + func : callable + The function to be decorated + + Returns + ------- + callable + The decorated function + """ + + @functools.wraps(func) + def wrapper( + self, + data: pd.DataFrame, + data_longitude: float, + data_latitude: float, + data_to_calibrate: pd.DataFrame, + max_time_diff: int = 2, + ): + if not isinstance(data, pd.DataFrame): + raise TypeError("Data must be a pandas DataFrame") + if not isinstance(data_longitude, float): + raise TypeError("Longitude must be a float") + data_longitude = data_longitude % 360 + if not isinstance(data_latitude, float): + raise TypeError("Latitude must be a float") + if not isinstance(data_to_calibrate, pd.DataFrame): + raise TypeError("Data to calibrate must be a pandas DataFrame") + if "LONGITUDE" not in data_to_calibrate.columns: + raise ValueError( + "Data to calibrate must contain a column named 'LONGITUDE'" + ) + if "LATITUDE" not in data_to_calibrate.columns: + raise ValueError("Data to calibrate must contain a column named 'LATITUDE'") + if "Hs_CAL" not in data_to_calibrate.columns: + raise ValueError("Data to calibrate must contain a column named 'Hs_CAL'") + if not isinstance(max_time_diff, int) or max_time_diff <= 0: + raise ValueError("Maximum time difference must be an integer and > 0") + + return func( + self, + data, + data_longitude, + data_latitude, + data_to_calibrate, + max_time_diff, + ) + + return wrapper diff --git a/bluemath_tk/core/plotting/scatter.py b/bluemath_tk/core/plotting/scatter.py index 8bd10ef..5be4940 100644 --- a/bluemath_tk/core/plotting/scatter.py +++ b/bluemath_tk/core/plotting/scatter.py @@ -1,43 +1,236 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple import numpy as np import pandas as pd from matplotlib.axes import Axes from matplotlib.figure import Figure +from scipy.stats import gaussian_kde, probplot +from sklearn.metrics import mean_squared_error from .base_plotting import DefaultStaticPlotting from .colors import default_colors +def rmse(pred: np.ndarray, tar: np.ndarray) -> float: + """ + Calculate the Root Mean Square Error between predicted and target values. + + Parameters + ---------- + pred : np.ndarray + Array of predicted values. + tar : np.ndarray + Array of target/actual values. + + Returns + ------- + float + The Root Mean Square Error value. + """ + + if len(pred) != len(tar): + raise ValueError("pred and tar must have the same length") + + return np.sqrt(((pred - tar) ** 2).mean()) + + +def bias(pred: np.ndarray, tar: np.ndarray) -> float: + """ + Calculate the bias between predicted and target values. + + Parameters + ---------- + pred : np.ndarray + Array of predicted values. + tar : np.ndarray + Array of target/actual values. + + Returns + ------- + float + The bias value (mean difference between predictions and targets). + """ + + if len(pred) != len(tar): + raise ValueError("pred and tar must have the same length") + + return sum(pred - tar) / len(pred) + + +def si(pred: np.ndarray, tar: np.ndarray) -> float: + """ + Calculate the Scatter Index between predicted and target values. + + Parameters + ---------- + pred : np.ndarray + Array of predicted values. + tar : np.ndarray + Array of target/actual values. + + Returns + ------- + float + The Scatter Index value. + """ + + if len(pred) != len(tar): + raise ValueError("pred and tar must have the same length") + + pred_mean = pred.mean() + tar_mean = tar.mean() + + return np.sqrt(sum(((pred - pred_mean) - (tar - tar_mean)) ** 2) / (sum(tar**2))) + + +def density_scatter( + x: np.ndarray, y: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute a density scatter for two arrays using gaussian KDE. + + Parameters + ---------- + x : np.ndarray + X values for the scatter plot. + y : np.ndarray + Y values for the scatter plot. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray] + A tuple containing: + - Sorted x values + - Sorted y values + - Density values corresponding to each point + """ + + if len(x) != len(y): + raise ValueError("x and y must have the same length") + + xy = np.vstack([x, y]) + z = gaussian_kde(xy)(xy) + idx = z.argsort() + x1, y1, z = x[idx], y[idx], z[idx] + + return x1, y1, z + + +def validation_scatter( + axs: Axes, + x: np.ndarray, + y: np.ndarray, + xlabel: str, + ylabel: str, + title: str, +) -> None: + """ + Plot a density scatter and Q-Q plot for validation. + + Parameters + ---------- + axs : Axes + Matplotlib axes to plot on. + x : np.ndarray + X values for the scatter plot. + y : np.ndarray + Y values for the scatter plot. + xlabel : str + Label for the X-axis. + ylabel : str + Label for the Y-axis. + title : str + Title for the plot. + """ + + x2, y2, z = density_scatter(x, y) + + # plot + axs.scatter(x2, y2, c=z, s=5, cmap="rainbow") + + # labels + axs.set_xlabel(xlabel) + axs.set_ylabel(ylabel) + axs.set_title(title) + + # axis limits + maxt = np.ceil(max(max(x) + 0.1, max(y) + 0.1)) + axs.set_xlim(0, maxt) + axs.set_ylim(0, maxt) + axs.plot([0, maxt], [0, maxt], "-r") + axs.set_xticks(np.linspace(0, maxt, 5)) + axs.set_yticks(np.linspace(0, maxt, 5)) + axs.set_aspect("equal") + + # qq-plot + xq = probplot(x, dist="norm") + yq = probplot(y, dist="norm") + axs.plot(xq[0][1], yq[0][1], "o", markersize=0.5, color="k", label="Q-Q plot") + + # diagnostic errors + props = dict( + boxstyle="round", facecolor="w", edgecolor="grey", linewidth=0.8, alpha=0.5 + ) + mse = mean_squared_error(x2, y2) + rmse_e = rmse(x2, y2) + BIAS = bias(x2, y2) + SI = si(x2, y2) + label = "\n".join( + ( + r"RMSE = %.2f" % (rmse_e,), + r"mse = %.2f" % (mse,), + r"BIAS = %.2f" % (BIAS,), + R"SI = %.2f" % (SI,), + ) + ) + axs.text( + 0.05, + 0.95, + label, + transform=axs.transAxes, + fontsize=9, + verticalalignment="top", + bbox=props, + ) + + def plot_scatters_in_triangle( dataframes: List[pd.DataFrame], - data_colors: List[str] = default_colors, + data_colors: Optional[List[str]] = None, **kwargs, -) -> Tuple[Figure, Axes]: +) -> Tuple[Figure, np.ndarray]: """ - Plot a scatter plot of the dataframes with axes in a triangle. + Plot scatter plots of the dataframes with axes in a triangle arrangement. Parameters ---------- dataframes : List[pd.DataFrame] - List of dataframes to plot. - data_colors : List[str], optional - List of colors for the dataframes. + List of dataframes to plot. Each dataframe should contain the same columns. + data_colors : Optional[List[str]], optional + List of colors for the dataframes. If None, uses default_colors. **kwargs : dict, optional - Keyword arguments for the scatter plot. Will be passed to the - DefaultStaticPlotting.plot_scatter method, which is the same - as the one in matplotlib.pyplot.scatter. - For example, to change the marker size, you can use: - ``plot_scatters_in_triangle(dataframes, s=10)`` + Additional keyword arguments for the scatter plot. These will be passed to + matplotlib.pyplot.scatter. Common parameters include: + - s : float, marker size + - alpha : float, transparency + - marker : str, marker style Returns ------- - fig : Figure - Figure object. - axes : Axes - Axes object. + Tuple[Figure, np.ndarray] + A tuple containing: + - Figure object + - 2D array of Axes objects + + Raises + ------ + ValueError + If the variables in the first dataframe are not present in all other dataframes. """ + if data_colors is None: + data_colors = default_colors + # Get the number and names of variables from the first dataframe variables_names = list(dataframes[0].columns) num_variables = len(variables_names) diff --git a/bluemath_tk/waves/calibration.py b/bluemath_tk/waves/calibration.py new file mode 100644 index 0000000..6857c1d --- /dev/null +++ b/bluemath_tk/waves/calibration.py @@ -0,0 +1,875 @@ +from typing import Tuple, Union + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import matplotlib as mpl +import numpy as np +import pandas as pd +import statsmodels.api as sm +import xarray as xr +from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from ..core.decorators import validate_data_calval +from ..core.models import BlueMathModel +from ..core.plotting.scatter import density_scatter, validation_scatter + + +def get_matching_times_between_arrays( + times1: np.ndarray, + times2: np.ndarray, + max_time_diff: int, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Finds matching time indices between two arrays of timestamps. + + For each time in `times1`, finds the closest time in `times2` that is within `max_time_diff` hours. + Returns the indices of matching times in both arrays. + + Parameters + ---------- + times1 : np.ndarray + First array of timestamps (reference times, e.g., from model data). + times2 : np.ndarray + Second array of timestamps (e.g., from satellite or validation data). + max_time_diff : int + Maximum time difference in hours for considering times as matching. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Two arrays containing the indices of matching times: + - First array: indices in times1 that have matches + - Second array: corresponding indices in times2 that match + + Example + ------- + >>> idx1, idx2 = get_matching_times_between_arrays( + ... model_df.index.values, + ... sat_df.index.values, + ... max_time_diff=2, + ... ) + """ + + indices1 = np.array([], dtype=int) + indices2 = np.array([], dtype=int) + + for i in range(len(times1)): + # Find minimum time difference for current time1 + time_diffs = np.abs(times2 - times1[i]) + min_diff = np.min(time_diffs) + + # If minimum difference is within threshold, record the indices + if min_diff < np.timedelta64(max_time_diff, "h"): + min_index = np.argmin(time_diffs) + indices1 = np.append(indices1, i) + indices2 = np.append(indices2, min_index) + + return indices1, indices2 + + +def process_imos_satellite_data( + satellite_df: pd.DataFrame, + ini_lat: float, + end_lat: float, + ini_lon: float, + end_lon: float, + depth_threshold: float = -200, +) -> pd.DataFrame: + """ + Processes IMOS satellite data for calibration. + + This function filters and processes IMOS satellite altimeter data to be used as + reference data for calibration (e.g., as `data_to_calibrate` in CalVal.fit). + + Parameters + ---------- + satellite_df : pd.DataFrame + IMOS satellite data. Must contain columns: + - 'LATITUDE' (float): Latitude in decimal degrees + - 'LONGITUDE' (float): Longitude in decimal degrees + - 'SWH_KU_quality_control' (float): Quality control flag for Ku-band + - 'SWH_KA_quality_control' (float): Quality control flag for Ka-band + - 'SWH_KU_CAL' (float): Calibrated significant wave height (Ku-band) + - 'SWH_KA_CAL' (float): Calibrated significant wave height (Ka-band) + - 'BOT_DEPTH' (float): Bathymetry (negative values for ocean) + ini_lat : float + Minimum latitude (southern boundary) for filtering. + end_lat : float + Maximum latitude (northern boundary) for filtering. + ini_lon : float + Minimum longitude (western boundary) for filtering. + end_lon : float + Maximum longitude (eastern boundary) for filtering. + depth_threshold : float, optional + Only include points with BOT_DEPTH < depth_threshold. Default is -200. + + Returns + ------- + pd.DataFrame + Filtered and processed satellite data, suitable for use as `data_to_calibrate` in CalVal.fit. + Includes a new column 'Hs_CAL' (combination of Ku-band and Ka-band calibrated significant wave heights). + + Notes + ----- + The returned DataFrame can be used directly as the `data_to_calibrate` argument in CalVal.fit. + """ + + # Filter satellite data by coordinates + satellite_df = satellite_df[ + (satellite_df.LATITUDE > ini_lat) + & (satellite_df.LATITUDE < end_lat) + & (satellite_df.LONGITUDE > ini_lon) + & (satellite_df.LONGITUDE < end_lon) + & (satellite_df.BOT_DEPTH < depth_threshold) + ] + + # Process quality control + wave_height_qlt = np.nansum( + np.concatenate( + ( + satellite_df["SWH_KU_quality_control"].values[:, np.newaxis], + satellite_df["SWH_KA_quality_control"].values[:, np.newaxis], + ), + axis=1, + ), + axis=1, + ) + good_qlt = np.where(wave_height_qlt < 1.5) + + # Process wave heights + satellite_df["Hs_CAL"] = np.nansum( + np.concatenate( + ( + satellite_df["SWH_KU_CAL"].values[:, np.newaxis], + satellite_df["SWH_KA_CAL"].values[:, np.newaxis], + ), + axis=1, + ), + axis=1, + ) + + return satellite_df.iloc[good_qlt] + + +class CalVal(BlueMathModel): + """ + Calibrates wave data using reference data. + + This class provides a framework for calibrating wave model outputs (e.g., hindcast or reanalysis) + using reference data (e.g., satellite or buoy observations). + It supports directionally-dependent calibration for both sea and swell components. + + Attributes + ---------- + direction_bin_size : int + Size of directional bins in degrees. + direction_bins : np.ndarray + Array of bin edges for directions. + calibration_model : sm.OLS + The calibration model, more details in `statsmodels.api.OLS`. + calibrated_data : pd.DataFrame + DataFrame with columns ['Hs', 'Hs_CORR', 'Hs_CAL'] after calibration. + The time domain is the same as the model data. + calibration_params : dict + Dictionary with 'sea_correction' and 'swell_correction' correction coefficients. + + Example + ------- + .. jupyter-execute:: + + import pandas as pd + from bluemath_tk.waves.calibration import CalVal, process_imos_satellite_data + + # Load your model data (must have columns: 'Hs', 'Hsea', 'Dirsea', 'Hswell1', 'Dirswell1', ...) + model_df = pd.read_csv('model_data.csv', index_col=0, parse_dates=True) + + # Load IMOS satellite data and process it for calibration + sat_df = pd.read_csv('imos_satellite.csv', index_col=0, parse_dates=True) + data_to_calibrate = process_imos_satellite_data( + sat_df, ini_lat=-40, end_lat=-30, ini_lon=140, end_lon=150 + ) + + # Initialize and fit the calibration + cal = CalVal() + cal.fit( + data=model_df, + data_longitude=145.0, + data_latitude=-35.0, + data_to_calibrate=data_to_calibrate, + max_time_diff=2, + ) + + # Plot results + cal.plot_calibration_results() + + """ + + direction_bin_size: int = 22.5 + direction_bins: np.ndarray = np.arange( + direction_bin_size, 360.5, direction_bin_size + ) + + def __init__(self) -> None: + """ + Initialize the CalVal class. + """ + + super().__init__() + self.set_logger_name(name="CalVal", level="INFO", console=True) + + # Save input data + self._data: pd.DataFrame = None + self._data_longitude: float = None + self._data_latitude: float = None + self._data_to_calibrate: pd.DataFrame = None + self._max_time_diff: int = None + + # Initialize calibration results + self._data_to_fit: Tuple[pd.DataFrame, pd.DataFrame] = (None, None) + self._calibration_model: sm.OLS = None + self._calibrated_data: pd.DataFrame = None + self._calibration_params: pd.Series = None + + # Exclude large attributes from model saving + self._exclude_attributes += [ + "_data", + "_data_to_calibrate", + ] + + @property + def calibration_model(self) -> sm.OLS: + """Returns the calibration model.""" + + if self._calibration_model is None: + raise ValueError( + "Calibration model is not available. Please run the fit method first." + ) + + return self._calibration_model + + @property + def calibrated_data(self) -> pd.DataFrame: + """Returns the calibrated data.""" + + if self._calibrated_data is None: + raise ValueError( + "Calibrated data is not available. Please run the fit method first." + ) + + return self._calibrated_data + + @property + def calibration_params(self) -> pd.Series: + """Returns the calibration parameters.""" + + if self._calibration_params is None: + raise ValueError( + "Calibration parameters are not available. Please run the fit method first." + ) + + return self._calibration_params + + def _plot_data_domains(self) -> Tuple[Figure, Axes]: + """ + Plots the domains of the data points. + + Returns + ------- + Tuple[Figure, Axes] + A tuple containing the figure and axes objects. + """ + + fig, ax = plt.subplots( + figsize=(10, 10), + subplot_kw={ + "projection": ccrs.PlateCarree(central_longitude=self._data_longitude) + }, + ) + land_10m = cfeature.NaturalEarthFeature( + "physical", + "land", + "10m", + edgecolor="face", + facecolor=cfeature.COLORS["land"], + ) + # Plot calibration data + ax.scatter( + self._data_to_calibrate.LONGITUDE, + self._data_to_calibrate.LATITUDE, + s=0.01, + c="k", + transform=ccrs.PlateCarree(), + ) + # Plot main data point + ax.scatter( + self._data_longitude, + self._data_latitude, + s=50, + c="red", + zorder=10, + transform=ccrs.PlateCarree(), + ) + # Set plot extent + ax.set_extent( + [ + self._data_longitude - 2, + self._data_longitude + 2, + self._data_latitude - 2, + self._data_latitude + 2, + ] + ) + ax.add_feature(land_10m) + + return fig, ax + + def _create_vec_direc(self, waves: np.ndarray, direcs: np.ndarray) -> np.ndarray: + """ + Creates a vector of wave heights for each directional bin. + + Parameters + ---------- + waves : np.ndarray + Wave heights. + direcs : np.ndarray + Wave directions in degrees. + + Returns + ------- + np.ndarray + Matrix of wave heights for each directional bin. + """ + + data = np.zeros((len(waves), len(self.direction_bins))) + for i in range(len(waves)): + if direcs[i] < 0: + direcs[i] = direcs[i] + 360 + if direcs[i] > 0 and waves[i] > 0: + bin_idx = int(direcs[i] / self.direction_bin_size) + data[i, bin_idx] = waves[i] + + return data + + @staticmethod + def _get_nparts(data: pd.DataFrame) -> int: + """ + Gets the number of parts in the wave data. + + Parameters + ---------- + data : pd.DataFrame + Wave data. + + Returns + ------- + int + The number of parts in the wave data. + """ + + return len([col for col in data.columns if col.startswith("Hswell")]) + + def _get_joined_sea_swell_data(self, data: pd.DataFrame) -> np.ndarray: + """ + Joins the sea and swell data. + + Parameters + ---------- + data : pd.DataFrame + Wave data. + + Returns + ------- + np.ndarray + The joined sea and swell matrix. + """ + + # Process sea waves + Hsea = self._create_vec_direc(data["Hsea"], data["Dirsea"]) + + # Process swells + Hs_swells = np.zeros(Hsea.shape) + for part in range(1, self._get_nparts(data) + 1): + Hs_swells += ( + self._create_vec_direc(data[f"Hswell{part}"], data[f"Dirswell{part}"]) + ) ** 2 + + # Combine sea and swell matrices + sea_swell_matrix = np.concatenate([Hsea**2, Hs_swells], axis=1) + + return sea_swell_matrix + + @validate_data_calval + def fit( + self, + data: pd.DataFrame, + data_longitude: float, + data_latitude: float, + data_to_calibrate: pd.DataFrame, + max_time_diff: int = 2, + ) -> None: + """ + Calibrate the model data using reference (calibration) data. + + This method matches the model data and calibration data in time, constructs directionally-binned sea and swell matrices, + and fits a linear regression to obtain correction coefficients for each direction bin. + + Parameters + ---------- + data : pd.DataFrame + Model data to calibrate. Must contain columns: + - 'Hs' (float): Significant wave height + - 'Hsea' (float): Sea component significant wave height + - 'Dirsea' (float): Sea component mean direction (degrees) + - 'Hswell1', 'Dirswell1', ... (float): Swell components (at least one required) + The index must be datetime-like. + data_longitude : float + Longitude of the model location (used for plotting and filtering). + data_latitude : float + Latitude of the model location (used for plotting and filtering). + data_to_calibrate : pd.DataFrame + Reference data for calibration. Must contain column: + - 'Hs_CAL' (float): Calibrated significant wave height (e.g., from satellite) + The index must be datetime-like. + max_time_diff : int, optional + Maximum time difference (in hours) allowed when matching model and calibration data. + Default is 2. + + Notes + ----- + After calling this method, the calibration parameters are stored in `self.calibration_params` and the calibrated data + is available in `self.calibrated_data`. + """ + + self.logger.info("Starting calibration fit procedure.") + + # Save input data + self._data = data.copy() + self._data_longitude = data_longitude + self._data_latitude = data_latitude + self._data_to_calibrate = data_to_calibrate.copy() + self._max_time_diff = max_time_diff + + # Plot data domains + self.logger.info("Plotting data domains.") + self._plot_data_domains() + + # Construct matrices for calibration + self.logger.info("Matching times and constructing matrices for calibration.") + + # Get matching times + times_data_to_fit, times_data_to_calibrate = get_matching_times_between_arrays( + self._data.index.values, + self._data_to_calibrate.index.values, + max_time_diff=self._max_time_diff, + ) + self._data_to_fit = ( + self._data.iloc[times_data_to_fit], + self._data_to_calibrate.iloc[times_data_to_calibrate], + ) + + # Get joined sea and swell data + sea_swell_matrix = self._get_joined_sea_swell_data(self._data_to_fit[0]) + + # Perform calibration + self.logger.info("Fitting OLS regression for calibration.") + X = sm.add_constant(sea_swell_matrix) + self._calibration_model = sm.OLS(self._data_to_fit[1]["Hs_CAL"] ** 2, X) + calibrated_model_results = self._calibration_model.fit() + + # Get significant correction coefficients + significant_model_params = [ + model_param + if calibrated_model_results.pvalues[imp] < 0.05 and model_param > 0 + else 1.0 + for imp, model_param in enumerate(calibrated_model_results.params) + ] + + # Save sea and swell correction coefficients + self._calibration_params = { + "sea_correction": { + ip: param + for ip, param in enumerate( + np.sqrt(significant_model_params[: len(self.direction_bins)]) + ) + }, + "swell_correction": { + ip: param + for ip, param in enumerate( + np.sqrt(significant_model_params[len(self.direction_bins) :]) + ) + }, + } + + # Save calibrated data to be used in plot_calibration_results() + self._calibrated_data = self.predict(self._data_to_fit[0]) + self._calibrated_data["Hs_CAL"] = self._data_to_fit[1]["Hs_CAL"].values + + self.logger.info("Calibration fit procedure completed.") + + def predict( + self, data: Union[pd.DataFrame, xr.Dataset] + ) -> Union[pd.DataFrame, xr.Dataset]: + """ + Apply the calibration correction to new data. + + Parameters + ---------- + data : pd.DataFrame or xr.Dataset + Data to correct. If DataFrame, must contain columns: + - 'Hs', 'Hsea', 'Dirsea', 'Hswell1', 'Dirswell1', ... + If xarray.Dataset, must have variables 'efth' and 'dp' (directional spectra). + + Returns + ------- + pd.DataFrame or xr.Dataset + Corrected data. For DataFrame, returns columns ['Hs', 'Hs_CORR'] (original and corrected SWH). + For Dataset, adds variables 'corr_coeffs' and 'corr_efth'. + + Notes + ----- + The correction is directionally dependent and uses the coefficients obtained from `fit`. + """ + + if self._calibration_params is None: + raise ValueError( + "Calibration parameters are not available. Run fit() first." + ) + + if isinstance(data, xr.Dataset): + self.logger.info( + "Input is xarray.Dataset. Applying correction to spectra data." + ) + + corrected_data = data.copy() # Copy data to avoid modifying original data + peak_directions = corrected_data.spec.stats(["dp"]).load() + correction_coeffs = np.ones(peak_directions.dp.shape) + for n_part in peak_directions.part: + if n_part == 0: + correction_coeffs[n_part, :] = np.array( + [ + self.calibration_params["sea_correction"][ + int(peak_direction / self.direction_bin_size) + ] + for peak_direction in peak_directions.isel( + part=n_part + ).dp.values + ] + ) + else: + correction_coeffs[n_part, :] = np.array( + [ + self.calibration_params["swell_correction"][ + int(peak_direction / self.direction_bin_size) + ] + for peak_direction in peak_directions.isel( + part=n_part + ).dp.values + ] + ) + corrected_data["corr_coeffs"] = (("part", "time"), correction_coeffs) + corrected_data["corr_efth"] = ( + corrected_data.efth * corrected_data.corr_coeffs + ) + self.logger.info("Spectra correction complete.") + + return corrected_data + + elif isinstance(data, pd.DataFrame): + self.logger.info( + "Input is pandas.DataFrame. Applying correction to wave data." + ) + + corrected_data = data.copy() + corrected_data["Hsea"] = ( + corrected_data["Hsea"] ** 2 + * np.array( + [ + self.calibration_params["sea_correction"][ + int(peak_direction / self.direction_bin_size) + ] + for peak_direction in corrected_data["Dirsea"] + ] + ) + ** 2 + ) + corrected_data["Hs_CORR"] = corrected_data["Hsea"] + for n_part in range(1, self._get_nparts(corrected_data) + 1): + corrected_data[f"Hswell{n_part}"] = ( + corrected_data[f"Hswell{n_part}"] ** 2 + * np.array( + [ + self.calibration_params["swell_correction"][ + int(peak_direction / self.direction_bin_size) + ] + for peak_direction in corrected_data[f"Dirswell{n_part}"] + ] + ) + ** 2 + ) + corrected_data["Hs_CORR"] += corrected_data[f"Hswell{n_part}"] + + corrected_data["Hs_CORR"] = np.sqrt(corrected_data["Hs_CORR"]) + self.logger.info("Wave data correction complete.") + + return corrected_data[["Hs", "Hs_CORR"]] + + def plot_calibration_results(self) -> Tuple[Figure, list]: + """ + Plot the calibration results, including: + - Pie charts of correction coefficients for sea and swell + - Scatter plots of model vs. reference (before and after correction) + - Polar density plots of sea and swell wave climate + + Returns + ------- + Tuple[Figure, list] + The matplotlib Figure and a list of Axes objects for all subplots. + + Notes + ----- + This function is intended for visual inspection of the calibration quality and the + directional distribution of corrections. + """ + + self.logger.info("Plotting calibration results.") + + fig = plt.figure(figsize=(10, 15)) + gs = fig.add_gridspec(7, 2, wspace=0.4, hspace=0.7) + + # Create subplots with proper projections + ax1 = fig.add_subplot(gs[:2, 0]) # Sea correction pie + ax2 = fig.add_subplot(gs[:2, 1]) # Swell correction pie + ax1_cbar = fig.add_subplot(gs[2, 0]) # Sea correction colorbar + ax2_cbar = fig.add_subplot(gs[2, 1]) # Swell correction colorbar + ax3 = fig.add_subplot(gs[3:5, 0]) # No correction scatter + ax4 = fig.add_subplot(gs[3:5, 1]) # With correction scatter + ax5 = fig.add_subplot(gs[5:7, 0], projection="polar") # Sea climate + ax6 = fig.add_subplot(gs[5:7, 1], projection="polar") # Swell climate + + # Plot sea correction pie chart + sea_norm = 0.3 # Smaller range for sea + sea_fracs = np.repeat(10, len(self.calibration_params["sea_correction"])) + sea_norm = mpl.colors.Normalize(1 - sea_norm, 1 + sea_norm) + sea_cmap = mpl.cm.get_cmap( + "bwr", len(self.calibration_params["sea_correction"]) + ) + sea_colors = sea_cmap( + sea_norm(list(self.calibration_params["sea_correction"].values())) + ) + ax1.pie( + sea_fracs, + labels=None, + colors=sea_colors, + startangle=90, + counterclock=False, + radius=1.2, + ) + ax1.set_title("SEA $Correction$", fontweight="bold") + # Add colorbar for sea correction below the pie chart, shrink it + _sea_cbar = mpl.colorbar.ColorbarBase( + ax1_cbar, + cmap=sea_cmap, + norm=sea_norm, + orientation="horizontal", + label="Correction Factor", + ) + box = ax1_cbar.get_position() + ax1_cbar.set_position( + [ + box.x0 + 0.15 * box.width, + box.y0 + 0.3 * box.height, + 0.7 * box.width, + 0.4 * box.height, + ] + ) + ax1_cbar.set_frame_on(False) + ax1_cbar.tick_params( + left=False, right=False, labelleft=False, labelbottom=True, bottom=True + ) + + # Plot swell correction pie chart + swell_norm = 0.6 # Larger range for swell + swell_fracs = np.repeat(10, len(self.calibration_params["swell_correction"])) + swell_norm = mpl.colors.Normalize(1 - swell_norm, 1 + swell_norm) + swell_cmap = mpl.cm.get_cmap( + "bwr", len(self.calibration_params["swell_correction"]) + ) + swell_colors = swell_cmap( + swell_norm(list(self.calibration_params["swell_correction"].values())) + ) + ax2.pie( + swell_fracs, + labels=None, + colors=swell_colors, + startangle=90, + counterclock=False, + radius=1.2, + ) + ax2.set_title("SWELL $Correction$", fontweight="bold") + # Add colorbar for swell correction below the pie chart, shrink it + _swell_cbar = mpl.colorbar.ColorbarBase( + ax2_cbar, + cmap=swell_cmap, + norm=swell_norm, + orientation="horizontal", + label="Correction Factor", + ) + box = ax2_cbar.get_position() + ax2_cbar.set_position( + [ + box.x0 + 0.15 * box.width, + box.y0 + 0.3 * box.height, + 0.7 * box.width, + 0.4 * box.height, + ] + ) + ax2_cbar.set_frame_on(False) + ax2_cbar.tick_params( + left=False, right=False, labelleft=False, labelbottom=True, bottom=True + ) + + # Plot no correction scatter + validation_scatter( + axs=ax3, + x=self._calibrated_data["Hs"].values, + y=self._calibrated_data["Hs_CAL"].values, + xlabel="Hindcast", + ylabel="Satellite", + title="No Correction", + ) + + # Plot with correction scatter + validation_scatter( + axs=ax4, + x=self._calibrated_data["Hs_CORR"].values, + y=self._calibrated_data["Hs_CAL"].values, + xlabel="Hindcast", + ylabel="Satellite", + title="With Correction", + ) + + # Plot sea wave climate + x, y, z = density_scatter( + self._data["Dirsea"] * np.pi / 180, + self._data["Hsea"], + ) + ax5.scatter(x, y, c=z, s=3, cmap="jet") + ax5.set_theta_zero_location("N", offset=0) + ax5.set_xticklabels(["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) + ax5.xaxis.grid(True, color="lavender", linestyle="-") + ax5.yaxis.grid(True, color="lavender", linestyle="-") + ax5.set_theta_direction(-1) + ax5.set_xlabel("$\u03b8_{m}$ ($\degree$)") + ax5.set_ylabel("$H_{s}$ (m)", labelpad=20) + ax5.set_title("SEA $Wave$ $Climate$", pad=15, fontweight="bold") + + # Plot swell wave climate + x, y, z = density_scatter( + self._data["Dirswell1"] * np.pi / 180, + self._data["Hswell1"], + ) + ax6.scatter(x, y, c=z, s=3, cmap="jet") + ax6.set_theta_zero_location("N", offset=0) + ax6.set_xticklabels(["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) + ax6.xaxis.grid(True, color="lavender", linestyle="-") + ax6.yaxis.grid(True, color="lavender", linestyle="-") + ax6.set_theta_direction(-1) + ax6.set_xlabel("$\u03b8_{m}$ ($\degree$)") + ax6.set_ylabel("$H_{s}$ (m)", labelpad=20) + ax6.set_title("SWELL 1 $Wave$ $Climate$", pad=15, fontweight="bold") + + return fig, [ax1, ax2, ax1_cbar, ax2_cbar, ax3, ax4, ax5, ax6] + + def validate_calibration( + self, data_to_validate: pd.DataFrame + ) -> Tuple[Figure, list]: + """ + Validate the calibration using independent validation data. + + This method compares the original and corrected model data to the validation data, + both as time series and with scatter plots. + + Parameters + ---------- + data_to_validate : pd.DataFrame + Validation data. Must contain column: + - 'Hs_VAL' (float): Validation significant wave height (e.g., from buoy) + The index must be datetime-like. + + Returns + ------- + Tuple[Figure, list] + The matplotlib Figure and a list of Axes objects: + [time series axis, scatter (no correction), scatter (corrected)]. + + Notes + ----- + This function is intended for visual validation of the calibration performance. + """ + + if "Hs_VAL" not in data_to_validate.columns: + raise ValueError("Validation data is missing required column: 'Hs_VAL'") + + data_corr = self.predict(data=self._data) + data_times, data_to_validate_times = get_matching_times_between_arrays( + times1=data_corr.index, + times2=data_to_validate.index, + max_time_diff=1, + ) + + # Create figure with a 2-row, 2-column grid, top row spans both columns + fig = plt.figure(figsize=(12, 8)) + gs = fig.add_gridspec(2, 2, height_ratios=[2, 3], hspace=0.4, wspace=0.3) + + # Top row: time series plot (spans both columns) + ax_ts = fig.add_subplot(gs[0, :]) + t = data_corr.index[data_times] + ax_ts.plot( + t, + data_to_validate["Hs_VAL"].iloc[data_to_validate_times], + label="Validation", + color="k", + lw=1.5, + ) + ax_ts.plot( + t, + data_corr["Hs"].iloc[data_times], + label="Model (No Correction)", + color="tab:blue", + alpha=0.7, + ) + ax_ts.plot( + t, + data_corr["Hs_CORR"].iloc[data_times], + label="Model (Corrected)", + color="tab:orange", + alpha=0.7, + ) + ax_ts.set_ylabel("$H_s$ (m)") + ax_ts.set_xlabel("Time") + ax_ts.set_title("Time Series Comparison") + ax_ts.legend(loc="upper right") + ax_ts.grid(True, linestyle=":", alpha=0.5) + + # Bottom row: scatter plots + ax_sc1 = fig.add_subplot(gs[1, 0]) + ax_sc2 = fig.add_subplot(gs[1, 1]) + validation_scatter( + axs=ax_sc1, + x=data_corr["Hs"].iloc[data_times].values, + y=data_to_validate["Hs_VAL"].iloc[data_to_validate_times].values, + xlabel="Model (No Correction)", + ylabel="Validation", + title="No Correction", + ) + validation_scatter( + axs=ax_sc2, + x=data_corr["Hs_CORR"].iloc[data_times].values, + y=data_to_validate["Hs_VAL"].iloc[data_to_validate_times].values, + xlabel="Model (Corrected)", + ylabel="Validation", + title="With Correction", + ) + + return fig, [ax_ts, ax_sc1, ax_sc2]