Skip to content

Commit 76a1a6e

Browse files
committed
random_bbox_assignment: port to geopandas
1 parent ba45e7b commit 76a1a6e

File tree

2 files changed

+65
-74
lines changed

2 files changed

+65
-74
lines changed

tests/datasets/test_splits.py

Lines changed: 61 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22
# Licensed under the MIT License.
33

44
from collections.abc import Sequence
5+
from datetime import datetime
56
from math import floor, isclose
67
from typing import Any
78

9+
import pandas as pd
810
import pytest
11+
import shapely
12+
from geopandas import GeoDataFrame
913
from pyproj import CRS
14+
from shapely import Geometry
1015

1116
from torchgeo.datasets import (
1217
BoundingBox,
@@ -18,13 +23,12 @@
1823
time_series_split,
1924
)
2025

26+
MINT = datetime(2025, 4, 24)
27+
MAXT = datetime(2025, 4, 25)
2128

22-
def total_area(dataset: GeoDataset) -> float:
23-
total_area = 0.0
24-
for hit in dataset.index.intersection(dataset.index.bounds, objects=True):
25-
total_area += BoundingBox(*hit.bounds).area
2629

27-
return total_area
30+
def total_area(dataset: GeoDataset) -> float:
31+
return dataset.index.geometry.area.sum()
2832

2933

3034
def no_overlap(ds1: GeoDataset, ds2: GeoDataset) -> bool:
@@ -39,20 +43,20 @@ def no_overlap(ds1: GeoDataset, ds2: GeoDataset) -> bool:
3943
class CustomGeoDataset(GeoDataset):
4044
def __init__(
4145
self,
42-
items: list[tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), '')],
43-
crs: CRS = CRS.from_epsg(3005),
44-
res: tuple[float, float] = (1, 1),
46+
index: pd.IntervalIndex | None = None,
47+
geometry: Sequence[Geometry] = [shapely.box(0, 0, 1, 1)],
4548
) -> None:
46-
super().__init__()
47-
for box, content in items:
48-
self.index.insert(0, tuple(box), content)
49-
self._crs = crs
50-
self.res = res
49+
if index is None:
50+
intervals = [(MINT, MAXT)] * len(geometry)
51+
index = pd.IntervalIndex.from_tuples(
52+
intervals, closed='both', name='datetime'
53+
)
54+
crs = CRS.from_epsg(3005)
55+
self.index = GeoDataFrame(index=index, geometry=geometry, crs=crs)
56+
self.res = (1, 1)
5157

5258
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
53-
hits = self.index.intersection(tuple(query), objects=True)
54-
hit = next(iter(hits))
55-
return {'content': hit.object}
59+
return {'index': query}
5660

5761

