|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import rasterio |
| 4 | +import cartopy.crs as ccrs |
| 5 | +import xarray as xr |
| 6 | +from rasterio.transform import from_bounds |
| 7 | +from bluemath_tk.core.plotting.colors import hex_colors_land, hex_colors_water |
| 8 | +from bluemath_tk.core.plotting.utils import join_colormaps |
| 9 | + |
| 10 | +def plot_bathymetry(rasters, figsize=(10, 8), cbar=False, ax=None): |
| 11 | + """ |
| 12 | + Plot a bathymetry map from either a raster file or a NetCDF dataset. |
| 13 | +
|
| 14 | + Parameters |
| 15 | + ---------- |
| 16 | + rasters : str or xarray.Dataset |
| 17 | + Either a path to a raster file readable by rasterio or an xarray Dataset. |
| 18 | + figsize : tuple of float, optional |
| 19 | + Figure size in inches, by default ``(10, 8)``. |
| 20 | + cbar : bool, optional |
| 21 | + If ``True``, display a colorbar. |
| 22 | +
|
| 23 | + Returns |
| 24 | + ------- |
| 25 | + fig : matplotlib.figure.Figure |
| 26 | + The generated figure. |
| 27 | + ax : matplotlib.axes._subplots.AxesSubplot |
| 28 | + The map axis with ``PlateCarree`` projection. |
| 29 | +
|
| 30 | + Examples |
| 31 | + -------- |
| 32 | + >>> import xarray as xr |
| 33 | + >>> ds = xr.open_dataset("GEBCO_sample.nc") |
| 34 | + >>> fig, ax = plot_bathymetry(ds) |
| 35 | +
|
| 36 | + >>> fig, ax = plot_bathymetry("path/to/raster.tif") |
| 37 | + """ |
| 38 | + if isinstance(rasters, str): |
| 39 | + data = [] |
| 40 | + with rasterio.open(rasters) as src: |
| 41 | + raster_data = src.read(1) |
| 42 | + no_data_value = src.nodata |
| 43 | + if no_data_value is not None: |
| 44 | + raster_data = np.ma.masked_equal(raster_data, no_data_value) |
| 45 | + data.append(raster_data) |
| 46 | + transform = src.transform |
| 47 | + height, width = data[0].shape |
| 48 | + |
| 49 | + elif isinstance(rasters, xr.Dataset): |
| 50 | + data, transform, height, width = nc_to_raster_like(rasters, var_name=None) |
| 51 | + else: |
| 52 | + raise TypeError("Input must be a raster path or an xarray Dataset.") |
| 53 | + |
| 54 | + cols, rows = np.meshgrid(np.arange(width), np.arange(height)) |
| 55 | + xs, ys = rasterio.transform.xy(transform, rows, cols) |
| 56 | + |
| 57 | + vmin = np.nanmin(data[0]) |
| 58 | + vmax = np.nanmax(data[0]) |
| 59 | + |
| 60 | + cmap, norm = join_colormaps( |
| 61 | + cmap1=hex_colors_water, |
| 62 | + cmap2=hex_colors_land, |
| 63 | + value_range1=(vmin, 0.0), |
| 64 | + value_range2=(0.0, vmax), |
| 65 | + name="raster_cmap", |
| 66 | + ) |
| 67 | + |
| 68 | + if ax is None: |
| 69 | + fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}, figsize=figsize) |
| 70 | + else: |
| 71 | + fig = ax.figure |
| 72 | + im = ax.imshow( |
| 73 | + data[0], |
| 74 | + cmap=cmap, |
| 75 | + norm=norm, |
| 76 | + extent=(np.min(xs), np.max(xs), np.min(ys), np.max(ys)), |
| 77 | + ) |
| 78 | + if cbar: |
| 79 | + plt.colorbar(im, ax=ax, orientation="vertical", label="Elevation (m)") |
| 80 | + |
| 81 | + return fig, ax |
| 82 | + |
| 83 | +def find_main_data_variable(ds): |
| 84 | + """ |
| 85 | + Find the first variable in the dataset that depends on the coordinate axes. |
| 86 | +
|
| 87 | + Parameters |
| 88 | + ---------- |
| 89 | + ds : xarray.Dataset |
| 90 | + Input dataset. |
| 91 | +
|
| 92 | + Returns |
| 93 | + ------- |
| 94 | + str or None |
| 95 | + Name of the main data variable if found, otherwise ``None``. |
| 96 | + """ |
| 97 | + coord_names = list(ds.coords) |
| 98 | + for var_name, var in ds.data_vars.items(): |
| 99 | + var_dims = set(var.dims) |
| 100 | + if all(c in var_dims for c in coord_names): |
| 101 | + return var_name |
| 102 | + |
| 103 | + best_var = None |
| 104 | + best_score = -1 |
| 105 | + for var_name, var in ds.data_vars.items(): |
| 106 | + score = len(set(var.dims).intersection(coord_names)) |
| 107 | + if score > best_score: |
| 108 | + best_score = score |
| 109 | + best_var = var_name |
| 110 | + return best_var |
| 111 | + |
| 112 | + |
| 113 | +def nc_to_raster_like(ds, var_name=None): |
| 114 | + """ |
| 115 | + Convert an xarray Dataset to a raster-like structure equivalent to a rasterio read. |
| 116 | +
|
| 117 | + This function extracts the main data variable, determines the coordinate |
| 118 | + system (lon/lat), flips the data if needed (north-south inversion), |
| 119 | + and builds the affine transform. |
| 120 | +
|
| 121 | + Parameters |
| 122 | + ---------- |
| 123 | + ds : xarray.Dataset |
| 124 | + Input dataset. |
| 125 | + var_name : str, optional |
| 126 | + Variable name to extract. If ``None``, the first variable depending |
| 127 | + on coordinates will be automatically detected. |
| 128 | +
|
| 129 | + Returns |
| 130 | + ------- |
| 131 | + data : list of numpy.ndarray |
| 132 | + List containing the raster-like array. |
| 133 | + transform : affine.Affine |
| 134 | + Affine transform equivalent to rasterio geotransform. |
| 135 | + height : int |
| 136 | + Number of rows. |
| 137 | + width : int |
| 138 | + Number of columns. |
| 139 | +
|
| 140 | + Raises |
| 141 | + ------ |
| 142 | + ValueError |
| 143 | + If no suitable variable or coordinate system is found. |
| 144 | +
|
| 145 | + Examples |
| 146 | + -------- |
| 147 | + >>> import xarray as xr |
| 148 | + >>> ds = xr.open_dataset("GEBCO_sample.nc") |
| 149 | + >>> data, transform, height, width = nc_to_raster_like(ds) |
| 150 | + """ |
| 151 | + if var_name is None: |
| 152 | + var_name = find_main_data_variable(ds) |
| 153 | + if var_name is None: |
| 154 | + raise ValueError("No variable found depending on lat/lon coordinates.") |
| 155 | + |
| 156 | + da = ds[var_name] |
| 157 | + raster_data = da.values |
| 158 | + if np.isnan(raster_data).any(): |
| 159 | + raster_data = np.ma.masked_invalid(raster_data) |
| 160 | + |
| 161 | + coords_and_vars = list(ds.coords) |
| 162 | + lon_name = next((n for n in coords_and_vars if "lon" in n.lower()), None) |
| 163 | + lat_name = next((n for n in coords_and_vars if "lat" in n.lower()), None) |
| 164 | + if lon_name is None or lat_name is None: |
| 165 | + raise ValueError("Could not detect latitude/longitude coordinates.") |
| 166 | + |
| 167 | + lon = ds[lon_name].values |
| 168 | + lat = ds[lat_name].values |
| 169 | + lon_min, lon_max = lon.min(), lon.max() |
| 170 | + lat_min, lat_max = lat.min(), lat.max() |
| 171 | + width = len(lon) |
| 172 | + height = len(lat) |
| 173 | + |
| 174 | + if lat[0] < lat[-1]: |
| 175 | + raster_data = raster_data[::-1, :] |
| 176 | + lat_min, lat_max = lat_max, lat_min |
| 177 | + |
| 178 | + transform = from_bounds(lon_min, lat_min, lon_max, lat_max, width, height) |
| 179 | + data = [raster_data] |
| 180 | + |
| 181 | + return data, transform, height, width |
| 182 | + |
| 183 | + |
| 184 | +# ======================= |
| 185 | +# Example usage |
| 186 | +# ======================= |
| 187 | +if __name__ == "__main__": |
| 188 | + GEBCO = xr.open_dataset( |
| 189 | + "https://dap.ceda.ac.uk/thredds/dodsC/bodc/gebco/global/gebco_2025/ice_surface_elevation/netcdf/GEBCO_2025.nc" |
| 190 | + ) |
| 191 | + GEBCO_sel = GEBCO.sel(lon=slice(-4, -3), lat=slice(43, 44)) |
| 192 | + |
| 193 | + EMODnet = xr.open_dataset( |
| 194 | + "https://geoocean.sci.unican.es/thredds/dodsC/geoocean/emodnet-bathy-2024" |
| 195 | + ) |
| 196 | + EMODnet_sel = EMODnet.sel(lon=slice(-4, -3), lat=slice(43, 44)) |
| 197 | + |
| 198 | + fig, axes = plt.subplots( |
| 199 | + 1, 2, |
| 200 | + subplot_kw={"projection": ccrs.PlateCarree()}, |
| 201 | + figsize=(14, 6) |
| 202 | + ) |
| 203 | + |
| 204 | + plot_bathymetry(GEBCO_sel, ax=axes[0]) |
| 205 | + axes[0].set_title("GEBCO 2025 Bathymetry") |
| 206 | + |
| 207 | + plot_bathymetry(EMODnet_sel, ax=axes[1]) |
| 208 | + axes[1].set_title("EMODnet 2024 Bathymetry") |
| 209 | + |
| 210 | + # For raster files: |
| 211 | + # plot_bathymetry("path/to/raster.tif", cbar=True) |
0 commit comments