Skip to content

Commit 58af1b5

Browse files
Merge pull request #20 from manaakiwhenua/demaion/nics_mods
demaion/nics mods
2 parents 214d231 + e8f9cd8 commit 58af1b5

File tree

10 files changed

+1250
-336
lines changed

10 files changed

+1250
-336
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
tests
22
__pycache__
3-
dist
3+
dist
4+
.conda*

.gitmodules

Whitespace-only changes.

conda_dev.yml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
2+
channels:
3+
- conda-forge
4+
dependencies:
5+
# Required
6+
- python>=3.12
7+
- gdal==3.8.4
8+
- geopandas>=1.0.1
9+
- h3pandas>=0.2.6
10+
- rioxarray>=0.13.4
11+
- dask-geopandas>=0.4.1
12+
- pyarrow>=16.0.0
13+
- dask>=2024.8.0
14+
- click>=8.1.3
15+
- boto3>=1.26.85
16+
- tqdm>=4.66.5
17+
- click-log>=0.4.0
18+
- rasterio>=1.3.6
19+
- dask-expr>=1.1.2
20+
- numpy<2
21+
- pip
22+
- pip:
23+
# Use pip to install rhppandas from dev branch of git repo
24+
- git+https://github.com/manaakiwhenua/rHP-Pandas.git
25+
- git+https://github.com/manaakiwhenua/rhealpixdggs-py.git

poetry.lock

