33import tempfile
44import logging
55import numpy as np
6- import threading
76import rioxarray
87import dask
98import click_log
9+ import shutil
1010
1111import rasterio as rio
1212import pandas as pd
1313import pyarrow as pa
1414import pyarrow .parquet as pq
1515import pyarrow .dataset as ds
16+ import json
17+ import shapely
18+ import pyproj
1619
1720from typing import Union , Optional , Sequence , Callable
1821from pathlib import Path
2427import dask .dataframe as dd
2528import xarray as xr
2629
27- from concurrent .futures import ThreadPoolExecutor , as_completed
30+ from concurrent .futures import ThreadPoolExecutor
2831
2932from urllib .parse import urlparse
3033from rasterio .warp import calculate_default_transform
@@ -106,6 +109,7 @@ def assemble_kwargs(
106109 resampling : str ,
107110 overwrite : bool ,
108111 compact : bool ,
112+ geo : str ,
109113) -> dict :
110114 kwargs = {
111115 "upscale" : upscale ,
@@ -117,6 +121,7 @@ def assemble_kwargs(
117121 "resampling" : resampling ,
118122 "overwrite" : overwrite ,
119123 "compact" : compact ,
124+ "geo" : geo if geo != "none" else None ,
120125 }
121126
122127 return kwargs
@@ -143,6 +148,72 @@ def get_parent_res(dggs: str, parent_res: Union[None, int], resolution: int) ->
143148 )
144149
145150
151+ def write_partition_as_geoparquet (
152+ pdf : pd .DataFrame ,
153+ geom_func ,
154+ base_dir : Union [str , Path ],
155+ partition_col_name : str ,
156+ compression : str ,
157+ ) -> None :
158+ # Build shapely geometries for this partition
159+ geoms = pdf .index .map (geom_func )
160+
161+ # Compute GeoParquet 1.1.0 extras
162+ valid = [g for g in geoms if (g is not None and not g .is_empty )]
163+ if len (valid ):
164+ arr = np .asarray (shapely .bounds (geoms )) # Shapely 2.x vectorized
165+ m = ~ np .isnan (arr ).any (axis = 1 )
166+ bbox_vals = arr [m ]
167+ bbox = [
168+ float (np .min (bbox_vals [:, 0 ])),
169+ float (np .min (bbox_vals [:, 1 ])),
170+ float (np .max (bbox_vals [:, 2 ])),
171+ float (np .max (bbox_vals [:, 3 ])),
172+ ]
173+ geometry_types = sorted ({g .geom_type for g in valid })
174+ else :
175+ bbox = None
176+ geometry_types = []
177+
178+ # Convert to WKB bytes (canonical encoding)
179+ pdf ["geometry" ] = shapely .to_wkb (geoms , hex = False )
180+
181+ table = pa .Table .from_pandas (pdf , preserve_index = True )
182+
183+ # Ensure geometry is Binary
184+ geom_idx = table .schema .get_field_index ("geometry" )
185+ if not pa .types .is_binary (table .field (geom_idx ).type ):
186+ geom_array = pa .array (table .column (geom_idx ).to_pylist (), type = pa .binary ())
187+ table = table .set_column (geom_idx , "geometry" , geom_array )
188+
189+ # GeoParquet 1.1.0 metadata
190+ crs_meta = pyproj .CRS .from_epsg (4326 ).to_json_dict ()
191+ col_meta = {"encoding" : "WKB" , "crs" : crs_meta }
192+ if geometry_types :
193+ col_meta ["geometry_types" ] = geometry_types
194+ if bbox is not None :
195+ col_meta ["bbox" ] = bbox
196+
197+ geo_meta = {
198+ "version" : "1.1.0" ,
199+ "primary_column" : "geometry" ,
200+ "columns" : {"geometry" : col_meta },
201+ }
202+ existing_meta = table .schema .metadata or {}
203+ new_meta = {** existing_meta , b"geo" : json .dumps (geo_meta ).encode ("utf-8" )}
204+ table = table .replace_schema_metadata (new_meta )
205+
206+ pq .write_to_dataset (
207+ table ,
208+ root_path = str (base_dir ),
209+ partition_cols = [partition_col_name ],
210+ compression = compression ,
211+ basename_template = "part.{i}.parquet" ,
212+ existing_data_behavior = "delete_matching" ,
213+ use_threads = True ,
214+ )
215+
216+
146217def address_boundary_issues (
147218 indexer : RasterIndexer ,
148219 pq_input : tempfile .TemporaryDirectory ,
@@ -163,6 +234,9 @@ def address_boundary_issues(
163234 IDs being present in different windows of the original image
164235 windows.
165236 """
237+ if kwargs .get ("overwrite" , False ) and Path (output ).exists ():
238+ shutil .rmtree (output )
239+
166240 LOGGER .debug (f"Reading Stage 1 output ({ pq_input } )" )
167241 index_col = indexer .index_col (resolution )
168242 partition_col = indexer .partition_col (parent_res )
@@ -202,15 +276,39 @@ def address_boundary_issues(
202276 indexer .compaction , resolution , parent_res , meta = out_meta
203277 )
204278
205- ddf .to_parquet (
206- output ,
207- engine = "pyarrow" ,
208- partition_on = [partition_col ],
209- overwrite = kwargs ["overwrite" ],
210- write_index = True ,
211- append = False ,
212- compression = kwargs ["compression" ],
213- )
279+ if kwargs ["geo" ]:
280+
281+ # Create one delayed write task per Dask partition
282+ delayed_parts = ddf .to_delayed ()
283+
284+ geo_serialisation_method = (
285+ indexer .cell_to_polygon
286+ if kwargs ["geo" ] == "polygon"
287+ else indexer .cell_to_point
288+ )
289+
290+ write_tasks = [
291+ dask .delayed (write_partition_as_geoparquet )(
292+ part , geo_serialisation_method , output , partition_col , kwargs ["compression" ]
293+ )
294+ for part in delayed_parts
295+ ]
296+
297+ # Execute writes with progress
298+ with TqdmCallback (desc = "Writing GeoParquet" ):
299+ dask .compute (* write_tasks )
300+
301+ else :
302+
303+ ddf .to_parquet (
304+ output ,
305+ engine = "pyarrow" ,
306+ partition_on = [partition_col ],
307+ overwrite = kwargs ["overwrite" ],
308+ write_index = True ,
309+ append = False ,
310+ compression = kwargs ["compression" ],
311+ )
214312
215313 LOGGER .debug ("Stage 2 (aggregation) complete" )
216314
@@ -364,9 +462,9 @@ def process(window):
364462 root_path = tmpdir ,
365463 partition_cols = [partition_col ],
366464 basename_template = str (window .col_off )
367- + "_ "
465+ + ". "
368466 + str (window .row_off )
369- + "_guid- {i}.parquet" ,
467+ + ". {i}.parquet" ,
370468 use_threads = False , # Already threading indexing and reading
371469 existing_data_behavior = "overwrite_or_ignore" , # Overwrite files with the same name; other existing files are ignored. Allows for an append workflow
372470 compression = compression ,
0 commit comments