|
| 1 | +import os |
| 2 | +import errno |
| 3 | +import tempfile |
| 4 | +import logging |
| 5 | +import threading |
| 6 | +import rioxarray |
| 7 | +import dask |
| 8 | +import click_log |
| 9 | + |
| 10 | +import rasterio as rio |
| 11 | +import pandas as pd |
| 12 | +import pyarrow.parquet as pq |
| 13 | + |
| 14 | +from typing import Union, Callable |
| 15 | +from pathlib import Path |
| 16 | +from rasterio import crs |
| 17 | +from rasterio.vrt import WarpedVRT |
| 18 | +from rasterio.enums import Resampling |
| 19 | +from tqdm import tqdm |
| 20 | +from tqdm.dask import TqdmCallback |
| 21 | +import dask.dataframe as dd |
| 22 | +import xarray as xr |
| 23 | + |
| 24 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 25 | + |
| 26 | +from urllib.parse import urlparse |
| 27 | +from rasterio.warp import calculate_default_transform |
| 28 | + |
| 29 | +import raster2dggs.constants as const |
| 30 | + |
| 31 | +LOGGER = logging.getLogger(__name__) |
| 32 | +click_log.basic_config(LOGGER) |
| 33 | + |
| 34 | + |
| 35 | +class ParentResolutionException(Exception): |
| 36 | + pass |
| 37 | + |
| 38 | + |
| 39 | +def check_resolutions(resolution: int, parent_res: int) -> None: |
| 40 | + if parent_res is not None and not int(parent_res) < int(resolution): |
| 41 | + raise ParentResolutionException( |
| 42 | + "Parent resolution ({pr}) must be less than target resolution ({r})".format( |
| 43 | + pr=parent_res, r=resolution |
| 44 | + ) |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +def resolve_input_path(raster_input: Union[str, Path]) -> Union[str, Path]: |
| 49 | + if not Path(raster_input).exists(): |
| 50 | + if not urlparse(raster_input).scheme: |
| 51 | + LOGGER.warning( |
| 52 | + f"Input raster {raster_input} does not exist, and is not recognised as a remote URI" |
| 53 | + ) |
| 54 | + raise FileNotFoundError( |
| 55 | + errno.ENOENT, os.strerror(errno.ENOENT), raster_input |
| 56 | + ) |
| 57 | + # Quacks like a path to remote data |
| 58 | + raster_input = str(raster_input) |
| 59 | + else: |
| 60 | + raster_input = Path(raster_input) |
| 61 | + |
| 62 | + return raster_input |
| 63 | + |
| 64 | + |
| 65 | +def assemble_warp_args(resampling: str, warp_mem_limit: int) -> dict: |
| 66 | + warp_args: dict = { |
| 67 | + "resampling": Resampling._member_map_[resampling], |
| 68 | + "crs": crs.CRS.from_epsg( |
| 69 | + 4326 |
| 70 | + ), # Input raster must be converted to WGS84 (4326) for H3 indexing |
| 71 | + "warp_mem_limit": warp_mem_limit, |
| 72 | + } |
| 73 | + |
| 74 | + return warp_args |
| 75 | + |
| 76 | + |
| 77 | +def create_aggfunc(aggfunc: str) -> str: |
| 78 | + if aggfunc == "mode": |
| 79 | + logging.warning( |
| 80 | + "Mode aggregation: arbitrary behaviour: if there is more than one mode when aggregating, only the first value will be recorded." |
| 81 | + ) |
| 82 | + aggfunc = lambda x: pd.Series.mode(x)[0] |
| 83 | + |
| 84 | + return aggfunc |
| 85 | + |
| 86 | + |
| 87 | +def assemble_kwargs( |
| 88 | + upscale: int, |
| 89 | + compression: str, |
| 90 | + threads: int, |
| 91 | + aggfunc: str, |
| 92 | + decimals: int, |
| 93 | + warp_mem_limit: int, |
| 94 | + resampling: str, |
| 95 | + overwrite: bool, |
| 96 | +) -> dict: |
| 97 | + kwargs = { |
| 98 | + "upscale": upscale, |
| 99 | + "compression": compression, |
| 100 | + "threads": threads, |
| 101 | + "aggfunc": aggfunc, |
| 102 | + "decimals": decimals, |
| 103 | + "warp_mem_limit": warp_mem_limit, |
| 104 | + "resampling": resampling, |
| 105 | + "overwrite": overwrite, |
| 106 | + } |
| 107 | + |
| 108 | + return kwargs |
| 109 | + |
| 110 | + |
| 111 | +def get_parent_res(dggs: str, parent_res: Union[None, int], resolution: int) -> int: |
| 112 | + """ |
| 113 | + Uses a parent resolution, |
| 114 | + OR, |
| 115 | + Given a target resolution, returns our recommended parent resolution. |
| 116 | +
|
| 117 | + Used for intermediate re-partioning. |
| 118 | + """ |
| 119 | + if dggs == "h3": |
| 120 | + return ( |
| 121 | + parent_res |
| 122 | + if parent_res is not None |
| 123 | + else max(const.MIN_H3, (resolution - const.DEFAULT_PARENT_OFFSET)) |
| 124 | + ) |
| 125 | + elif dggs == "rhp": |
| 126 | + return ( |
| 127 | + parent_res |
| 128 | + if parent_res is not None |
| 129 | + else max(const.MIN_RHP, (resolution - const.DEFAULT_PARENT_OFFSET)) |
| 130 | + ) |
| 131 | + else: |
| 132 | + raise RuntimeError( |
| 133 | + "Unknown dggs {dggs}) - must be one of [ 'h3', 'rhp' ]".format(dggs=dggs) |
| 134 | + ) |
| 135 | + |
| 136 | + |
| 137 | +def address_boundary_issues( |
| 138 | + dggs: str, |
| 139 | + parent_groupby: Callable, |
| 140 | + pq_input: tempfile.TemporaryDirectory, |
| 141 | + output: Path, |
| 142 | + resolution: int, |
| 143 | + parent_res: int, |
| 144 | + **kwargs, |
| 145 | +) -> Path: |
| 146 | + """ |
| 147 | + After "stage 1" processing, there is a DGGS cell and band value/s for each pixel in the input image. Partitions are based |
| 148 | + on raster windows. |
| 149 | +
|
| 150 | + This function will re-partition based on parent cell IDs at a fixed offset from the target resolution. |
| 151 | +
|
| 152 | + Once re-partitioned on this basis, values are aggregated at the target resolution, to account for multiple pixels mapping |
| 153 | + to the same cell. |
| 154 | +
|
| 155 | + This re-partitioning is necessary to address the issue of the same cell IDs being present in different partitions |
| 156 | + of the original (i.e. window-based) partitioning. Using the nested structure of the DGGS is an useful property |
| 157 | + to address this problem. |
| 158 | + """ |
| 159 | + parent_res = get_parent_res(dggs, parent_res, resolution) |
| 160 | + |
| 161 | + LOGGER.debug( |
| 162 | + f"Reading Stage 1 output ({pq_input}) and setting index for parent-based partitioning" |
| 163 | + ) |
| 164 | + with TqdmCallback(desc="Reading window partitions"): |
| 165 | + # Set index as parent cell |
| 166 | + ddf = dd.read_parquet(pq_input).set_index(f"{dggs}_{parent_res:02}") |
| 167 | + |
| 168 | + with TqdmCallback(desc="Counting parents"): |
| 169 | + # Count parents, to get target number of partitions |
| 170 | + uniqueparents = sorted(list(ddf.index.unique().compute())) |
| 171 | + |
| 172 | + LOGGER.debug( |
| 173 | + "Repartitioning into %d partitions, based on parent cells", |
| 174 | + len(uniqueparents) + 1, |
| 175 | + ) |
| 176 | + LOGGER.debug("Aggregating cell values where conflicts exist") |
| 177 | + |
| 178 | + with TqdmCallback(desc="Repartioning/aggregating"): |
| 179 | + ddf = ( |
| 180 | + ddf.repartition( # See "notes" on why divisions expects repetition of the last item https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.repartition.html |
| 181 | + divisions=(uniqueparents + [uniqueparents[-1]]) |
| 182 | + ) |
| 183 | + .map_partitions( |
| 184 | + parent_groupby, resolution, kwargs["aggfunc"], kwargs["decimals"] |
| 185 | + ) |
| 186 | + .to_parquet( |
| 187 | + output, |
| 188 | + overwrite=kwargs["overwrite"], |
| 189 | + engine="pyarrow", |
| 190 | + write_index=True, |
| 191 | + append=False, |
| 192 | + name_function=lambda i: f"{uniqueparents[i]}.parquet", |
| 193 | + compression=kwargs["compression"], |
| 194 | + ) |
| 195 | + ) |
| 196 | + |
| 197 | + LOGGER.debug( |
| 198 | + "Stage 2 (parent cell repartioning) and Stage 3 (aggregation) complete" |
| 199 | + ) |
| 200 | + |
| 201 | + return output |
| 202 | + |
| 203 | + |
| 204 | +def initial_index( |
| 205 | + dggs: str, |
| 206 | + dggsfunc: Callable, |
| 207 | + parent_groupby: Callable, |
| 208 | + raster_input: Union[Path, str], |
| 209 | + output: Path, |
| 210 | + resolution: int, |
| 211 | + parent_res: Union[None, int], |
| 212 | + warp_args: dict, |
| 213 | + **kwargs, |
| 214 | +) -> Path: |
| 215 | + """ |
| 216 | + Responsible for opening the raster_input, and performing DGGS indexing per window of a WarpedVRT. |
| 217 | +
|
| 218 | + A WarpedVRT is used to enforce reprojection to https://epsg.io/4326, which is required for H3 indexing. |
| 219 | +
|
| 220 | + It also allows on-the-fly resampling of the input, which is useful if the target DGGS resolution exceeds the resolution |
| 221 | + of the input. |
| 222 | +
|
| 223 | + This function passes a path to a temporary directory (which contains the output of this "stage 1" processing) to |
| 224 | + a secondary function that addresses issues at the boundaries of raster windows. |
| 225 | + """ |
| 226 | + parent_res = get_parent_res(dggs, parent_res, resolution) |
| 227 | + LOGGER.info( |
| 228 | + "Indexing %s at %s resolution %d, parent resolution %d", |
| 229 | + raster_input, |
| 230 | + str.upper(dggs), |
| 231 | + resolution, |
| 232 | + parent_res, |
| 233 | + ) |
| 234 | + |
| 235 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 236 | + LOGGER.debug(f"Create temporary directory {tmpdir}") |
| 237 | + |
| 238 | + # https://rasterio.readthedocs.io/en/latest/api/rasterio.warp.html#rasterio.warp.calculate_default_transform |
| 239 | + with rio.Env(CHECK_WITH_INVERT_PROJ=True): |
| 240 | + with rio.open(raster_input) as src: |
| 241 | + LOGGER.debug("Source CRS: %s", src.crs) |
| 242 | + # VRT used to avoid additional disk use given the potential for reprojection to 4326 prior to H3 indexing |
| 243 | + band_names = src.descriptions |
| 244 | + |
| 245 | + upscale_factor = kwargs["upscale"] |
| 246 | + if upscale_factor > 1: |
| 247 | + dst_crs = warp_args["crs"] |
| 248 | + transform, width, height = calculate_default_transform( |
| 249 | + src.crs, |
| 250 | + dst_crs, |
| 251 | + src.width, |
| 252 | + src.height, |
| 253 | + *src.bounds, |
| 254 | + dst_width=src.width * upscale_factor, |
| 255 | + dst_height=src.height * upscale_factor, |
| 256 | + ) |
| 257 | + upsample_args = dict( |
| 258 | + {"transform": transform, "width": width, "height": height} |
| 259 | + ) |
| 260 | + LOGGER.debug(upsample_args) |
| 261 | + else: |
| 262 | + upsample_args = dict({}) |
| 263 | + |
| 264 | + with WarpedVRT( |
| 265 | + src, src_crs=src.crs, **warp_args, **upsample_args |
| 266 | + ) as vrt: |
| 267 | + LOGGER.debug("VRT CRS: %s", vrt.crs) |
| 268 | + da: xr.Dataset = rioxarray.open_rasterio( |
| 269 | + vrt, |
| 270 | + lock=dask.utils.SerializableLock(), |
| 271 | + masked=True, |
| 272 | + default_name=const.DEFAULT_NAME, |
| 273 | + ).chunk(**{"y": "auto", "x": "auto"}) |
| 274 | + |
| 275 | + windows = [window for _, window in vrt.block_windows()] |
| 276 | + LOGGER.debug( |
| 277 | + "%d windows (the same number of partitions will be created)", |
| 278 | + len(windows), |
| 279 | + ) |
| 280 | + |
| 281 | + write_lock = threading.Lock() |
| 282 | + |
| 283 | + def process(window): |
| 284 | + sdf = da.rio.isel_window(window) |
| 285 | + |
| 286 | + result = dggsfunc( |
| 287 | + sdf, |
| 288 | + resolution, |
| 289 | + parent_res, |
| 290 | + vrt.nodata, |
| 291 | + band_labels=band_names, |
| 292 | + ) |
| 293 | + |
| 294 | + with write_lock: |
| 295 | + pq.write_to_dataset( |
| 296 | + result, |
| 297 | + root_path=tmpdir, |
| 298 | + compression=kwargs["compression"], |
| 299 | + ) |
| 300 | + |
| 301 | + return None |
| 302 | + |
| 303 | + with tqdm(total=len(windows), desc="Raster windows") as pbar: |
| 304 | + with ThreadPoolExecutor( |
| 305 | + max_workers=kwargs["threads"] |
| 306 | + ) as executor: |
| 307 | + futures = [ |
| 308 | + executor.submit(process, window) for window in windows |
| 309 | + ] |
| 310 | + for future in as_completed(futures): |
| 311 | + result = future.result() |
| 312 | + pbar.update(1) |
| 313 | + |
| 314 | + LOGGER.debug("Stage 1 (primary indexing) complete") |
| 315 | + return address_boundary_issues( |
| 316 | + dggs, |
| 317 | + parent_groupby, |
| 318 | + tmpdir, |
| 319 | + output, |
| 320 | + resolution, |
| 321 | + parent_res, |
| 322 | + **kwargs, |
| 323 | + ) |
0 commit comments