5862
@pytest.mark.parametrize(
@@ -67,14 +71,13 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
6771
def test_random_bbox_assignment(
6872
lengths: Sequence[int | float], expected_lengths: Sequence[int]
6973
) -> None:
70-
ds = CustomGeoDataset(
71-
[
72-
(BoundingBox(0, 1, 0, 1, 0, 0), 'a'),
73-
(BoundingBox(1, 2, 0, 1, 0, 0), 'b'),
74-
(BoundingBox(2, 3, 0, 1, 0, 0), 'c'),
75-
(BoundingBox(3, 4, 0, 1, 0, 0), 'd'),
76-
]
77-
)
74+
geometry = [
75+
shapely.box(0, 0, 1, 1),
76+
shapely.box(1, 0, 2, 1),
77+
shapely.box(2, 0, 3, 1),
78+
shapely.box(3, 0, 4, 1),
79+
]
80+
ds = CustomGeoDataset(geometry=geometry)
7881

7982
train_ds, val_ds, test_ds = random_bbox_assignment(ds, lengths)
8083

@@ -94,7 +97,6 @@ def test_random_bbox_assignment(
9497
# Test __getitem__
9598
x = train_ds[train_ds.bounds]
9699
assert isinstance(x, dict)
97-
assert isinstance(x['content'], str)
98100

99101

100102
def test_random_bbox_assignment_invalid_inputs() -> None:
@@ -110,14 +112,13 @@ def test_random_bbox_assignment_invalid_inputs() -> None:
110112

111113

112114
def test_random_bbox_splitting() -> None:
113-
ds = CustomGeoDataset(
114-
[
115-
(BoundingBox(0, 1, 0, 1, 0, 0), 'a'),
116-
(BoundingBox(1, 2, 0, 1, 0, 0), 'b'),
117-
(BoundingBox(2, 3, 0, 1, 0, 0), 'c'),
118-
(BoundingBox(3, 4, 0, 1, 0, 0), 'd'),
119-
]
120-
)
115+
geometry = [
116+
shapely.box(0, 0, 1, 1),
117+
shapely.box(1, 0, 2, 1),
118+
shapely.box(2, 0, 3, 1),
119+
shapely.box(3, 0, 4, 1),
120+
]
121+
ds = CustomGeoDataset(geometry=geometry)
121122

122123
ds_area = total_area(ds)
123124

@@ -145,7 +146,6 @@ def test_random_bbox_splitting() -> None:
145146
# Test __get_item__
146147
x = train_ds[train_ds.bounds]
147148
assert isinstance(x, dict)
148-
assert isinstance(x['content'], str)
149149

150150
# Test invalid input fractions
151151
with pytest.raises(ValueError, match='Sum of input fractions must equal 1.'):
@@ -157,12 +157,8 @@ def test_random_bbox_splitting() -> None:
157157

158158

159159
def test_random_grid_cell_assignment() -> None:
160-
ds = CustomGeoDataset(
161-
[
162-
(BoundingBox(0, 12, 0, 12, 0, 0), 'a'),
163-
(BoundingBox(12, 24, 0, 12, 0, 0), 'b'),
164-
]
165-
)
160+
geometry = [shapely.box(0, 0, 12, 12), shapely.box(12, 0, 24, 12)]
161+
ds = CustomGeoDataset(geometry=geometry)
166162

167163
train_ds, val_ds, test_ds = random_grid_cell_assignment(
168164
ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=5
@@ -185,7 +181,6 @@ def test_random_grid_cell_assignment() -> None:
185181
# Test __get_item__
186182
x = train_ds[train_ds.bounds]
187183
assert isinstance(x, dict)
188-
assert isinstance(x['content'], str)
189184

190185
# Test invalid input fractions
191186
with pytest.raises(ValueError, match='Sum of input fractions must equal 1.'):
@@ -199,21 +194,20 @@ def test_random_grid_cell_assignment() -> None:
199194

200195

201196
def test_roi_split() -> None:
202-
ds = CustomGeoDataset(
203-
[
204-
(BoundingBox(0, 1, 0, 1, 0, 0), 'a'),
205-
(BoundingBox(1, 2, 0, 1, 0, 0), 'b'),
206-
(BoundingBox(2, 3, 0, 1, 0, 0), 'c'),
207-
(BoundingBox(3, 4, 0, 1, 0, 0), 'd'),
208-
]
209-
)
197+
geometry = [
198+
shapely.box(0, 0, 1, 1),
199+
shapely.box(1, 0, 2, 1),
200+
shapely.box(2, 0, 3, 1),
201+
shapely.box(3, 0, 4, 1),
202+
]
203+
ds = CustomGeoDataset(geometry=geometry)
210204

211205
train_ds, val_ds, test_ds = roi_split(
212206
ds,
213207
rois=[
214-
BoundingBox(0, 2, 0, 1, 0, 0),
215-
BoundingBox(2, 3.5, 0, 1, 0, 0),
216-
BoundingBox(3.5, 4, 0, 1, 0, 0),
208+
BoundingBox(0, 2, 0, 1, MINT, MAXT),
209+
BoundingBox(2, 3.5, 0, 1, MINT, MAXT),
210+
BoundingBox(3.5, 4, 0, 1, MINT, MAXT),
217211
],
218212
)
219213

@@ -234,7 +228,6 @@ def test_roi_split() -> None:
234228
# Test __get_item__
235229
x = train_ds[train_ds.bounds]
236230
assert isinstance(x, dict)
237-
assert isinstance(x['content'], str)
238231

239232
# Test invalid input rois
240233
with pytest.raises(ValueError, match="ROIs in input rois can't overlap."):
@@ -257,14 +250,23 @@ def test_roi_split() -> None:
257250
def test_time_series_split(
258251
lengths: Sequence[tuple[int, int] | int | float], expected_lengths: Sequence[int]
259252
) -> None:
260-
ds = CustomGeoDataset(
253+
geometry = [
254+
shapely.box(0, 0, 1, 1),
255+
shapely.box(0, 0, 1, 1),
256+
shapely.box(0, 0, 1, 1),
257+
shapely.box(0, 0, 1, 1),
258+
]
259+
index = pd.IntervalIndex.from_tuples(
261260
[
262-
(BoundingBox(0, 1, 0, 1, 0, 10), 'a'),
263-
(BoundingBox(0, 1, 0, 1, 10, 20), 'b'),
264-
(BoundingBox(0, 1, 0, 1, 20, 30), 'c'),
265-
(BoundingBox(0, 1, 0, 1, 30, 40), 'd'),
266-
]
261+
(datetime(2025, 4, 25), datetime(2025, 4, 26)),
262+
(datetime(2025, 4, 26), datetime(2025, 4, 27)),
263+
(datetime(2025, 4, 27), datetime(2025, 4, 28)),
264+
(datetime(2025, 4, 28), datetime(2025, 4, 29)),
265+
],
266+
closed='both',
267+
name='datetime',
267268
)
269+
ds = CustomGeoDataset(index, geometry)
268270

269271
train_ds, val_ds, test_ds = time_series_split(ds, lengths)
270272

@@ -284,7 +286,6 @@ def test_time_series_split(
284286
# Test __get_item__
285287
x = train_ds[train_ds.bounds]
286288
assert isinstance(x, dict)
287-
assert isinstance(x['content'], str)
288289

289290

290291
def test_time_series_split_invalid_input() -> None:

torchgeo/datasets/splits.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
"""Dataset splitting utilities."""
55

6+
import itertools
67
from collections.abc import Sequence
78
from copy import deepcopy
89
from itertools import accumulate
@@ -71,23 +72,12 @@ def random_bbox_assignment(
7172
lengths = _fractions_to_lengths(lengths, len(dataset))
7273
lengths = cast(Sequence[int], lengths)
7374

74-
hits = list(dataset.index.intersection(dataset.index.bounds, objects=True))
75-
76-
hits = [hits[i] for i in randperm(sum(lengths), generator=generator)]
77-
78-
new_indexes = [
79-
Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths
80-
]
81-
82-
for i, length in enumerate(lengths):
83-
for j in range(length):
84-
hit = hits.pop()
85-
new_indexes[i].insert(j, hit.bounds, hit.object)
75+
indices = randperm(sum(lengths), generator=generator)
8676

8777
new_datasets = []
88-
for index in new_indexes:
78+
for offset, length in zip(itertools.accumulate(lengths), lengths):
8979
ds = deepcopy(dataset)
90-
ds.index = index
80+
ds.index = dataset.index.iloc[indices[offset - length : offset]]
9181
new_datasets.append(ds)
9282

9383
return new_datasets

0 commit comments

Comments
 (0)