Skip to content

Commit a4a8c05

Browse files
committed
Add geography serde tests and dataframe api tests for python binding
1 parent 3c58ce7 commit a4a8c05

File tree

7 files changed

+82
-31
lines changed

7 files changed

+82
-31
lines changed

python/sedona/core/geom/geography.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,11 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import pickle
19-
20-
from sedona.utils.decorators import require
18+
from shapely.geometry.base import BaseGeometry
2119

2220

2321
class Geography:
22+
geometry: BaseGeometry
2423

25-
def __init__(self, geometry):
26-
self._geom = geometry
27-
self.userData = None
28-
29-
def getUserData(self):
30-
return self.userData
31-
32-
@classmethod
33-
def from_jvm_instance(cls, java_obj):
34-
return Geography(java_obj.geometry)
35-
36-
@classmethod
37-
def serialize_for_java(cls, geogs):
38-
return pickle.dumps(geogs)
39-
40-
@require(["Geography"])
41-
def create_jvm_instance(self, jvm):
42-
return jvm.Geography(self._geom)
24+
def __init__(self, geometry: BaseGeometry):
25+
self.geometry = geometry

python/sedona/sql/st_constructors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,22 @@ def ST_GeomFromWKT(
176176
return _call_constructor_function("ST_GeomFromWKT", args)
177177

178178

179+
@validate_argument_types
180+
def ST_GeogFromWKT(
181+
wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None
182+
) -> Column:
183+
"""Generate a geography column from a Well-Known Text (WKT) string column.
184+
185+
:param wkt: WKT string column to generate from.
186+
:type wkt: ColumnOrName
187+
:return: Geography column representing the WKT string.
188+
:rtype: Column
189+
"""
190+
args = (wkt) if srid is None else (wkt, srid)
191+
192+
return _call_constructor_function("ST_GeogFromWKT", args)
193+
194+
179195
@validate_argument_types
180196
def ST_GeomFromEWKT(ewkt: ColumnOrName) -> Column:
181197
"""Generate a geometry column from a OGC Extended Well-Known Text (WKT) string column.

python/sedona/sql/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def sqlType(cls):
6868
return BinaryType()
6969

7070
def serialize(self, obj):
71-
return geometry_serde.serialize(obj._geom)
71+
return geometry_serde.serialize(obj.geometry)
7272

7373
def deserialize(self, datum):
7474
geom, offset = geometry_serde.deserialize(datum)

python/sedona/utils/geometry_adapter.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,23 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Union
19-
2018
from shapely.geometry.base import BaseGeometry
2119

2220
from sedona.core.geom.envelope import Envelope
23-
from sedona.core.geom.geography import Geography
2421
from sedona.core.jvm.translate import JvmGeometryAdapter
2522
from sedona.utils.spatial_rdd_parser import GeometryFactory
2623

2724

2825
class GeometryAdapter:
2926

3027
@classmethod
31-
def create_jvm_geometry_from_base_geometry(
32-
cls, jvm, geom: Union[BaseGeometry, Geography]
33-
):
28+
def create_jvm_geometry_from_base_geometry(cls, jvm, geom: BaseGeometry):
3429
"""
3530
:param jvm:
3631
:param geom:
3732
:return:
3833
"""
39-
if isinstance(geom, (Envelope, Geography)):
34+
if isinstance(geom, Envelope):
4035
jvm_geom = geom.create_jvm_instance(jvm)
4136
else:
4237
decoded_geom = GeometryFactory.to_bytes(geom)

python/tests/sql/test_dataframe_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from shapely.geometry.base import BaseGeometry
2727
from tests.test_base import TestBase
2828

29+
from sedona.core.geom.geography import Geography
2930
from sedona.sql import st_aggregates as sta
3031
from sedona.sql import st_constructors as stc
3132
from sedona.sql import st_functions as stf
@@ -85,6 +86,8 @@
8586
(stc.ST_GeomFromWKT, ("wkt",), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"),
8687
(stc.ST_GeomFromWKT, ("wkt", 4326), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"),
8788
(stc.ST_GeomFromEWKT, ("ewkt",), "linestring_ewkt", "", "LINESTRING (1 2, 3 4)"),
89+
(stc.ST_GeogFromWKT, ("wkt",), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"),
90+
(stc.ST_GeogFromWKT, ("wkt", 4326), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"),
8891
(stc.ST_LineFromText, ("wkt",), "linestring_wkt", "", "LINESTRING (1 2, 3 4)"),
8992
(
9093
stc.ST_LineFromWKB,
@@ -1230,6 +1233,7 @@
12301233
(stc.ST_LinestringFromWKB, (None,)),
12311234
(stc.ST_GeomFromEWKB, (None,)),
12321235
(stc.ST_GeomFromWKT, (None,)),
1236+
(stc.ST_GeogFromWKT, (None,)),
12331237
(stc.ST_GeometryFromText, (None,)),
12341238
(stc.ST_LineFromText, (None,)),
12351239
(stc.ST_LineStringFromText, (None, "")),
@@ -1711,6 +1715,9 @@ def test_dataframe_function(
17111715
if isinstance(actual_result, BaseGeometry):
17121716
self.assert_geometry_almost_equal(expected_result, actual_result)
17131717
return
1718+
elif isinstance(actual_result, Geography):
1719+
self.assert_geometry_almost_equal(expected_result, actual_result.geometry)
1720+
return
17141721
elif isinstance(actual_result, bytearray):
17151722
actual_result = actual_result.hex()
17161723
elif isinstance(actual_result, Row):

python/tests/sql/test_geography.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pytest
19+
from pyspark.sql.functions import expr
20+
from pyspark.sql.types import StructType
21+
from shapely.wkt import loads as wkt_loads
22+
from sedona.core.geom.geography import Geography
23+
from sedona.sql.types import GeographyType
24+
from tests.test_base import TestBase
25+
26+
27+
class TestGeography(TestBase):
28+
29+
def test_deserialize_geography(self):
30+
"""Test serialization and deserialization of Geography objects"""
31+
geog_df = self.spark.range(0, 10).withColumn(
32+
"geog", expr("ST_GeogFromWKT(CONCAT('POINT (', id, ' ', id + 1, ')'))")
33+
)
34+
rows = geog_df.collect()
35+
assert len(rows) == 10
36+
for row in rows:
37+
id = row["id"]
38+
geog = row["geog"]
39+
assert geog.geometry.wkt == f"POINT ({id} {id + 1})"
40+
41+
def test_serialize_geography(self):
42+
wkt = "MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((-10 -10, -20 -20, -20 -10, -10 -10)))"
43+
geog = Geography(wkt_loads(wkt))
44+
schema = StructType().add("geog", GeographyType())
45+
returned_geog = self.spark.createDataFrame([(geog,)], schema).take(1)[0][0]
46+
assert geog.geometry.equals(returned_geog.geometry)

spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2692,7 +2692,9 @@ class functionTestScala
26922692
assert(functionDf.first().get(0) == null)
26932693
functionDf = sparkSession.sql("select ST_AsBinary(null)")
26942694
assert(functionDf.first().get(0) == null)
2695-
functionDf = sparkSession.sql("select ST_AsEWKB(null)")
2695+
functionDf = sparkSession.sql("select ST_AsEWKB(ST_GeomFromWKT(null))")
2696+
assert(functionDf.first().get(0) == null)
2697+
functionDf = sparkSession.sql("select ST_AsEWKB(ST_GeogFromWKT(null))")
26962698
assert(functionDf.first().get(0) == null)
26972699
functionDf = sparkSession.sql("select ST_SRID(null)")
26982700
assert(functionDf.first().get(0) == null)
@@ -2764,7 +2766,9 @@ class functionTestScala
27642766
assert(functionDf.first().get(0) == null)
27652767
functionDf = sparkSession.sql("select ST_Reverse(null)")
27662768
assert(functionDf.first().get(0) == null)
2767-
functionDf = sparkSession.sql("select ST_AsEWKT(null)")
2769+
functionDf = sparkSession.sql("select ST_AsEWKT(ST_GeomFromWKT(null))")
2770+
assert(functionDf.first().get(0) == null)
2771+
functionDf = sparkSession.sql("select ST_AsEWKT(ST_GeogFromWKT(null))")
27682772
assert(functionDf.first().get(0) == null)
27692773
functionDf = sparkSession.sql("select ST_Force_2D(null)")
27702774
assert(functionDf.first().get(0) == null)

0 commit comments

Comments
 (0)