Skip to content

Commit 1e8c54a

Browse files
Merge pull request #39 from manaakiwhenua/a5
--band
2 parents 9abdd50 + ff00ffe commit 1e8c54a

File tree

16 files changed

+233
-163
lines changed

16 files changed

+233
-163
lines changed

README.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,17 @@ Options:
6464
-r, --resolution [0|1|2|3|4|5|6|7|8|9|10|11|12|13|14|15]
6565
H3 resolution to index [required]
6666
-pr, --parent_res [0|1|2|3|4|5|6|7|8|9|10|11|12|13|14|15]
67-
H3 Parent resolution to index and aggregate
67+
H3 parent resolution to index and aggregate
6868
to. Defaults to resolution - 6
69+
-b, --band TEXT Band(s) to include in the output. Can
70+
specify multiple, e.g. `-b 1 -b 2 -b 4` for
71+
bands 1, 2, and 4 (all unspecified bands are
72+
ignored). If unused, all bands are included
73+
in the output (this is the default
74+
behaviour). Bands can be specified as
75+
numeric indices (1-based indexing) or string
76+
band labels (if present in the input), e.g.
77+
-b B02 -b B07 -b B12.
6978
-u, --upscale INTEGER Upscaling factor, used to upsample input
7079
data on the fly; useful when the raster
7180
resolution is lower than the target DGGS
@@ -101,7 +110,7 @@ Options:
101110
is a need to resample. This setting
102111
specifies this resampling algorithm.
103112
[default: average]
104-
-co, --compact Compact the H3 cells up to the parent
113+
-co, --compact Compact the cells up to the parent
105114
resolution. Compaction is not applied for
106115
cells without identical values across all
107116
bands.

raster2dggs/a5.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import tempfile
44

55
from pathlib import Path
6-
from typing import Union
6+
from typing import Optional, Sequence, Union
77
from rasterio.enums import Resampling
88