Lines changed: 641 additions & 50 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ maintainers = ["Richard Law <[email protected]>"]
77
readme = "README.md"
88
license = "LGPL-3.0-or-later"
99
repository = "https://github.com/manaakiwhenua/raster2dggs"
10-
keywords = ["dggs", "raster", "h3", "cli"]
10+
keywords = ["dggs", "raster", "h3", "rHEALPix", "cli"]
1111
classifiers = [
1212
"Topic :: Scientific/Engineering",
1313
"Topic :: Scientific/Engineering :: GIS",
@@ -30,6 +30,9 @@ tqdm = "^4.66.4"
3030
click-log = "^0.4.0"
3131
rasterio = "^1.3.6"
3232
dask-expr = "^1.1.2"
33+
numpy = "<2"
34+
rhppandas = { git = "https://github.com/manaakiwhenua/rHP-Pandas.git" }
35+
rhealpixdggs = { git = "https://github.com/manaakiwhenua/rhealpixdggs-py.git" }
3336

3437
[tool.poetry.group.dev.dependencies]
3538
pytest = "^7.2.2"

raster2dggs/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from raster2dggs import __version__
44
from raster2dggs.h3 import h3
5+
from raster2dggs.rHP import rhp
56

67
# If the program does terminal interaction, make it output a short
78
# notice like this when it starts in an interactive mode:
@@ -19,6 +20,7 @@ def cli():
1920

2021

2122
cli.add_command(h3)
23+
cli.add_command(rhp)
2224

2325

2426
def main():

raster2dggs/common.py

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
import os
2+
import errno
3+
import tempfile
4+
import logging
5+
import threading
6+
import rioxarray
7+
import dask
8+
import click_log
9+
10+
import rasterio as rio
11+
import pandas as pd
12+
import pyarrow.parquet as pq
13+
14+
from typing import Union, Callable
15+
from pathlib import Path
16+
from rasterio import crs
17+
from rasterio.vrt import WarpedVRT
18+
from rasterio.enums import Resampling
19+
from tqdm import tqdm
20+
from tqdm.dask import TqdmCallback
21+
import dask.dataframe as dd
22+
import xarray as xr
23+
24+
from concurrent.futures import ThreadPoolExecutor, as_completed
25+
26+
from urllib.parse import urlparse
27+
from rasterio.warp import calculate_default_transform
28+
29+
import raster2dggs.constants as const
30+
31+
LOGGER = logging.getLogger(__name__)
32+
click_log.basic_config(LOGGER)
33+
34+
35+
class ParentResolutionException(Exception):
36+
pass
37+
38+
39+
def check_resolutions(resolution: int, parent_res: int) -> None:
40+
if parent_res is not None and not int(parent_res) < int(resolution):
41+
raise ParentResolutionException(
42+
"Parent resolution ({pr}) must be less than target resolution ({r})".format(
43+
pr=parent_res, r=resolution
44+
)
45+
)
46+
47+
48+
def resolve_input_path(raster_input: Union[str, Path]) -> Union[str, Path]:
49+
if not Path(raster_input).exists():
50+
if not urlparse(raster_input).scheme:
51+
LOGGER.warning(
52+
f"Input raster {raster_input} does not exist, and is not recognised as a remote URI"
53+
)
54+
raise FileNotFoundError(
55+
errno.ENOENT, os.strerror(errno.ENOENT), raster_input
56+
)
57+
# Quacks like a path to remote data
58+
raster_input = str(raster_input)
59+
else:
60+
raster_input = Path(raster_input)
61+
62+
return raster_input
63+
64+
65+
def assemble_warp_args(resampling: str, warp_mem_limit: int) -> dict:
66+
warp_args: dict = {
67+
"resampling": Resampling._member_map_[resampling],
68+
"crs": crs.CRS.from_epsg(
69+
4326
70+
), # Input raster must be converted to WGS84 (4326) for H3 indexing
71+
"warp_mem_limit": warp_mem_limit,
72+
}
73+
74+
return warp_args
75+
76+
77+
def create_aggfunc(aggfunc: str) -> str:
78+
if aggfunc == "mode":
79+
logging.warning(
80+
"Mode aggregation: arbitrary behaviour: if there is more than one mode when aggregating, only the first value will be recorded."
81+
)
82+
aggfunc = lambda x: pd.Series.mode(x)[0]
83+
84+
return aggfunc
85+
86+
87+
def assemble_kwargs(
88+
upscale: int,
89+
compression: str,
90+
threads: int,
91+
aggfunc: str,
92+
decimals: int,
93+
warp_mem_limit: int,
94+
resampling: str,
95+
overwrite: bool,
96+
) -> dict:
97+
kwargs = {
98+
"upscale": upscale,
99+
"compression": compression,
100+
"threads": threads,
101+
"aggfunc": aggfunc,
102+
"decimals": decimals,
103+
"warp_mem_limit": warp_mem_limit,
104+
"resampling": resampling,
105+
"overwrite": overwrite,
106+
}
107+
108+
return kwargs
109+
110+
111+
def get_parent_res(dggs: str, parent_res: Union[None, int], resolution: int) -> int:
112+
"""
113+
Uses a parent resolution,
114+
OR,
115+
Given a target resolution, returns our recommended parent resolution.
116+
117+
Used for intermediate re-partioning.
118+
"""
119+
if dggs == "h3":
120+
return (
121+
parent_res
122+
if parent_res is not None
123+
else max(const.MIN_H3, (resolution - const.DEFAULT_PARENT_OFFSET))
124+
)
125+
elif dggs == "rhp":
126+
return (
127+
parent_res
128+
if parent_res is not None
129+
else max(const.MIN_RHP, (resolution - const.DEFAULT_PARENT_OFFSET))
130+
)
131+
else:
132+
raise RuntimeError(
133+
"Unknown dggs {dggs}) - must be one of [ 'h3', 'rhp' ]".format(dggs=dggs)
134+
)
135+
136+
137+
def address_boundary_issues(
138+
dggs: str,
139+
parent_groupby: Callable,
140+
pq_input: tempfile.TemporaryDirectory,
141+
output: Path,
142+
resolution: int,
143+
parent_res: int,
144+
**kwargs,
145+
) -> Path:
146+
"""
147+
After "stage 1" processing, there is a DGGS cell and band value/s for each pixel in the input image. Partitions are based
148+
on raster windows.
149+
150+
This function will re-partition based on parent cell IDs at a fixed offset from the target resolution.
151+
152+
Once re-partitioned on this basis, values are aggregated at the target resolution, to account for multiple pixels mapping
153+
to the same cell.
154+
155+
This re-partitioning is necessary to address the issue of the same cell IDs being present in different partitions
156+
of the original (i.e. window-based) partitioning. Using the nested structure of the DGGS is an useful property
157+
to address this problem.
158+
"""
159+
parent_res = get_parent_res(dggs, parent_res, resolution)
160+
161+
LOGGER.debug(
162+
f"Reading Stage 1 output ({pq_input}) and setting index for parent-based partitioning"
163+
)
164+
with TqdmCallback(desc="Reading window partitions"):
165+
# Set index as parent cell
166+
ddf = dd.read_parquet(pq_input).set_index(f"{dggs}_{parent_res:02}")
167+
168+
with TqdmCallback(desc="Counting parents"):
169+
# Count parents, to get target number of partitions
170+
uniqueparents = sorted(list(ddf.index.unique().compute()))
171+
172+
LOGGER.debug(
173+
"Repartitioning into %d partitions, based on parent cells",
174+
len(uniqueparents) + 1,
175+
)
176+
LOGGER.debug("Aggregating cell values where conflicts exist")
177+
178+
with TqdmCallback(desc="Repartioning/aggregating"):
179+
ddf = (
180+
ddf.repartition( # See "notes" on why divisions expects repetition of the last item https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.repartition.html
181+
divisions=(uniqueparents + [uniqueparents[-1]])
182+
)
183+
.map_partitions(
184+
parent_groupby, resolution, kwargs["aggfunc"], kwargs["decimals"]
185+
)
186+
.to_parquet(
187+
output,
188+
overwrite=kwargs["overwrite"],
189+
engine="pyarrow",
190+
write_index=True,
191+
append=False,
192+
name_function=lambda i: f"{uniqueparents[i]}.parquet",
193+
compression=kwargs["compression"],
194+
)
195+
)
196+
197+
LOGGER.debug(
198+
"Stage 2 (parent cell repartioning) and Stage 3 (aggregation) complete"
199+
)
200+
201+
return output
202+
203+
204+
def initial_index(
205+
dggs: str,
206+
dggsfunc: Callable,
207+
parent_groupby: Callable,
208+
raster_input: Union[Path, str],
209+
output: Path,
210+
resolution: int,
211+
parent_res: Union[None, int],
212+
warp_args: dict,
213+
**kwargs,
214+
) -> Path:
215+
"""
216+
Responsible for opening the raster_input, and performing DGGS indexing per window of a WarpedVRT.
217+
218+
A WarpedVRT is used to enforce reprojection to https://epsg.io/4326, which is required for H3 indexing.
219+
220+
It also allows on-the-fly resampling of the input, which is useful if the target DGGS resolution exceeds the resolution
221+
of the input.
222+
223+
This function passes a path to a temporary directory (which contains the output of this "stage 1" processing) to
224+
a secondary function that addresses issues at the boundaries of raster windows.
225+
"""
226+
parent_res = get_parent_res(dggs, parent_res, resolution)
227+
LOGGER.info(
228+
"Indexing %s at %s resolution %d, parent resolution %d",
229+
raster_input,
230+
str.upper(dggs),
231+
resolution,
232+
parent_res,
233+
)
234+
235+
with tempfile.TemporaryDirectory() as tmpdir:
236+
LOGGER.debug(f"Create temporary directory {tmpdir}")
237+
238+
# https://rasterio.readthedocs.io/en/latest/api/rasterio.warp.html#rasterio.warp.calculate_default_transform
239+
with rio.Env(CHECK_WITH_INVERT_PROJ=True):
240+
with rio.open(raster_input) as src:
241+
LOGGER.debug("Source CRS: %s", src.crs)
242+
# VRT used to avoid additional disk use given the potential for reprojection to 4326 prior to H3 indexing
243+
band_names = src.descriptions
244+
245+
upscale_factor = kwargs["upscale"]
246+
if upscale_factor > 1:
247+
dst_crs = warp_args["crs"]
248+
transform, width, height = calculate_default_transform(
249+
src.crs,
250+
dst_crs,
251+
src.width,
252+
src.height,
253+
*src.bounds,
254+
dst_width=src.width * upscale_factor,
255+
dst_height=src.height * upscale_factor,
256+
)
257+
upsample_args = dict(
258+
{"transform": transform, "width": width, "height": height}
259+
)
260+
LOGGER.debug(upsample_args)
261+
else:
262+
upsample_args = dict({})
263+
264+
with WarpedVRT(
265+
src, src_crs=src.crs, **warp_args, **upsample_args
266+
) as vrt:
267+
LOGGER.debug("VRT CRS: %s", vrt.crs)
268+
da: xr.Dataset = rioxarray.open_rasterio(
269+
vrt,
270+
lock=dask.utils.SerializableLock(),
271+
masked=True,
272+
default_name=const.DEFAULT_NAME,
273+
).chunk(**{"y": "auto", "x": "auto"})
274+
275+
windows = [window for _, window in vrt.block_windows()]
276+
LOGGER.debug(
277+
"%d windows (the same number of partitions will be created)",
278+
len(windows),
279+
)
280+
281+
write_lock = threading.Lock()
282+
283+
def process(window):
284+
sdf = da.rio.isel_window(window)
285+
286+
result = dggsfunc(
287+
sdf,
288+
resolution,
289+
parent_res,
290+
vrt.nodata,
291+
band_labels=band_names,
292+
)
293+
294+
with write_lock:
295+
pq.write_to_dataset(
296+
result,
297+
root_path=tmpdir,
298+
compression=kwargs["compression"],
299+
)
300+
301+
return None
302+
303+
with tqdm(total=len(windows), desc="Raster windows") as pbar:
304+
with ThreadPoolExecutor(
305+
max_workers=kwargs["threads"]
306+
) as executor:
307+
futures = [
308+
executor.submit(process, window) for window in windows
309+
]
310+
for future in as_completed(futures):
311+
result = future.result()
312+
pbar.update(1)
313+
314+
LOGGER.debug("Stage 1 (primary indexing) complete")
315+
return address_boundary_issues(
316+
dggs,
317+
parent_groupby,
318+
tmpdir,
319+
output,
320+
resolution,
321+
parent_res,
322+
**kwargs,
323+
)

0 commit comments

Comments
 (0)