Skip to content

Commit 3578b4d

Browse files
authored
fix(rust/sedona-query-planner): Ensure user provided RS_EnsureLoaded call preserves metadata (#969)
1 parent 28c1a27 commit 3578b4d

11 files changed

Lines changed: 798 additions & 295 deletions

File tree

python/sedonadb/python/sedonadb/raster.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import struct
19+
import math
20+
1821
from typing import List, Optional, TYPE_CHECKING, Tuple, Any, Iterable
1922
import geoarrow.types as gat
2023
import pyarrow as pa
@@ -229,10 +232,36 @@ def source_data(self) -> memoryview:
229232
view_scalar = self._array.field("data")[0]
230233
return memoryview(view_scalar.as_buffer())
231234

235+
@property
236+
def source_data_size(self) -> int:
237+
"""The number of bytes consumed by soure_data if it were loaded"""
238+
buffer_type_id = self._py_field("data_type")
239+
buffer_type_char = BAND_DATA_TYPE_STRUCT_CHARS[buffer_type_id]
240+
element_size = struct.calcsize(buffer_type_char)
241+
return math.prod(self.source_shape) * element_size
242+
243+
@property
244+
def data_size(self) -> int:
245+
"""The number of bytes consumed by data if it were loaded"""
246+
buffer_type_id = self._py_field("data_type")
247+
buffer_type_char = BAND_DATA_TYPE_STRUCT_CHARS[buffer_type_id]
248+
element_size = struct.calcsize(buffer_type_char)
249+
return math.prod(self.shape) * element_size
250+
232251
@property
233252
def data(self) -> memoryview:
234253
"""The band data as a typed, shaped memoryview."""
235-
if self.outdb_uri is not None:
254+
buffer_type_id = self._py_field("data_type")
255+
buffer_type_char = BAND_DATA_TYPE_STRUCT_CHARS[buffer_type_id]
256+
257+
# This is not quite right, but shapes that contain zeroes are not well
258+
# supported by the memoryview yet. Callers should check data_size for
259+
# empty handling with non-numpy views.
260+
if self.data_size == 0:
261+
return memoryview(b"")
262+
263+
source_data = self.source_data
264+
if self.outdb_uri is not None and len(source_data) == 0:
236265
raise ValueError("Can't extract buffer from a reference to external data.")
237266

238267
# When views are supported, we would need to calculate the striding
@@ -241,14 +270,15 @@ def data(self) -> memoryview:
241270
if views:
242271
raise NotImplementedError("Lazy views are not yet supported")
243272

244-
buffer_type_id = self._py_field("data_type")
245-
buffer_type_char = BAND_DATA_TYPE_STRUCT_CHARS[buffer_type_id]
246273
return self.source_data.cast(buffer_type_char, self.shape)
247274

248275
def to_numpy(self) -> "np.ndarray":
249276
"""Convert this band's data to a numpy array."""
250277
import numpy as np
251278

279+
if self.data_size == 0:
280+
return np.empty(self.shape, dtype=self.data_type)
281+
252282
return np.array(self.data)
253283

254284
def __repr__(self) -> str:

python/sedonadb/tests/functions/conftest.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

python/sedonadb/tests/functions/test_raster_functions.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,43 +15,44 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
"""Table-driven tests for RS_ accessor functions over an in-DB example raster.
19-
20-
The `rasters` view (see `raster_con` in conftest.py) holds a single
21-
`RS_Example()` raster: 64x32, three UInt8 bands, nodata 127, with a fixed
22-
geotransform (origin (43.08, 79.07), scale 2, skew 1). These exercise the RS_
23-
accessor kernels against the raster Arrow type with no zarr dependency; the
24-
zarr reader path (OutDb chunk anchors, fill_value->nodata, RS_EnsureLoaded) is
25-
tested in the sedonadb-zarr package.
26-
27-
There is no PostGIS twin for these (unlike the geometry function tests), so
28-
plain `con.sql(...)` assertions are the right altitude.
29-
"""
30-
3118
import pytest
3219

33-
34-
def query_value(con, expr):
35-
"""Evaluate `expr` over the single example raster row and return the value."""
36-
table = con.sql(f"SELECT {expr} AS v FROM rasters").to_arrow_table()
37-
return table["v"][0].as_py()
20+
from sedonadb.testing import SedonaDB
3821

3922

4023
@pytest.mark.parametrize(
4124
("expr", "expected"),
4225
[
43-
("RS_NumBands(raster)", 3),
44-
("RS_Width(raster)", 64),
45-
("RS_Height(raster)", 32),
46-
("RS_BandPixelType(raster, 1)", "UNSIGNED_8BITS"),
47-
("RS_BandNoDataValue(raster, 1)", 127.0),
48-
("RS_ScaleX(raster)", 2.0),
49-
("RS_ScaleY(raster)", 2.0),
50-
("RS_SkewX(raster)", 1.0),
51-
("RS_SkewY(raster)", 1.0),
52-
("RS_UpperLeftX(raster)", 43.08),
53-
("RS_UpperLeftY(raster)", 79.07),
26+
("RS_NumBands(RS_Example())", 3),
27+
("RS_Width(RS_Example())", 64),
28+
("RS_Height(RS_Example())", 32),
29+
("RS_BandPixelType(RS_Example(), 1)", "UNSIGNED_8BITS"),
30+
("RS_BandNoDataValue(RS_Example(), 1)", 127.0),
31+
("RS_ScaleX(RS_Example())", 2.0),
32+
("RS_ScaleY(RS_Example())", 2.0),
33+
("RS_SkewX(RS_Example())", 1.0),
34+
("RS_SkewY(RS_Example())", 1.0),
35+
("RS_UpperLeftX(RS_Example())", 43.08),
36+
("RS_UpperLeftY(RS_Example())", 79.07),
5437
],
5538
)
56-
def test_rs_function(raster_con, expr, expected):
57-
assert query_value(raster_con, expr) == expected
39+
def test_rs_function(expr, expected):
40+
eng = SedonaDB()
41+
eng.assert_query_result(f"SELECT {expr}", expected)
42+
43+
44+
def test_rs_ensureloaded(con, sedona_testing):
45+
path = sedona_testing / "data/raster/sentinel2.tif"
46+
t = con.sql("SELECT RS_FromPath($1) AS raster", params=(str(path),))
47+
tab = t.select(raster=t.raster.funcs.rs_ensureloaded()).to_arrow_table()
48+
r = tab["raster"][0].as_py()
49+
assert r.height == 512
50+
assert r.width == 512
51+
52+
assert len(r.bands) == 1
53+
b = r.bands[0]
54+
assert b.shape == (512, 512)
55+
arr = b.to_numpy()
56+
assert arr.shape == (512, 512)
57+
assert arr.dtype == "uint16"
58+
assert arr[0, 0] == 2324

python/sedonadb/tests/test_raster.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def test_raster_accessors(con):
4949
assert b.source_shape == (32, 64)
5050
assert b.outdb_uri is None
5151
assert b.data_type == "uint8"
52+
assert b.source_data_size == 32 * 64 * 1 # uint8 = 1 byte
53+
assert b.data_size == 32 * 64 * 1
5254
assert repr(b) == "<Band uint8 32x64>"
5355

5456
arr = b.to_numpy()
@@ -87,6 +89,8 @@ def test_raster_lazy():
8789
assert b.source_shape == (512, 1024)
8890
assert b.data_type == "float32"
8991
assert b.outdb_uri == "s3://bucket/path/to/data.zarr"
92+
assert b.source_data_size == 512 * 1024 * 4 # float32 = 4 bytes
93+
assert b.data_size == 512 * 1024 * 4
9094

9195
# Lazy raster should have empty data buffer
9296
assert len(b.source_data) == 0
@@ -116,3 +120,22 @@ def test_raster_lazy_invalid_shape():
116120

117121
with pytest.raises(ValueError, match="exactly two dimensions"):
118122
Raster.lazy(uri="s3://bucket/data.zarr", shape=(10, 20, 30), dtype="UInt8")
123+
124+
125+
def test_raster_lazy_zero_size():
126+
"""Test that a raster with zero-size shape returns an empty memoryview."""
127+
r = Raster.lazy(
128+
uri="s3://bucket/empty.zarr",
129+
shape=(0, 64),
130+
dtype="float32",
131+
)
132+
133+
b = r.bands[0]
134+
assert b.source_shape == (0, 64)
135+
assert b.data_size == 0
136+
assert b.source_data_size == 0
137+
assert b.data == memoryview(b"")
138+
139+
arr = b.to_numpy()
140+
assert arr.shape == (0, 64)
141+
assert arr.dtype == "float32"

0 commit comments

Comments
 (0)