Skip to content

Commit ee3fbef

Browse files
authored
Merge pull request #1856 from AdeelH/geom
Improve geometry-related validation in `Scene` and `GeoJSONVectorSource`
2 parents dc23f1f + 9109320 commit ee3fbef

File tree

6 files changed

+93
-21
lines changed

6 files changed

+93
-21
lines changed

Diff for: rastervision_core/rastervision/core/data/scene.py

+5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def __init__(self,
4545
self.aoi_polygons = []
4646
self.aoi_polygons_bbox_coords = []
4747
else:
48+
for p in aoi_polygons:
49+
if p.geom_type not in ['Polygon', 'MultiPolygon']:
50+
raise ValueError(
51+
'Expected all AOI geometries to be Polygons or '
52+
f'MultiPolygons. Found: {p.geom_type}.')
4853
bbox = self.raster_source.bbox
4954
bbox_geom = bbox.to_shapely()
5055
self.aoi_polygons = [

Diff for: rastervision_core/rastervision/core/data/vector_source/geojson_vector_source.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import TYPE_CHECKING, List, Optional, Union
2+
import logging
23

34
from rastervision.pipeline.file_system import download_if_needed, file_to_json
45
from rastervision.core.box import Box
@@ -8,6 +9,8 @@
89
if TYPE_CHECKING:
910
from rastervision.core.data import CRSTransformer, VectorTransformer
1011

12+
log = logging.getLogger(__name__)
13+
1114

1215
class GeoJSONVectorSource(VectorSource):
1316
"""A :class:`.VectorSource` for reading GeoJSON files."""
@@ -49,13 +52,27 @@ def _get_geojson(self) -> dict:
4952
def _get_geojson_single(self, uri: str) -> dict:
5053
# download first so that it gets cached
5154
geojson = file_to_json(download_if_needed(uri))
52-
if not self.ignore_crs_field and 'crs' in geojson:
53-
raise NotImplementedError(
54-
f'The GeoJSON file at {uri} contains a CRS field which '
55-
'is not allowed by the current GeoJSON standard or by '
56-
'Raster Vision. All coordinates are expected to be in '
57-
'EPSG:4326 CRS. If the file uses EPSG:4326 (ie. lat/lng on '
58-
'the WGS84 reference ellipsoid) and you would like to ignore '
59-
'the CRS field, set ignore_crs_field=True in '
60-
'GeoJSONVectorSourceConfig.')
55+
if 'crs' in geojson:
56+
if not self.ignore_crs_field:
57+
raise NotImplementedError(
58+
f'The GeoJSON file at {uri} contains a CRS field which '
59+
'is not allowed by the current GeoJSON standard or by '
60+
'Raster Vision. All coordinates are expected to be in '
61+
'EPSG:4326 CRS. If the file uses EPSG:4326 (ie. lat/lng '
62+
'on the WGS84 reference ellipsoid) and you would like to '
63+
'ignore the CRS field, set ignore_crs_field=True.')
64+
else:
65+
crs = geojson['crs']
66+
log.info(f'Ignoring CRS ({crs}) specified in {uri} '
67+
'and assuming EPSG:4326 instead.')
68+
# Delete the CRS field to avoid discrepancies in case the
69+
# geojson is passed to code that *does* respect the CRS field
70+
# (e.g. geopandas).
71+
del geojson['crs']
72+
# Also delete any "crs" keys in features' properties.
73+
for f in geojson.get('features', []):
74+
try:
75+
del f['properties']['crs']
76+
except KeyError:
77+
pass
6178
return geojson

Diff for: rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/utils/aoi_sampler.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Sequence, Tuple
1+
from typing import Sequence, Tuple, Union
22

33
import numpy as np
4-
from shapely.geometry import Polygon, MultiPolygon
4+
from shapely.geometry import Polygon, MultiPolygon, LinearRing
55
from shapely.ops import unary_union
66
from triangle import triangulate
77

@@ -86,7 +86,9 @@ def triangulate_polygon(self, polygon: Polygon) -> dict:
8686

8787
# the triangulation algorithm requires a sample point inside each
8888
# hole
89-
hole_centroids = np.stack([hole.centroid for hole in holes])
89+
hole_centroids = [hole.centroid for hole in holes]
90+
hole_centroids = np.concatenate(
91+
[np.array(c.coords) for c in hole_centroids], axis=0)
9092

9193
args = {
9294
'vertices': vertices,
@@ -108,18 +110,20 @@ def triangulate_polygon(self, polygon: Polygon) -> dict:
108110
}
109111
return out
110112

111-
def polygon_to_graph(self,
112-
polygon: Polygon) -> Tuple[np.ndarray, np.ndarray]:
113+
def polygon_to_graph(self, polygon: Union[Polygon, LinearRing]
114+
) -> Tuple[np.ndarray, np.ndarray]:
113115
"""Given a polygon, return its graph representation.
114116
115117
Args:
116-
polygon (Polygon): A polygon.
118+
polygon (Union[Polygon, LinearRing]): A polygon or
119+
polygon-exterior.
117120
118121
Returns:
119122
Tuple[np.ndarray, np.ndarray]: An (N, 2) array of vertices and
120123
an (N, 2) array of indices to vertices representing edges.
121124
"""
122-
vertices = np.array(polygon.exterior.coords)
125+
exterior = getattr(polygon, 'exterior', polygon)
126+
vertices = np.array(exterior.coords)
123127
# Discard the last vertex - it is a duplicate of the first vertex and
124128
# duplicates cause problems for the Triangle library.
125129
vertices = vertices[:-1]

Diff for: tests/core/data/test_scene.py

+13
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ def test_aoi_polygons(self):
4343
self.assertListEqual(scene.aoi_polygons_bbox_coords,
4444
aoi_polygons_bbox_coords)
4545

46+
def test_invalid_aoi_polygons(self):
47+
bbox = Box(100, 100, 200, 200)
48+
rs = RasterioSource(self.img_uri, bbox=bbox)
49+
50+
aoi_polygons = [
51+
Box(50, 50, 150, 150).to_shapely(),
52+
Box(150, 150, 250, 250).to_shapely(),
53+
# not a polygon:
54+
Box(150, 150, 250, 250).to_shapely().exterior,
55+
]
56+
with self.assertRaises(ValueError):
57+
_ = Scene(id='', raster_source=rs, aoi_polygons=aoi_polygons)
58+
4659

4760
if __name__ == '__main__':
4861
unittest.main()

Diff for: tests/core/data/vector_source/test_geojson_vector_source.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1+
from typing import Callable
12
import unittest
23
import os
34

45
from shapely.geometry import shape
56

7+
from rastervision.core.data import (
8+
BufferTransformerConfig, ClassConfig, ClassInferenceTransformerConfig,
9+
GeoJSONVectorSource, GeoJSONVectorSourceConfig, IdentityCRSTransformer)
610
from rastervision.core.data.vector_source.geojson_vector_source_config import (
7-
GeoJSONVectorSourceConfig, geojson_vector_source_config_upgrader)
8-
from rastervision.core.data import (ClassConfig, IdentityCRSTransformer,
9-
ClassInferenceTransformerConfig,
10-
BufferTransformerConfig)
11+
geojson_vector_source_config_upgrader)
1112
from rastervision.pipeline.file_system import json_to_file, get_tmp_dir
1213

13-
from tests import test_config_upgrader
14+
from tests import test_config_upgrader, data_file_path
1415
from tests.core.data.mock_crs_transformer import DoubleCRSTransformer
1516

1617

@@ -30,6 +31,12 @@ def test_upgrader(self):
3031
class TestGeoJSONVectorSource(unittest.TestCase):
3132
"""This also indirectly tests the ClassInference class."""
3233

34+
def assertNoError(self, fn: Callable, msg: str = ''):
35+
try:
36+
fn()
37+
except Exception:
38+
self.fail(msg)
39+
3340
def setUp(self):
3441
self.tmp_dir = get_tmp_dir()
3542
self.uri = os.path.join(self.tmp_dir.name, 'vectors.json')
@@ -155,6 +162,19 @@ def test_transform_polygon(self):
155162
trans_geom = trans_geojson['features'][0]['geometry']
156163
self.assertTrue(shape(geom).equals(shape(trans_geom)))
157164

165+
def test_ignore_crs_field(self):
166+
uri = data_file_path('0-aoi.geojson')
167+
crs_transformer = IdentityCRSTransformer()
168+
169+
vs = GeoJSONVectorSource(uri, crs_transformer=crs_transformer)
170+
with self.assertRaises(NotImplementedError):
171+
_ = vs.get_geojson()
172+
173+
vs = GeoJSONVectorSource(
174+
uri, crs_transformer=crs_transformer, ignore_crs_field=True)
175+
self.assertNoError(lambda: vs.get_geojson())
176+
self.assertNotIn('crs', vs.get_geojson())
177+
158178

159179
if __name__ == '__main__':
160180
unittest.main()

Diff for: tests/pytorch_learner/dataset/test_aoi_sampler.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Callable
12
import unittest
23
from itertools import product
34

@@ -9,6 +10,18 @@
910

1011

1112
class TestAoiSampler(unittest.TestCase):
13+
def assertNoError(self, fn: Callable, msg: str = ''):
14+
try:
15+
fn()
16+
except Exception:
17+
self.fail(msg)
18+
19+
def test_polygon_with_holes(self):
20+
p1 = Polygon.from_bounds(0, 0, 20, 20)
21+
p2 = Polygon.from_bounds(5, 5, 15, 15)
22+
polygon_with_holes = p1 - p2
23+
self.assertNoError(lambda: AoiSampler([polygon_with_holes]).sample())
24+
1225
def test_sampler(self, nsamples: int = 200):
1326
"""Attempt to check if points are distributed uniformly within an AOI.
1427

0 commit comments

Comments
 (0)