Skip to content

Commit bd60d48

Browse files
adds -g/--geo enum for GeoParquet output
1 parent cab1ad6 commit bd60d48

23 files changed

+518
-26
lines changed

poetry.lock

Lines changed: 122 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ python-geohash = "^0.8"
3737
maidenhead = "^1.8"
3838
s2sphere = "^0.2"
3939
pya5 = "^0.5"
40+
geoarrow-pyarrow = "^0.2.0"
4041

4142
[tool.poetry.group.dev.dependencies]
4243
pytest = "^7.2.2"

raster2dggs/a5.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@
8989
is_flag=True,
9090
help=const.OPTION_HELP["compact"],
9191
)
92+
@click.option(
93+
"-g",
94+
"--geo",
95+
default=const.DEFAULTS["geo"],
96+
type=click.Choice(const.GEOM_TYPES),
97+
)
9298
@click.option(
9399
"--tempdir",
94100
default=const.DEFAULTS["tempdir"],
@@ -111,6 +117,7 @@ def a5(
111117
warp_mem_limit: int,
112118
resampling: str,
113119
compact: bool,
120+
geo: str,
114121
tempdir: Union[str, Path],
115122
):
116123
"""
@@ -136,6 +143,7 @@ def a5(
136143
resampling,
137144
overwrite,
138145
compact,
146+
geo,
139147
)
140148

141149
common.initial_index(

raster2dggs/common.py

Lines changed: 111 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
import tempfile
44
import logging
55
import numpy as np
6-
import threading
76
import rioxarray
87
import dask
98
import click_log
9+
import shutil
1010

1111
import rasterio as rio
1212
import pandas as pd
1313
import pyarrow as pa
1414
import pyarrow.parquet as pq
1515
import pyarrow.dataset as ds
16+
import json
17+
import shapely
18+
import pyproj
1619

1720
from typing import Union, Optional, Sequence, Callable
1821
from pathlib import Path
@@ -24,7 +27,7 @@
2427
import dask.dataframe as dd
2528
import xarray as xr
2629

27-
from concurrent.futures import ThreadPoolExecutor, as_completed
30+
from concurrent.futures import ThreadPoolExecutor
2831

2932
from urllib.parse import urlparse
3033
from rasterio.warp import calculate_default_transform
@@ -106,6 +109,7 @@ def assemble_kwargs(
106109
resampling: str,
107110
overwrite: bool,
108111
compact: bool,
112+
geo: str,
109113
) -> dict:
110114
kwargs = {
111115
"upscale": upscale,
@@ -117,6 +121,7 @@ def assemble_kwargs(
117121
"resampling": resampling,
118122
"overwrite": overwrite,
119123
"compact": compact,
124+
"geo": geo if geo != "none" else None,
120125
}
121126

122127
return kwargs
@@ -143,6 +148,72 @@ def get_parent_res(dggs: str, parent_res: Union[None, int], resolution: int) ->
143148
)
144149

145150

151+
def write_partition_as_geoparquet(
152+
pdf: pd.DataFrame,
153+
geom_func,
154+
base_dir: Union[str, Path],
155+
partition_col_name: str,
156+
compression: str,
157+
) -> None:
158+
# Build shapely geometries for this partition
159+
geoms = pdf.index.map(geom_func)
160+
161+
# Compute GeoParquet 1.1.0 extras
162+
valid = [g for g in geoms if (g is not None and not g.is_empty)]
163+
if len(valid):
164+
arr = np.asarray(shapely.bounds(geoms)) # Shapely 2.x vectorized
165+
m = ~np.isnan(arr).any(axis=1)
166+
bbox_vals = arr[m]
167+
bbox = [
168+
float(np.min(bbox_vals[:, 0])),
169+
float(np.min(bbox_vals[:, 1])),
170+
float(np.max(bbox_vals[:, 2])),
171+
float(np.max(bbox_vals[:, 3])),
172+
]
173+
geometry_types = sorted({g.geom_type for g in valid})
174+
else:
175+
bbox = None
176+
geometry_types = []
177+
178+
# Convert to WKB bytes (canonical encoding)
179+
pdf["geometry"] = shapely.to_wkb(geoms, hex=False)
180+
181+
table = pa.Table.from_pandas(pdf, preserve_index=True)
182+
183+
# Ensure geometry is Binary
184+
geom_idx = table.schema.get_field_index("geometry")
185+
if not pa.types.is_binary(table.field(geom_idx).type):
186+
geom_array = pa.array(table.column(geom_idx).to_pylist(), type=pa.binary())
187+
table = table.set_column(geom_idx, "geometry", geom_array)
188+
189+
# GeoParquet 1.1.0 metadata
190+
crs_meta = pyproj.CRS.from_epsg(4326).to_json_dict()
191+
col_meta = {"encoding": "WKB", "crs": crs_meta}
192+
if geometry_types:
193+
col_meta["geometry_types"] = geometry_types
194+
if bbox is not None:
195+
col_meta["bbox"] = bbox
196+
197+
geo_meta = {
198+
"version": "1.1.0",
199+
"primary_column": "geometry",
200+
"columns": {"geometry": col_meta},
201+
}
202+
existing_meta = table.schema.metadata or {}
203+
new_meta = {**existing_meta, b"geo": json.dumps(geo_meta).encode("utf-8")}
204+
table = table.replace_schema_metadata(new_meta)
205+
206+
pq.write_to_dataset(
207+
table,
208+
root_path=str(base_dir),
209+
partition_cols=[partition_col_name],
210+
compression=compression,
211+
basename_template="part.{i}.parquet",
212+
existing_data_behavior="delete_matching",
213+
use_threads=True,
214+
)
215+
216+
146217
def address_boundary_issues(
147218
indexer: RasterIndexer,
148219
pq_input: tempfile.TemporaryDirectory,
@@ -163,6 +234,9 @@ def address_boundary_issues(
163234
IDs being present in different windows of the original image
164235
windows.
165236
"""
237+
if kwargs.get("overwrite", False) and Path(output).exists():
238+
shutil.rmtree(output)
239+
166240
LOGGER.debug(f"Reading Stage 1 output ({pq_input})")
167241
index_col = indexer.index_col(resolution)
168242
partition_col = indexer.partition_col(parent_res)
@@ -202,15 +276,39 @@ def address_boundary_issues(
202276
indexer.compaction, resolution, parent_res, meta=out_meta
203277
)
204278

205-
ddf.to_parquet(
206-
output,
207-
engine="pyarrow",
208-
partition_on=[partition_col],
209-
overwrite=kwargs["overwrite"],
210-
write_index=True,
211-
append=False,
212-
compression=kwargs["compression"],
213-
)
279+
if kwargs["geo"]:
280+
281+
# Create one delayed write task per Dask partition
282+
delayed_parts = ddf.to_delayed()
283+
284+
geo_serialisation_method = (
285+
indexer.cell_to_polygon
286+
if kwargs["geo"] == "polygon"
287+
else indexer.cell_to_point
288+
)
289+
290+
write_tasks = [
291+
dask.delayed(write_partition_as_geoparquet)(
292+
part, geo_serialisation_method, output, partition_col, kwargs["compression"]
293+
)
294+
for part in delayed_parts
295+
]
296+
297+
# Execute writes with progress
298+
with TqdmCallback(desc="Writing GeoParquet"):
299+
dask.compute(*write_tasks)
300+
301+
else:
302+
303+
ddf.to_parquet(
304+
output,
305+
engine="pyarrow",
306+
partition_on=[partition_col],
307+
overwrite=kwargs["overwrite"],
308+
write_index=True,
309+
append=False,
310+
compression=kwargs["compression"],
311+
)
214312

215313
LOGGER.debug("Stage 2 (aggregation) complete")
216314

@@ -364,9 +462,9 @@ def process(window):
364462
root_path=tmpdir,
365463
partition_cols=[partition_col],
366464
basename_template=str(window.col_off)
367-
+ "_"
465+
+ "."
368466
+ str(window.row_off)
369-
+ "_guid-{i}.parquet",
467+
+ ".{i}.parquet",
370468
use_threads=False, # Already threading indexing and reading
371469
existing_data_behavior="overwrite_or_ignore", # Overwrite files with the same name; other existing files are ignored. Allows for an append workflow
372470
compression=compression,

0 commit comments

Comments
 (0)