Skip to content

Commit 166e55a

Browse files
authored
GeoDataset: rtree -> geopandas (#2747)
* GeoDataset: rtree -> geopandas * Update documented spatiotemporal indexing backend * Re-add rtree dep until we are ready to remove * EDDMapS: update __init__ * GBIF: update __init__ * iNaturalist: update __init__ * Fewer backwards-incompatible changes * L7 Irish: pray for the best? * MMFlood: pray for the best? * Minor docstring change * Intersection/Union * Add missing import * Raster/VectorDataset: update __getitem__ * GeoDataset: tests pass * Intersection/UnionDataset: various fixes * Use real datetime objects in tests * RasterDataset: all tests now pass * VectorDataset: all tests passing * Fix remaining type hint issues in geo.py * GlobBiomass: try to fix docs * Fix intersection/union index reprojection * Fix no intersection * Line/point intersection now counts as intersection * Silence geopandas/shapely warning * rasterio.crs.CRS to pyproj.CRS * sjoin -> overlay * BoundingBox: convert from int to datetime * Docs: link to pyproj.CRS * VectorDataset: fix fiona/pyproj CRS compatibility * CDL: fix tests * CopernicusBench: fix tests * EDDMapS: fix tests * GBIF: fix getitem * HySpecNet-11k: fix tests * iNaturalist: fix tests * IntersectionDataset: implement temporal intersection * BoundingBox and shapely.box have different order * Restore backwards compatibility for edge overlap * Break backwards compatibility for point datasets * Update util tests * GeoSampler: port to geopandas * RandomGeoSampler: port to geopandas * Ruff * GridGeoSampler: convert to geopandas * PreChippedGeoSampler: port to geopandas * BatchGeoSampler: port to geopandas * RandomBatchGeoSampler: port to geopandas * Simpler logic * RandomGeoSampler: revert most backwards-incompatible changes * Samplers: fix type hints * random_bbox_assignment: port to geopandas * random_bbox_splitting: port to geopandas * random_grid_cell_assignment: port to geopandas * roi_split: port to geopandas * time_series_split: port to geopandas * Splitters: fix type hints * EDDMapS: fix tests * GBIF: fix tests * iNaturalist: fix tests * Landsat/Sentinel: remove redundant length tests * SSL4EO: fix tests * SACT: port to geopandas * L7 Irish: fix tests * AgriFieldNet: port to geopandas * Docs: correct link to pandas * Datamodules: port tests to geopandas * Ruff * BoundingBox: update tests, type hints * Tutorials: port to geopandas * shapely.LineString did not exist in shapely 1.X * shapely.Geometry did not exist in shapely 1.X * Deps: bump shapely to 2+ * Tutorials: port to geopandas * Open Buildings: port to geopandas * Try geopandas 0.12.0 * Try geopandas 0.12.1 * Add geopandas min version * GlobBiomass: port to geopandas * LandCoverAI: port to geopandas * MMFlood: port to geopandas * EnviroAtlas: port to geopandas * Chesapeake CVPR: port to geopandas * mypy: fix all type hints * Tutorials: port to geopandas * L7 Irish: 1 + 1 != 3 * Fiona: switch from PROJ dict to WKT for greater compatibility * Explicitly drop geom type conversions during intersection to silence warning * Ruff * IntersectionDataset: add more tests * UnionDataset: increase code coverage * random_bbox_splitting: 100% test coverage * RandomGeoSampler: 100% test coverage * index.sindex.query -> index.cx
1 parent f50dcf0 commit 166e55a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+1146
-1188
lines changed

docs/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,11 @@
113113
'kornia': ('https://kornia.readthedocs.io/en/stable/', None),
114114
'matplotlib': ('https://matplotlib.org/stable/', None),
115115
'numpy': ('https://numpy.org/doc/stable/', None),
116+
'pandas': ('https://pandas.pydata.org/docs/', None),
117+
'pyproj': ('https://pyproj4.github.io/pyproj/stable/', None),
116118
'python': ('https://docs.python.org/3', None),
117119
'lightning': ('https://lightning.ai/docs/pytorch/stable/', None),
118120
'rasterio': ('https://rasterio.readthedocs.io/en/stable/', None),
119-
'rtree': ('https://rtree.readthedocs.io/en/stable/', None),
120121
'segmentation_models_pytorch': ('https://smp.readthedocs.io/en/stable/', None),
121122
'sklearn': ('https://scikit-learn.org/stable/', None),
122123
'timm': ('https://huggingface.co/docs/timm/main/en/', None),

docs/tutorials/custom_raster_dataset.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"source": [
5050
"### GeoDataset\n",
5151
"\n",
52-
"In TorchGeo, each `GeoDataset` uses an [R-tree](https://en.wikipedia.org/wiki/R-tree) to store the spatiotemporal bounding box of each file or data point. To simplify this process and reduce code duplication, we provide two subclasses of `GeoDataset`:\n",
52+
"In TorchGeo, each `GeoDataset` uses a [GeoDataFrame](https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.html) to store the spatiotemporal bounding box of each file or data point. To simplify this process and reduce code duplication, we provide two subclasses of `GeoDataset`:\n",
5353
"\n",
5454
"* `RasterDataset`: recursively search for raster files in a directory\n",
5555
"* `VectorDataset`: recursively search for vector files in a directory\n",

docs/tutorials/earth_surface_water.ipynb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,8 @@
492492
" \"\"\"\n",
493493
"\n",
494494
" # To avoid loading the entire dataset in memory, we will loop through each img\n",
495-
" # The filenames will be retrieved from the dataset's rtree index\n",
496-
" files = [\n",
497-
" item.object for item in dset.index.intersection(dset.index.bounds, objects=True)\n",
498-
" ]\n",
495+
" # The filenames will be retrieved from the dataset's GeoDataFrame index\n",
496+
" files = dset.index.filepath\n",
499497
"\n",
500498
" # Resetting statistics\n",
501499
" accum_mean = 0\n",

docs/tutorials/torchgeo.ipynb

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@
7070
"source": [
7171
"import os\n",
7272
"import tempfile\n",
73-
"from datetime import datetime\n",
7473
"\n",
74+
"import pandas as pd\n",
7575
"from matplotlib import pyplot as plt\n",
7676
"from torch.utils.data import DataLoader\n",
7777
"\n",
@@ -273,11 +273,7 @@
273273
"source": [
274274
"### Spatiotemporal indexing\n",
275275
"\n",
276-
"How did we do this? TorchGeo uses a data structure called an *R-tree* to store the spatiotemporal bounding box of every file in the dataset. \n",
277-
"\n",
278-
"![R-tree](https://raw.githubusercontent.com/davidmoten/davidmoten.github.io/master/resources/rtree-3d/plot2.png)\n",
279-
"\n",
280-
"TorchGeo extracts the spatial bounding box from the metadata of each file, and the timestamp from the filename. This geospatial and geotemporal metadata allows us to efficiently compute the intersection or union of two datasets. It also lets us quickly retrieve an image and corresponding mask for a particular location in space and time."
276+
"How did we do this? TorchGeo uses a [GeoDataFrame](https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.html) to store the spatiotemporal bounding box of every file in the dataset. TorchGeo extracts the spatial bounding box from the metadata of each file, and the timestamp from the filename. This geospatial and geotemporal metadata allows us to efficiently compute the intersection or union of two datasets. It also lets us quickly retrieve an image and corresponding mask for a particular location in space and time."
281277
]
282278
},
283279
{
@@ -293,8 +289,8 @@
293289
"xmax = xmin + size * 30\n",
294290
"ymin = 4470000\n",
295291
"ymax = ymin + size * 30\n",
296-
"tmin = datetime(2023, 1, 1).timestamp()\n",
297-
"tmax = datetime(2023, 12, 31).timestamp()\n",
292+
"tmin = pd.Timestamp(2023, 1, 1)\n",
293+
"tmax = pd.Timestamp(2023, 12, 31)\n",
298294
"\n",
299295
"bbox = BoundingBox(xmin, xmax, ymin, ymax, tmin, tmax)\n",
300296
"sample = dataset[bbox]\n",

docs/user/metrics/features.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
Library,ML Backend,I/O Backend,Spatial Backend,Transform Backend,Datasets,Weights,CLI,GUI,Reprojection,STAC,Time Series
2-
`TorchGeo`_,PyTorch,"GDAL, h5py, laspy, NetCDF4, OpenCV, pandas, pillow, scipy, xarray",R-tree,Kornia,125,93,✅,❌,✅,❌,🚧
2+
`TorchGeo`_,PyTorch,"GDAL, h5py, laspy, NetCDF4, OpenCV, pandas, pillow, scipy, xarray",geopandas,Kornia,125,93,✅,❌,✅,❌,🚧
33
`eo-learn`_,scikit-learn,"GDAL, OpenCV, pandas, scipy, Zarr",geopandas,numpy,0,0,❌,❌,✅,❌,🚧
44
`Raster Vision`_,"PyTorch, TensorFlow*","GDAL, OpenCV, pandas, pillow, scipy, xarray",STAC,Albumentations,0,6,✅,❌,✅,✅,🚧
55
`DeepForest`_,"PyTorch, TensorFlow*","GDAL, OpenCV, pandas, pillow, scipy",R-tree,Albumentations,0,4,❌,❌,❌,❌,❌
66
`samgeo`_,PyTorch,"GDAL, OpenCV, pandas, xarray",geopandas,numpy,0,0,❌,✅,✅,❌,❌
7-
`TerraTorch`_,PyTorch,"GDAL, h5py, pandas, xarray",R-tree,Albumentations,27,1,✅,❌,✅,❌,🚧
7+
`TerraTorch`_,PyTorch,"GDAL, h5py, pandas, xarray",-,Albumentations,27,1,✅,❌,✅,❌,🚧
88
`SITS`_,R Torch,GDAL,-,tidyverse,22,0,❌,❌,✅,✅,✅
99
`srai`_,PyTorch,"pandas, polars, duckdb","geopandas, duckdb-spatial",-,0,0,❌,❌,❌,❌,❌
1010
`scikit-eo`_,"scikit-learn, TensorFlow","GDAL, pandas, scipy","geopandas",numpy,0,0,❌,❌,❌,❌,🚧

experiments/ssl4eo/sample_conus.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import fiona
1212
from rtree import index
1313
from sample_ssl4eo import create_bbox, km2deg
14-
from shapely.geometry import MultiPolygon, Point, shape
14+
from shapely import MultiPolygon, Point
15+
from shapely.geometry import shape
1516
from shapely.ops import unary_union
1617
from torchvision.datasets.utils import download_and_extract_archive
1718
from tqdm import tqdm

pyproject.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ dependencies = [
3939
"einops>=0.3",
4040
# fiona 1.8.22+ required for Python 3.11 wheels
4141
"fiona>=1.8.22",
42+
# geopandas 0.12.1+ required for Shapely 2+ support
43+
"geopandas>=0.12.1",
4244
# kornia 0.7.4+ required for AugmentationSequential support for unknown keys
4345
"kornia>=0.7.4",
4446
# lightly 1.4.5+ required for LARS optimizer
@@ -63,12 +65,10 @@ dependencies = [
6365
# rasterio 1.4.0-1.4.2 lack support for merging WarpedVRT objects
6466
# https://github.com/rasterio/rasterio/issues/3196
6567
"rasterio>=1.3.3,!=1.4.0,!=1.4.1,!=1.4.2",
66-
# rtree 1.0.1+ required for Python 3.11 wheels
67-
"rtree>=1.0.1",
6868
# segmentation-models-pytorch 0.5+ required for new UnetDecoder API
6969
"segmentation-models-pytorch>=0.5",
70-
# shapely 1.8.5+ required for Python 3.11 wheels
71-
"shapely>=1.8.5",
70+
# shapely 2+ required for shapely.Geometry
71+
"shapely>=2",
7272
# timm 0.8+ required for timm.models.adapt_input_conv, 0.9.2 required by SMP
7373
"timm>=0.9.2",
7474
# torch 2+ required for Python 3.11 wheels
@@ -259,6 +259,8 @@ filterwarnings = [
259259
"ignore:Failing to pass a value to the 'type_params' parameter of 'typing.*' is deprecated:DeprecationWarning:pydantic",
260260
# https://github.com/Lightning-AI/pytorch-lightning/pull/20802
261261
"ignore::jsonargparse._deprecated.JsonargparseDeprecationWarning:lightning.pytorch.cli",
262+
# https://github.com/geopandas/geopandas/pull/3453
263+
"ignore:The 'shapely.geos' module is deprecated:DeprecationWarning:geopandas._compat",
262264

263265
# Expected warnings
264266
# Lightning warns us about using num_workers=0, but it's faster on macOS

requirements/min-reqs.old

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ setuptools==77.0.1
44
# install
55
einops==0.3.0
66
fiona==1.8.22
7+
geopandas==0.12.1
78
kornia==0.7.4
89
lightly==1.4.5
910
lightning[pytorch-extra]==2.0.0
@@ -13,9 +14,8 @@ pandas==1.5.0
1314
pillow==9.2.0
1415
pyproj==3.4.0
1516
rasterio==1.3.11
16-
rtree==1.0.1
1717
segmentation-models-pytorch==0.5.0
18-
shapely==1.8.5
18+
shapely==2.0.0
1919
timm==0.9.2
2020
torch==2.0.0
2121
torchmetrics==1.2.0

requirements/required.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ setuptools==80.7.0
44
# install
55
einops==0.8.1
66
fiona==1.10.1
7+
geopandas==1.0.1
78
kornia==0.8.1
89
lightly==1.5.20
910
lightning[pytorch-extra]==2.5.0.post0
@@ -13,7 +14,6 @@ pandas==2.2.3
1314
pillow==11.2.1
1415
pyproj==3.7.1
1516
rasterio==1.4.3
16-
rtree==1.4.0
1717
segmentation-models-pytorch==0.5.0
1818
shapely==2.0.7
1919
timm==1.0.15

tests/data/eurocrops/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
import fiona
1212
from rasterio.crs import CRS
13-
from shapely.geometry import Polygon, mapping
13+
from shapely import Polygon
14+
from shapely.geometry import mapping
1415

1516
# Size of example crop field polygon in projection units.
1617
# This is set to align with Sentinel-2 test data, which is a 128x128 image at 10

tests/data/openbuildings/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
import shutil
1212

13-
from shapely.geometry import Polygon
13+
from shapely import Polygon
1414

1515
SIZE = 0.05
1616

tests/datamodules/test_geo.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
from typing import Any
55

66
import matplotlib.pyplot as plt
7+
import pandas as pd
78
import pytest
9+
import shapely
810
import torch
911
from _pytest.fixtures import SubRequest
12+
from geopandas import GeoDataFrame
1013
from lightning.pytorch import Trainer
1114
from matplotlib.figure import Figure
12-
from rasterio.crs import CRS
15+
from pyproj import CRS
1316
from torch import Tensor
1417

1518
from torchgeo.datamodules import (
@@ -20,19 +23,23 @@
2023
from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset
2124
from torchgeo.samplers import RandomBatchGeoSampler, RandomGeoSampler
2225

26+
MINT = pd.Timestamp(2025, 4, 24)
27+
MAXT = pd.Timestamp(2025, 4, 25)
28+
2329

2430
class CustomGeoDataset(GeoDataset):
2531
def __init__(
2632
self, split: str = 'train', length: int = 1, download: bool = False
2733
) -> None:
28-
super().__init__()
29-
for i in range(length):
30-
self.index.insert(i, (0, 1, 2, 3, 4, 5))
34+
geometry = [shapely.box(0, 0, 1, 1)] * length
35+
index = pd.IntervalIndex([pd.Interval(MINT, MAXT)] * length, name='datetime')
36+
crs = CRS.from_epsg(4326)
37+
self.index = GeoDataFrame(index=index, geometry=geometry, crs=crs)
3138
self.res = (1, 1)
3239

3340
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
3441
image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)
35-
return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query}
42+
return {'image': image, 'crs': self.index.crs, 'bounds': query}
3643

3744
def plot(self, *args: Any, **kwargs: Any) -> Figure:
3845
return plt.figure()

tests/datasets/test_agb_live_woody_density.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import pytest
99
import torch
1010
import torch.nn as nn
11+
from pyproj import CRS
1112
from pytest import MonkeyPatch
12-
from rasterio.crs import CRS
1313

1414
from torchgeo.datasets import (
1515
AbovegroundLiveWoodyBiomassDensity,

tests/datasets/test_agrifieldnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from pathlib import Path
66

77
import matplotlib.pyplot as plt
8+
import pandas as pd
89
import pytest
910
import torch
1011
import torch.nn as nn
12+
from pyproj import CRS
1113
from pytest import MonkeyPatch
12-
from rasterio.crs import CRS
1314

1415
from torchgeo.datasets import (
1516
AgriFieldNet,
@@ -69,7 +70,7 @@ def test_plot_prediction(self, dataset: AgriFieldNet) -> None:
6970
plt.close()
7071

7172
def test_invalid_query(self, dataset: AgriFieldNet) -> None:
72-
query = BoundingBox(0, 0, 0, 0, 0, 0)
73+
query = BoundingBox(0, 0, 0, 0, pd.Timestamp.min, pd.Timestamp.min)
7374
with pytest.raises(
7475
IndexError, match='query: .* not found in index with bounds:'
7576
):

tests/datasets/test_airphen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from pathlib import Path
66

77
import matplotlib.pyplot as plt
8+
import pandas as pd
89
import pytest
910
import torch
1011
import torch.nn as nn
11-
from rasterio.crs import CRS
12+
from pyproj import CRS
1213

1314
from torchgeo.datasets import (
1415
Airphen,
@@ -55,7 +56,7 @@ def test_no_data(self, tmp_path: Path) -> None:
5556
Airphen(tmp_path)
5657

5758
def test_invalid_query(self, dataset: Airphen) -> None:
58-
query = BoundingBox(0, 0, 0, 0, 0, 0)
59+
query = BoundingBox(0, 0, 0, 0, pd.Timestamp.min, pd.Timestamp.min)
5960
with pytest.raises(
6061
IndexError, match='query: .* not found in index with bounds:'
6162
):

tests/datasets/test_astergdem.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from pathlib import Path
77

88
import matplotlib.pyplot as plt
9+
import pandas as pd
910
import pytest
1011
import torch
1112
import torch.nn as nn
12-
from rasterio.crs import CRS
13+
from pyproj import CRS
1314

1415
from torchgeo.datasets import (
1516
AsterGDEM,
@@ -66,7 +67,7 @@ def test_plot_prediction(self, dataset: AsterGDEM) -> None:
6667
plt.close()
6768

6869
def test_invalid_query(self, dataset: AsterGDEM) -> None:
69-
query = BoundingBox(100, 100, 100, 100, 0, 0)
70+
query = BoundingBox(100, 100, 100, 100, pd.Timestamp.min, pd.Timestamp.min)
7071
with pytest.raises(
7172
IndexError, match='query: .* not found in index with bounds:'
7273
):

tests/datasets/test_cbf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from pathlib import Path
66

77
import matplotlib.pyplot as plt
8+
import pandas as pd
89
import pytest
910
import torch
1011
import torch.nn as nn
12+
from pyproj import CRS
1113
from pytest import MonkeyPatch
12-
from rasterio.crs import CRS
1314

1415
from torchgeo.datasets import (
1516
BoundingBox,
@@ -76,7 +77,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
7677
CanadianBuildingFootprints(tmp_path)
7778

7879
def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
79-
query = BoundingBox(2, 2, 2, 2, 2, 2)
80+
query = BoundingBox(2, 2, 2, 2, pd.Timestamp.min, pd.Timestamp.min)
8081
with pytest.raises(
8182
IndexError, match='query: .* not found in index with bounds:'
8283
):

tests/datasets/test_cdl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import glob
55
import os
66
import shutil
7-
from datetime import datetime
87
from pathlib import Path
98

109
import matplotlib.pyplot as plt
10+
import pandas as pd
1111
import pytest
1212
import torch
1313
import torch.nn as nn
14+
from pyproj import CRS
1415
from pytest import MonkeyPatch
15-
from rasterio.crs import CRS
1616

1717
from torchgeo.datasets import (
1818
CDL,
@@ -71,9 +71,9 @@ def test_or(self, dataset: CDL) -> None:
7171

7272
def test_full_year(self, dataset: CDL) -> None:
7373
bbox = dataset.bounds
74-
time = datetime(2023, 6, 1).timestamp()
74+
time = pd.Timestamp(2023, 6, 1)
7575
query = BoundingBox(bbox.minx, bbox.maxx, bbox.miny, bbox.maxy, time, time)
76-
next(dataset.index.intersection(tuple(query)))
76+
dataset[query]
7777

7878
def test_already_extracted(self, dataset: CDL) -> None:
7979
CDL(dataset.paths, years=[2023, 2022])
@@ -117,7 +117,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
117117
CDL(tmp_path)
118118

119119
def test_invalid_query(self, dataset: CDL) -> None:
120-
query = BoundingBox(0, 0, 0, 0, 0, 0)
120+
query = BoundingBox(0, 0, 0, 0, pd.Timestamp.min, pd.Timestamp.min)
121121
with pytest.raises(
122122
IndexError, match='query: .* not found in index with bounds:'
123123
):

tests/datasets/test_chesapeake.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from pathlib import Path
77

88
import matplotlib.pyplot as plt
9+
import pandas as pd
910
import pytest
1011
import torch
1112
import torch.nn as nn
1213
from _pytest.fixtures import SubRequest
14+
from pyproj import CRS
1315
from pytest import MonkeyPatch
14-
from rasterio.crs import CRS
1516

1617
from torchgeo.datasets import (
1718
BoundingBox,
@@ -83,7 +84,7 @@ def test_plot(self, dataset: ChesapeakeDC) -> None:
8384
plt.close()
8485

8586
def test_invalid_query(self, dataset: ChesapeakeDC) -> None:
86-
query = BoundingBox(0, 0, 0, 0, 0, 0)
87+
query = BoundingBox(0, 0, 0, 0, pd.Timestamp.min, pd.Timestamp.min)
8788
with pytest.raises(
8889
IndexError, match='query: .* not found in index with bounds:'
8990
):
@@ -191,7 +192,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
191192
ChesapeakeCVPR(tmp_path, checksum=True)
192193

193194
def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
194-
query = BoundingBox(0, 0, 0, 0, 0, 0)
195+
query = BoundingBox(0, 0, 0, 0, pd.Timestamp.min, pd.Timestamp.min)
195196
with pytest.raises(
196197
IndexError, match='query: .* not found in index with bounds:'
197198
):

0 commit comments

Comments
 (0)