2
2
# Licensed under the MIT License.
3
3
4
4
from collections .abc import Sequence
5
+ from datetime import datetime
5
6
from math import floor , isclose
6
7
from typing import Any
7
8
9
+ import pandas as pd
8
10
import pytest
11
+ import shapely
12
+ from geopandas import GeoDataFrame
9
13
from pyproj import CRS
14
+ from shapely import Geometry
10
15
11
16
from torchgeo .datasets import (
12
17
BoundingBox ,
18
23
time_series_split ,
19
24
)
20
25
26
+ MINT = datetime (2025 , 4 , 24 )
27
+ MAXT = datetime (2025 , 4 , 25 )
21
28
22
- def total_area (dataset : GeoDataset ) -> float :
23
- total_area = 0.0
24
- for hit in dataset .index .intersection (dataset .index .bounds , objects = True ):
25
- total_area += BoundingBox (* hit .bounds ).area
26
29
27
- return total_area
30
+ def total_area (dataset : GeoDataset ) -> float :
31
+ return dataset .index .geometry .area .sum ()
28
32
29
33
30
34
def no_overlap (ds1 : GeoDataset , ds2 : GeoDataset ) -> bool :
@@ -39,20 +43,20 @@ def no_overlap(ds1: GeoDataset, ds2: GeoDataset) -> bool:
39
43
class CustomGeoDataset (GeoDataset ):
40
44
def __init__ (
41
45
self ,
42
- items : list [tuple [BoundingBox , str ]] = [(BoundingBox (0 , 1 , 0 , 1 , 0 , 40 ), '' )],
43
- crs : CRS = CRS .from_epsg (3005 ),
44
- res : tuple [float , float ] = (1 , 1 ),
46
+ index : pd .IntervalIndex | None = None ,
47
+ geometry : Sequence [Geometry ] = [shapely .box (0 , 0 , 1 , 1 )],
45
48
) -> None :
46
- super ().__init__ ()
47
- for box , content in items :
48
- self .index .insert (0 , tuple (box ), content )
49
- self ._crs = crs
50
- self .res = res
49
+ if index is None :
50
+ intervals = [(MINT , MAXT )] * len (geometry )
51
+ index = pd .IntervalIndex .from_tuples (
52
+ intervals , closed = 'both' , name = 'datetime'
53
+ )
54
+ crs = CRS .from_epsg (3005 )
55
+ self .index = GeoDataFrame (index = index , geometry = geometry , crs = crs )
56
+ self .res = (1 , 1 )
51
57
52
58
def __getitem__ (self , query : BoundingBox ) -> dict [str , Any ]:
53
- hits = self .index .intersection (tuple (query ), objects = True )
54
- hit = next (iter (hits ))
55
- return {'content' : hit .object }
59
+ return {'index' : query }
56
60
57
61
58
62
@pytest .mark .parametrize (
@@ -67,14 +71,13 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
67
71
def test_random_bbox_assignment (
68
72
lengths : Sequence [int | float ], expected_lengths : Sequence [int ]
69
73
) -> None :
70
- ds = CustomGeoDataset (
71
- [
72
- (BoundingBox (0 , 1 , 0 , 1 , 0 , 0 ), 'a' ),
73
- (BoundingBox (1 , 2 , 0 , 1 , 0 , 0 ), 'b' ),
74
- (BoundingBox (2 , 3 , 0 , 1 , 0 , 0 ), 'c' ),
75
- (BoundingBox (3 , 4 , 0 , 1 , 0 , 0 ), 'd' ),
76
- ]
77
- )
74
+ geometry = [
75
+ shapely .box (0 , 0 , 1 , 1 ),
76
+ shapely .box (1 , 0 , 2 , 1 ),
77
+ shapely .box (2 , 0 , 3 , 1 ),
78
+ shapely .box (3 , 0 , 4 , 1 ),
79
+ ]
80
+ ds = CustomGeoDataset (geometry = geometry )
78
81
79
82
train_ds , val_ds , test_ds = random_bbox_assignment (ds , lengths )
80
83
@@ -94,7 +97,6 @@ def test_random_bbox_assignment(
94
97
# Test __getitem__
95
98
x = train_ds [train_ds .bounds ]
96
99
assert isinstance (x , dict )
97
- assert isinstance (x ['content' ], str )
98
100
99
101
100
102
def test_random_bbox_assignment_invalid_inputs () -> None :
@@ -110,14 +112,13 @@ def test_random_bbox_assignment_invalid_inputs() -> None:
110
112
111
113
112
114
def test_random_bbox_splitting () -> None :
113
- ds = CustomGeoDataset (
114
- [
115
- (BoundingBox (0 , 1 , 0 , 1 , 0 , 0 ), 'a' ),
116
- (BoundingBox (1 , 2 , 0 , 1 , 0 , 0 ), 'b' ),
117
- (BoundingBox (2 , 3 , 0 , 1 , 0 , 0 ), 'c' ),
118
- (BoundingBox (3 , 4 , 0 , 1 , 0 , 0 ), 'd' ),
119
- ]
120
- )
115
+ geometry = [
116
+ shapely .box (0 , 0 , 1 , 1 ),
117
+ shapely .box (1 , 0 , 2 , 1 ),
118
+ shapely .box (2 , 0 , 3 , 1 ),
119
+ shapely .box (3 , 0 , 4 , 1 ),
120
+ ]
121
+ ds = CustomGeoDataset (geometry = geometry )
121
122
122
123
ds_area = total_area (ds )
123
124
@@ -145,7 +146,6 @@ def test_random_bbox_splitting() -> None:
145
146
# Test __get_item__
146
147
x = train_ds [train_ds .bounds ]
147
148
assert isinstance (x , dict )
148
- assert isinstance (x ['content' ], str )
149
149
150
150
# Test invalid input fractions
151
151
with pytest .raises (ValueError , match = 'Sum of input fractions must equal 1.' ):
@@ -157,12 +157,8 @@ def test_random_bbox_splitting() -> None:
157
157
158
158
159
159
def test_random_grid_cell_assignment () -> None :
160
- ds = CustomGeoDataset (
161
- [
162
- (BoundingBox (0 , 12 , 0 , 12 , 0 , 0 ), 'a' ),
163
- (BoundingBox (12 , 24 , 0 , 12 , 0 , 0 ), 'b' ),
164
- ]
165
- )
160
+ geometry = [shapely .box (0 , 0 , 12 , 12 ), shapely .box (12 , 0 , 24 , 12 )]
161
+ ds = CustomGeoDataset (geometry = geometry )
166
162
167
163
train_ds , val_ds , test_ds = random_grid_cell_assignment (
168
164
ds , fractions = [1 / 2 , 1 / 4 , 1 / 4 ], grid_size = 5
@@ -185,7 +181,6 @@ def test_random_grid_cell_assignment() -> None:
185
181
# Test __get_item__
186
182
x = train_ds [train_ds .bounds ]
187
183
assert isinstance (x , dict )
188
- assert isinstance (x ['content' ], str )
189
184
190
185
# Test invalid input fractions
191
186
with pytest .raises (ValueError , match = 'Sum of input fractions must equal 1.' ):
@@ -199,21 +194,20 @@ def test_random_grid_cell_assignment() -> None:
199
194
200
195
201
196
def test_roi_split () -> None :
202
- ds = CustomGeoDataset (
203
- [
204
- (BoundingBox (0 , 1 , 0 , 1 , 0 , 0 ), 'a' ),
205
- (BoundingBox (1 , 2 , 0 , 1 , 0 , 0 ), 'b' ),
206
- (BoundingBox (2 , 3 , 0 , 1 , 0 , 0 ), 'c' ),
207
- (BoundingBox (3 , 4 , 0 , 1 , 0 , 0 ), 'd' ),
208
- ]
209
- )
197
+ geometry = [
198
+ shapely .box (0 , 0 , 1 , 1 ),
199
+ shapely .box (1 , 0 , 2 , 1 ),
200
+ shapely .box (2 , 0 , 3 , 1 ),
201
+ shapely .box (3 , 0 , 4 , 1 ),
202
+ ]
203
+ ds = CustomGeoDataset (geometry = geometry )
210
204
211
205
train_ds , val_ds , test_ds = roi_split (
212
206
ds ,
213
207
rois = [
214
- BoundingBox (0 , 2 , 0 , 1 , 0 , 0 ),
215
- BoundingBox (2 , 3.5 , 0 , 1 , 0 , 0 ),
216
- BoundingBox (3.5 , 4 , 0 , 1 , 0 , 0 ),
208
+ BoundingBox (0 , 2 , 0 , 1 , MINT , MAXT ),
209
+ BoundingBox (2 , 3.5 , 0 , 1 , MINT , MAXT ),
210
+ BoundingBox (3.5 , 4 , 0 , 1 , MINT , MAXT ),
217
211
],
218
212
)
219
213
@@ -234,7 +228,6 @@ def test_roi_split() -> None:
234
228
# Test __get_item__
235
229
x = train_ds [train_ds .bounds ]
236
230
assert isinstance (x , dict )
237
- assert isinstance (x ['content' ], str )
238
231
239
232
# Test invalid input rois
240
233
with pytest .raises (ValueError , match = "ROIs in input rois can't overlap." ):
@@ -257,14 +250,23 @@ def test_roi_split() -> None:
257
250
def test_time_series_split (
258
251
lengths : Sequence [tuple [int , int ] | int | float ], expected_lengths : Sequence [int ]
259
252
) -> None :
260
- ds = CustomGeoDataset (
253
+ geometry = [
254
+ shapely .box (0 , 0 , 1 , 1 ),
255
+ shapely .box (0 , 0 , 1 , 1 ),
256
+ shapely .box (0 , 0 , 1 , 1 ),
257
+ shapely .box (0 , 0 , 1 , 1 ),
258
+ ]
259
+ index = pd .IntervalIndex .from_tuples (
261
260
[
262
- (BoundingBox (0 , 1 , 0 , 1 , 0 , 10 ), 'a' ),
263
- (BoundingBox (0 , 1 , 0 , 1 , 10 , 20 ), 'b' ),
264
- (BoundingBox (0 , 1 , 0 , 1 , 20 , 30 ), 'c' ),
265
- (BoundingBox (0 , 1 , 0 , 1 , 30 , 40 ), 'd' ),
266
- ]
261
+ (datetime (2025 , 4 , 25 ), datetime (2025 , 4 , 26 )),
262
+ (datetime (2025 , 4 , 26 ), datetime (2025 , 4 , 27 )),
263
+ (datetime (2025 , 4 , 27 ), datetime (2025 , 4 , 28 )),
264
+ (datetime (2025 , 4 , 28 ), datetime (2025 , 4 , 29 )),
265
+ ],
266
+ closed = 'both' ,
267
+ name = 'datetime' ,
267
268
)
269
+ ds = CustomGeoDataset (index , geometry )
268
270
269
271
train_ds , val_ds , test_ds = time_series_split (ds , lengths )
270
272
@@ -284,7 +286,6 @@ def test_time_series_split(
284
286
# Test __get_item__
285
287
x = train_ds [train_ds .bounds ]
286
288
assert isinstance (x , dict )
287
- assert isinstance (x ['content' ], str )
288
289
289
290
290
291
def test_time_series_split_invalid_input () -> None :
0 commit comments