99
import raster2dggs.constants as const
@@ -20,82 +20,88 @@
2020
"--resolution",
2121
required=True,
2222
type=click.Choice(list(map(str, range(const.MIN_A5, const.MAX_A5 + 1)))),
23-
help="A5 resolution to index",
23+
help=const.OPTION_HELP['resolution']('A5'),
2424
)
2525
@click.option(
2626
"-pr",
2727
"--parent_res",
2828
required=False,
2929
type=click.Choice(list(map(str, range(const.MIN_A5, const.MAX_A5 + 1)))),
30-
help="A5 parent resolution to index and aggregate to. Defaults to resolution - 6",
30+
help=const.OPTION_HELP['parent_res']('A5', 'resolution - 6'),
31+
)
32+
@click.option(
33+
"-b",
34+
"--band",
35+
required=False,
36+
multiple=True,
37+
help=const.OPTION_HELP['band'],
3138
)
3239
@click.option(
3340
"-u",
3441
"--upscale",
3542
default=const.DEFAULTS["upscale"],
3643
type=int,
37-
help="Upscaling factor, used to upsample input data on the fly; useful when the raster resolution is lower than the target DGGS resolution. Default (1) applies no upscaling. The resampling method controls interpolation.",
44+
help=const.OPTION_HELP['upscale'],
3845
)
3946
@click.option(
4047
"-c",
4148
"--compression",
4249
default=const.DEFAULTS["compression"],
4350
type=str,
44-
help="Compression method to use for the output Parquet files. Options include 'snappy', 'gzip', 'brotli', 'lz4', 'zstd', etc. Use 'none' for no compression.",
45-
)
51+
help=const.OPTION_HELP['compression'],)
4652
@click.option(
4753
"-t",
4854
"--threads",
4955
default=const.DEFAULTS["threads"],
50-
help="Number of threads to use when running in parallel. The default is determined based dynamically as the total number of available cores, minus one.",
56+
help=const.OPTION_HELP['threads'],
5157
)
5258
@click.option(
5359
"-a",
5460
"--aggfunc",
5561
default=const.DEFAULTS["aggfunc"],
5662
type=click.Choice(
57-
["count", "mean", "sum", "prod", "std", "var", "min", "max", "median", "mode"]
63+
const.AGGFUNC_OPTIONS
5864
),
59-
help="Numpy aggregate function to apply when aggregating cell values after DGGS indexing, in case of multiple pixels mapping to the same DGGS cell.",
65+
help=const.OPTION_HELP['aggfunc'],
6066
)
6167
@click.option(
6268
"-d",
6369
"--decimals",
6470
default=const.DEFAULTS["decimals"],
6571
type=int,
66-
help="Number of decimal places to round values when aggregating. Use 0 for integer output.",
67-
)
72+
help=const.OPTION_HELP['decimals'],)
6873
@click.option("-o", "--overwrite", is_flag=True)
6974
@click.option(
7075
"--warp_mem_limit",
7176
default=const.DEFAULTS["warp_mem_limit"],
7277
type=int,
73-
help="Input raster may be warped to EPSG:4326 if it is not already in this CRS. This setting specifies the warp operation's memory limit in MB.",
78+
help=const.OPTION_HELP['warp_mem_limit'],
7479
)
7580
@click.option(
7681
"--resampling",
7782
default=const.DEFAULTS["resampling"],
7883
type=click.Choice(Resampling._member_names_),
79-
help="Input raster may be warped to EPSG:4326 if it is not already in this CRS. Or, if the upscale parameter is greater than 1, there is a need to resample. This setting specifies this resampling algorithm.",
84+
help=const.OPTION_HELP['resampling'],
8085
)
8186
@click.option(
8287
"-co",
8388
"--compact",
8489
is_flag=True,
85-
help="Compact the cells up to the parent resolution. Compaction is not applied for cells without identical values across all bands.",
90+
help=const.OPTION_HELP['compact'],
8691
)
8792
@click.option(
8893
"--tempdir",
8994
default=const.DEFAULTS["tempdir"],
9095
type=click.Path(),
91-
help="Temporary data is created during the execution of this program. This parameter allows you to control where this data will be written.",
96+
help=const.OPTION_HELP['tempdir'],
9297
)
9398
@click.version_option(version=__version__)
9499
def a5(
95100
raster_input: Union[str, Path],
96101
output_directory: Union[str, Path],
97102
resolution: str,
98103
parent_res: str,
104+
band: Optional[Sequence[Union[int, str]]],
99105
upscale: int,
100106
compression: str,
101107
threads: int,
@@ -139,5 +145,6 @@ def a5(
139145
int(resolution),
140146
parent_res,
141147
warp_args,
148+
band,
142149
**kwargs,
143150
)

raster2dggs/cli.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,6 @@
88
from raster2dggs.s2 import s2
99
from raster2dggs.a5 import a5
1010

11-
# If the program does terminal interaction, make it output a short
12-
# notice like this when it starts in an interactive mode:
13-
14-
# <program> Copyright (C) <year> <name of author>
15-
# This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
16-
# This is free software, and you are welcome to redistribute it
17-
# under certain conditions; type `show c' for details.
18-
19-
2011
@click.group()
2112
@click.version_option(version=__version__)
2213
def cli():

raster2dggs/common.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import errno
33
import tempfile
44
import logging
5+
import numpy as np
56
import threading
67
import rioxarray
78
import dask
@@ -11,7 +12,7 @@
1112
import pandas as pd
1213
import pyarrow.parquet as pq
1314

14-
from typing import Union, Callable
15+
from typing import Union, Optional, Sequence, Callable
1516
from pathlib import Path
1617
from rasterio import crs
1718
from rasterio.vrt import WarpedVRT
@@ -67,7 +68,7 @@ def resolve_input_path(raster_input: Union[str, Path]) -> Union[str, Path]:
6768

6869
def assemble_warp_args(resampling: str, warp_mem_limit: int) -> dict:
6970
warp_args: dict = {
70-
"resampling": Resampling._member_map_[resampling],
71+
"resampling": Resampling[resampling],
7172
"crs": crs.CRS.from_epsg(
7273
4326
7374
), # Input raster must be converted to WGS84 (4326) for DGGS indexing
@@ -76,13 +77,17 @@ def assemble_warp_args(resampling: str, warp_mem_limit: int) -> dict:
7677

7778
return warp_args
7879

80+
def first_mode(x):
81+
m = pd.Series.mode(x, dropna=False)
82+
# Result is empty if all x is nan
83+
return m.iloc[0] if not m.empty else np.nan
7984

8085
def create_aggfunc(aggfunc: str) -> str:
8186
if aggfunc == "mode":
8287
logging.warning(
8388
"Mode aggregation: arbitrary behaviour: if there is more than one mode when aggregating, only the first value will be recorded."
8489
)
85-
aggfunc = lambda x: pd.Series.mode(x)[0]
90+
aggfunc = first_mode
8691

8792
return aggfunc
8893

@@ -91,7 +96,7 @@ def assemble_kwargs(
9196
upscale: int,
9297
compression: str,
9398
threads: int,
94-
aggfunc: str,
99+
aggfunc: Union[str, Callable],
95100
decimals: int,
96101
warp_mem_limit: int,
97102
resampling: str,
@@ -158,13 +163,12 @@ def address_boundary_issues(
158163
LOGGER.debug(
159164
f"Reading Stage 1 output ({pq_input}) and setting index for parent-based partitioning"
160165
)
161-
with TqdmCallback(desc="Reading window partitions"):
162-
# Set index as parent cell
163-
pad_width = const.zero_padding(indexer.dggs)
164-
index_col = f"{indexer.dggs}_{parent_res:0{pad_width}d}"
165-
ddf = dd.read_parquet(pq_input).set_index(index_col)
166+
# Set index as parent cell
167+
pad_width = const.zero_padding(indexer.dggs)
168+
index_col = f"{indexer.dggs}_{parent_res:0{pad_width}d}"
169+
ddf = dd.read_parquet(pq_input).set_index(index_col)
166170

167-
with TqdmCallback(desc="Counting parents"):
171+
with TqdmCallback(desc="Reading window partitions and counting parents"):
168172
# Count parents, to get target number of partitions
169173
uniqueparents = sorted(list(ddf.index.unique().compute()))
170174

@@ -209,6 +213,7 @@ def initial_index(
209213
resolution: int,
210214
parent_res: Union[None, int],
211215
warp_args: dict,
216+
bands: Optional[Sequence[Union[int, str]]] = None,
212217
**kwargs,
213218
) -> Path:
214219
"""
@@ -237,10 +242,33 @@ def initial_index(
237242

238243
# https://rasterio.readthedocs.io/en/latest/api/rasterio.warp.html#rasterio.warp.calculate_default_transform
239244
with rio.Env(CHECK_WITH_INVERT_PROJ=True):
240-
with rio.open(raster_input) as src:
245+
with rio.open(raster_input, mode='r', sharing=False) as src:
241246
LOGGER.debug("Source CRS: %s", src.crs)
242247
# VRT used to avoid additional disk use given the potential for reprojection to 4326 prior to DGGS indexing
243-
band_names = src.descriptions
248+
band_names = tuple(src.descriptions) if src.descriptions else tuple()
249+
count = src.count # Bands
250+
labels_by_index = {
251+
i: (band_names[i-1] if i-1 < len(band_names) and band_names[i-1] else f"band_{i}")
252+
for i in range(1, count + 1)
253+
}
254+
if not bands: # Covers None or empty tuple
255+
selected_indices = list(range(1, count + 1))
256+
else:
257+
if all(isinstance(b, int) or str(b).isdigit() for b in bands):
258+
selected_indices = list(map(int, bands))
259+
else:
260+
name_to_index = {v: k for k, v in labels_by_index.items()}
261+
try:
262+
selected_indices = [name_to_index[str(b)] for b in bands]
263+
except KeyError as e:
264+
raise ValueError(f"Requested band name not found: {e.args[0]}")
265+
# Validate
266+
for i in selected_indices:
267+
if i < 1 or i > count:
268+
raise ValueError(f"Band index out of range: {i} (1..{count})")
269+
# De-duplicate, preserving order
270+
seen = set()
271+
selected_indices = [i for i in selected_indices if not (i in seen or seen.add(i))]
244272

245273
upscale_factor = kwargs["upscale"]
246274
if upscale_factor > 1:
@@ -262,7 +290,7 @@ def initial_index(
262290
upsample_args = dict({})
263291

264292
with WarpedVRT(
265-
src, src_crs=src.crs, **warp_args, **upsample_args
293+
src, src_crs=src.crs, **warp_args, **upsample_args,
266294
) as vrt:
267295
LOGGER.debug("VRT CRS: %s", vrt.crs)
268296
da: xr.Dataset = rioxarray.open_rasterio(
@@ -272,13 +300,21 @@ def initial_index(
272300
default_name=const.DEFAULT_NAME,
273301
).chunk(**{"y": "auto", "x": "auto"})
274302

303+
# Band selection
304+
if "band" in da.dims and (len(selected_indices) != count):
305+
if "band" in da.coords: # rioxarray commonly exposes 1..N as band coords
306+
da = da.sel(band=selected_indices)
307+
else:
308+
da = da.isel(band=[i - 1 for i in selected_indices])
309+
275310
windows = [window for _, window in vrt.block_windows()]
276311
LOGGER.debug(
277312
"%d windows (the same number of partitions will be created)",
278313
len(windows),
279314
)
280315

281-
write_lock = threading.Lock()
316+
selected_labels = tuple([labels_by_index[i] for i in selected_indices])
317+
compression = kwargs["compression"]
282318

283319
def process(window):
284320
sdf = da.rio.isel_window(window)
@@ -288,28 +324,22 @@ def process(window):
288324
resolution,
289325
parent_res,
290326
vrt.nodata,
291-
band_labels=band_names,
327+
band_labels=selected_labels,
292328
)
293329

294-
with write_lock:
295-
pq.write_to_dataset(
296-
result,
297-
root_path=tmpdir,
298-
compression=kwargs["compression"],
299-
)
330+
pq.write_to_dataset(
331+
result,
332+
root_path=tmpdir,
333+
compression=compression,
334+
)
300335

301336
return None
302337

303-
with tqdm(total=len(windows), desc="Raster windows") as pbar:
304-
with ThreadPoolExecutor(
338+
with ThreadPoolExecutor(
305339
max_workers=kwargs["threads"]
306-
) as executor:
307-
futures = [
308-
executor.submit(process, window) for window in windows
309-
]
310-
for future in as_completed(futures):
311-
result = future.result()
312-
pbar.update(1)
340+
) as executor, tqdm(total=len(windows), desc="Raster windows") as pbar:
341+
for _ in executor.map(process, windows, chunksize=1):
342+
pbar.update(1)
313343

314344
LOGGER.debug("Stage 1 (primary indexing) complete")
315345
return address_boundary_issues(

raster2dggs/constants.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,20 @@ def zero_padding(dggs: str) -> int:
4848
if max_res is None:
4949
raise ValueError(f"Unknown DGGS type: {dggs}")
5050
return len(str(max_res))
51+
52+
OPTION_HELP = {
53+
'resolution': lambda dggs: f"{dggs} resolution to index",
54+
'parent_res': lambda dggs, default: f"{dggs} parent resolution to index and aggregate to. Defaults to {default}",
55+
'band': "Band(s) to include in the output. Can specify multiple, e.g. `-b 1 -b 2 -b 4` for bands 1, 2, and 4 (all unspecified bands are ignored). If unused, all bands are included in the output (this is the default behaviour). Bands can be specified as numeric indices (1-based indexing) or string band labels (if present in the input), e.g. -b B02 -b B07 -b B12.",
56+
'upscale': "Upscaling factor, used to upsample input data on the fly; useful when the raster resolution is lower than the target DGGS resolution. Default (1) applies no upscaling. The resampling method controls interpolation.",
57+
'compression': "Compression method to use for the output Parquet files. Options include 'snappy', 'gzip', 'brotli', 'lz4', 'zstd', etc. Use 'none' for no compression.",
58+
'threads': "Number of threads to use when running in parallel. The default is determined based dynamically as the total number of available cores, minus one.",
59+
'aggfunc': "Numpy aggregate function to apply when aggregating cell values after DGGS indexing, in case of multiple pixels mapping to the same DGGS cell.",
60+
'decimals': "Number of decimal places to round values when aggregating. Use 0 for integer output.",
61+
'warp_mem_limit': "Input raster may be warped to EPSG:4326 if it is not already in this CRS. This setting specifies the warp operation's memory limit in MB.",
62+
'resampling': "Input raster may be warped to EPSG:4326 if it is not already in this CRS. Or, if the upscale parameter is greater than 1, there is a need to resample. This setting specifies this resampling algorithm.",
63+
'compact': "Compact the cells up to the parent resolution. Compaction is not applied for cells without identical values across all bands.",
64+
'tempdir': "Temporary data is created during the execution of this program. This parameter allows you to control where this data will be written."
65+
}
66+
67+
AGGFUNC_OPTIONS = ["count", "mean", "sum", "prod", "std", "var", "min", "max", "median", "mode"]

0 commit comments

Comments
 (0)