Skip to content

Commit 73567c5

Browse files
committed
random_bbox_splitting: port to geopandas
1 parent 76a1a6e commit 73567c5

File tree

2 files changed

+54
-32
lines changed

2 files changed

+54
-32
lines changed

tests/datasets/test_splits.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,16 @@ def test_random_bbox_splitting() -> None:
123123
ds_area = total_area(ds)
124124

125125
train_ds, val_ds, test_ds = random_bbox_splitting(
126-
ds, fractions=[1 / 2, 1 / 4, 1 / 4]
126+
ds, fractions=[5 / 8, 2 / 8, 1 / 8]
127127
)
128128
train_ds_area = total_area(train_ds)
129129
val_ds_area = total_area(val_ds)
130130
test_ds_area = total_area(test_ds)
131131

132132
# Check datasets areas
133-
assert train_ds_area == ds_area / 2
134-
assert val_ds_area == ds_area / 4
135-
assert test_ds_area == ds_area / 4
133+
assert isclose(train_ds_area, ds_area * 5 / 8)
134+
assert isclose(val_ds_area, ds_area * 2 / 8)
135+
assert isclose(test_ds_area, ds_area * 1 / 8)
136136

137137
# No overlap
138138
assert no_overlap(train_ds, val_ds)

torchgeo/datasets/splits.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from math import floor, isclose
1111
from typing import cast
1212

13+
import shapely
1314
from rtree.index import Index, Property
15+
from shapely import LineString
1416
from torch import Generator, default_generator, randint, randperm
1517

1618
from ..datasets import GeoDataset
@@ -110,43 +112,63 @@ def random_bbox_splitting(
110112
if any(n <= 0 for n in fractions):
111113
raise ValueError('All items in input fractions must be greater than 0.')
112114

113-
new_indexes = [
114-
Index(interleaved=False, properties=Property(dimension=3)) for _ in fractions
115-
]
115+
new_datasets = [deepcopy(dataset) for _ in fractions]
116116

117-
for i, hit in enumerate(
118-
dataset.index.intersection(dataset.index.bounds, objects=True)
119-
):
120-
box = BoundingBox(*hit.bounds)
121-
fraction_left = 1.0
117+
for i in range(len(dataset)):
118+
geometry_remaining = dataset.index.geometry.iloc[i]
119+
fraction_remaining = 1.0
122120

123121
# Randomly choose the split direction
124122
horizontal, flip = randint(0, 2, (2,), generator=generator)
125123
for j, fraction in enumerate(fractions):
126-
if fraction_left == fraction:
124+
if isclose(fraction_remaining, fraction):
127125
# For the last fraction, no need to split again
128-
new_box = box
129-
elif flip:
130-
# new_box corresponds to fraction, box is the remainder that we might
131-
# split again in the next iteration. Each split is done according to
132-
# fraction wrt what's left
133-
box, new_box = box.split(
134-
(fraction_left - fraction) / fraction_left, horizontal
135-
)
126+
new_geometry = geometry_remaining
136127
else:
137-
# Same as above, but without flipping
138-
new_box, box = box.split(fraction / fraction_left, horizontal)
139-
140-
new_indexes[j].insert(i, tuple(new_box), hit.object)
141-
fraction_left -= fraction
128+
# Create a new_geometry from geometry_remaining
129+
minx, miny, maxx, maxy = geometry_remaining.bounds
130+
131+
if flip:
132+
frac = fraction_remaining - fraction
133+
else:
134+
frac = fraction
135+
136+
if horizontal:
137+
splity = miny + (maxy - miny) * frac / fraction_remaining
138+
line = LineString([(minx, splity), (maxx, splity)])
139+
else:
140+
splitx = minx + (maxx - minx) * frac / fraction_remaining
141+
line = LineString([(splitx, miny), (splitx, maxy)])
142+
143+
geom1, geom2 = shapely.ops.split(geometry_remaining, line).geoms
144+
if horizontal:
145+
if flip:
146+
if geom1.centroid.y < splity:
147+
geometry_remaining, new_geometry = geom1, geom2
148+
else:
149+
new_geometry, geometry_remaining = geom1, geom2
150+
else:
151+
if geom1.centroid.y < splity:
152+
new_geometry, geometry_remaining = geom1, geom2
153+
else:
154+
geometry_remaining, new_geometry = geom1, geom2
155+
else:
156+
if flip:
157+
if geom1.centroid.x < splitx:
158+
geometry_remaining, new_geometry = geom1, geom2
159+
else:
160+
new_geometry, geometry_remaining = geom1, geom2
161+
else:
162+
if geom1.centroid.x < splitx:
163+
new_geometry, geometry_remaining = geom1, geom2
164+
else:
165+
geometry_remaining, new_geometry = geom1, geom2
166+
167+
new_datasets[j].index.iloc[i].geometry = new_geometry
168+
169+
fraction_remaining -= fraction
142170
horizontal = not horizontal
143171

144-
new_datasets = []
145-
for index in new_indexes:
146-
ds = deepcopy(dataset)
147-
ds.index = index
148-
new_datasets.append(ds)
149-
150172
return new_datasets
151173

152174

0 commit comments

Comments
 (0)