Skip to content

GeoDataset: rtree -> geopandas #2747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 93 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
222804a
GeoDataset: rtree -> geopandas
adamjstewart Apr 21, 2025
0b972bb
Update documented spatiotemporal indexing backend
adamjstewart Apr 21, 2025
ff6dfe2
Re-add rtree dep until we are ready to remove
adamjstewart Apr 21, 2025
7e22f3d
EDDMapS: update __init__
adamjstewart Apr 22, 2025
9e26c4d
GBIF: update __init__
adamjstewart Apr 22, 2025
876948d
iNaturalist: update __init__
adamjstewart Apr 22, 2025
8f216b1
Fewer backwards-incompatible changes
adamjstewart Apr 22, 2025
f27709f
L7 Irish: pray for the best?
adamjstewart Apr 22, 2025
58d04be
MMFlood: pray for the best?
adamjstewart Apr 22, 2025
e5a3053
Minor docstring change
adamjstewart Apr 22, 2025
cff997f
Intersection/Union
adamjstewart Apr 23, 2025
538b602
Add missing import
adamjstewart Apr 23, 2025
66598f9
Raster/VectorDataset: update __getitem__
adamjstewart Apr 23, 2025
e2bd70c
GeoDataset: tests pass
adamjstewart Apr 23, 2025
2f35f9e
Intersection/UnionDataset: various fixes
adamjstewart Apr 23, 2025
44e68b6
Use real datetime objects in tests
adamjstewart Apr 24, 2025
76967db
RasterDataset: all tests now pass
adamjstewart Apr 24, 2025
aafe860
VectorDataset: all tests passing
adamjstewart Apr 24, 2025
7a1c66b
Fix remaining type hint issues in geo.py
adamjstewart Apr 24, 2025
c599fc0
GlobBiomass: try to fix docs
adamjstewart Apr 24, 2025
ae7c397
Fix intersection/union index reprojection
adamjstewart Apr 25, 2025
39c9ea6
Fix no intersection
adamjstewart Apr 25, 2025
3ad783e
Line/point intersection now counts as intersection
adamjstewart Apr 25, 2025
e57425c
Silence geopandas/shapely warning
adamjstewart Apr 25, 2025
1ab53ed
rasterio.crs.CRS to pyproj.CRS
adamjstewart Apr 25, 2025
035f26b
sjoin -> overlay
adamjstewart Apr 26, 2025
cc3d373
BoundingBox: convert from int to datetime
adamjstewart Apr 26, 2025
4d1c2b5
Docs: link to pyproj.CRS
adamjstewart Apr 26, 2025
ea63552
VectorDataset: fix fiona/pyproj CRS compatibility
adamjstewart Apr 26, 2025
c76dd16
CDL: fix tests
adamjstewart Apr 26, 2025
a3425b3
CopernicusBench: fix tests
adamjstewart Apr 26, 2025
0b49098
EDDMapS: fix tests
adamjstewart Apr 26, 2025
5211629
GBIF: fix getitem
adamjstewart Apr 26, 2025
3aa2ebc
HySpecNet-11k: fix tests
adamjstewart Apr 26, 2025
40ba141
iNaturalist: fix tests
adamjstewart Apr 26, 2025
1a75cc0
IntersectionDataset: implement temporal intersection
adamjstewart Apr 28, 2025
e65af8c
BoundingBox and shapely.box have different order
adamjstewart Apr 28, 2025
a250252
Restore backwards compatibility for edge overlap
adamjstewart Apr 28, 2025
0dcadbb
Break backwards compatibility for point datasets
adamjstewart Apr 28, 2025
4a36a3c
Update util tests
adamjstewart Apr 29, 2025
f4cba79
GeoSampler: port to geopandas
adamjstewart Apr 29, 2025
ca37305
RandomGeoSampler: port to geopandas
adamjstewart Apr 30, 2025
06ca77e
Ruff
adamjstewart Apr 30, 2025
b1d5366
GridGeoSampler: convert to geopandas
adamjstewart May 1, 2025
c0dd9c1
PreChippedGeoSampler: port to geopandas
adamjstewart May 1, 2025
4d62d5e
BatchGeoSampler: port to geopandas
adamjstewart May 1, 2025
18f1a45
RandomBatchGeoSampler: port to geopandas
adamjstewart May 1, 2025
d604e81
Simpler logic
adamjstewart May 1, 2025
b12a40c
RandomGeoSampler: revert most backwards-incompatible changes
adamjstewart May 1, 2025
8b9d19a
Samplers: fix type hints
adamjstewart May 1, 2025
48aaf73
random_bbox_assignment: port to geopandas
adamjstewart May 1, 2025
f07f452
random_bbox_splitting: port to geopandas
adamjstewart May 1, 2025
06a8304
random_grid_cell_assignment: port to geopandas
adamjstewart May 2, 2025
7dc0c56
roi_split: port to geopandas
adamjstewart May 2, 2025
acf6ef0
time_series_split: port to geopandas
adamjstewart May 3, 2025
cbe5711
Splitters: fix type hints
adamjstewart May 3, 2025
8e619a1
EDDMapS: fix tests
adamjstewart May 3, 2025
b2a5973
GBIF: fix tests
adamjstewart May 3, 2025
52e764a
iNaturalist: fix tests
adamjstewart May 3, 2025
1d62e20
Landsat/Sentinel: remove redundant length tests
adamjstewart May 3, 2025
04e53c1
SSL4EO: fix tests
adamjstewart May 3, 2025
96ac820
SACT: port to geopandas
adamjstewart May 3, 2025
599e21f
L7 Irish: fix tests
adamjstewart May 3, 2025
df8b4e3
AgriFieldNet: port to geopandas
adamjstewart May 3, 2025
2bed4ee
Docs: correct link to pandas
adamjstewart May 3, 2025
a4414a8
Datamodules: port tests to geopandas
adamjstewart May 3, 2025
70e0621
Ruff
adamjstewart May 3, 2025
92701c9
BoundingBox: update tests, type hints
adamjstewart May 4, 2025
311858b
Tutorials: port to geopandas
adamjstewart May 4, 2025
30bf3e3
shapely.LineString did not exist in shapely 1.X
adamjstewart May 4, 2025
809bd0f
shapely.Geometry did not exist in shapely 1.X
adamjstewart May 4, 2025
ef3a8b2
Deps: bump shapely to 2+
adamjstewart May 4, 2025
6b896c2
Tutorials: port to geopandas
adamjstewart May 4, 2025
d4cb99f
Open Buildings: port to geopandas
adamjstewart May 6, 2025
6c8bbf8
Try geopandas 0.12.0
adamjstewart May 6, 2025
2797a19
Try geopandas 0.12.1
adamjstewart May 6, 2025
6991ced
Add geopandas min version
adamjstewart May 6, 2025
4cb03b1
GlobBiomass: port to geopandas
adamjstewart May 6, 2025
9cb7143
LandCoverAI: port to geopandas
adamjstewart May 6, 2025
9ea9ae6
MMFlood: port to geopandas
adamjstewart May 6, 2025
3d82a87
EnviroAtlas: port to geopandas
adamjstewart May 6, 2025
136595b
Chesapeake CVPR: port to geopandas
adamjstewart May 6, 2025
fd69c14
mypy: fix all type hints
adamjstewart May 6, 2025
24e4375
Tutorials: port to geopandas
adamjstewart May 6, 2025
9c990b3
L7 Irish: 1 + 1 != 3
adamjstewart May 7, 2025
8f74f5c
Fiona: switch from PROJ dict to WKT for greater compatibility
adamjstewart May 7, 2025
e9ef360
Explicitly drop geom type conversions during intersection to silence …
adamjstewart May 7, 2025
457a106
Ruff
adamjstewart May 7, 2025
e923515
IntersectionDataset: add more tests
adamjstewart May 7, 2025
58891c3
UnionDataset: increase code coverage
adamjstewart May 7, 2025
f641e07
random_bbox_splitting: 100% test coverage
adamjstewart May 7, 2025
1e33ec5
RandomGeoSampler: 100% test coverage
adamjstewart May 7, 2025
78866ff
index.sindex.query -> index.cx
adamjstewart May 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@
'kornia': ('https://kornia.readthedocs.io/en/stable/', None),
'matplotlib': ('https://matplotlib.org/stable/', None),
'numpy': ('https://numpy.org/doc/stable/', None),
'pandas': ('https://pandas.pydata.org/docs/', None),
'pyproj': ('https://pyproj4.github.io/pyproj/stable/', None),
'python': ('https://docs.python.org/3', None),
'lightning': ('https://lightning.ai/docs/pytorch/stable/', None),
'rasterio': ('https://rasterio.readthedocs.io/en/stable/', None),
'rtree': ('https://rtree.readthedocs.io/en/stable/', None),
'segmentation_models_pytorch': ('https://smp.readthedocs.io/en/stable/', None),
'sklearn': ('https://scikit-learn.org/stable/', None),
'timm': ('https://huggingface.co/docs/timm/main/en/', None),
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/custom_raster_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"source": [
"### GeoDataset\n",
"\n",
"In TorchGeo, each `GeoDataset` uses an [R-tree](https://en.wikipedia.org/wiki/R-tree) to store the spatiotemporal bounding box of each file or data point. To simplify this process and reduce code duplication, we provide two subclasses of `GeoDataset`:\n",
"In TorchGeo, each `GeoDataset` uses a [GeoDataFrame](https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.html) to store the spatiotemporal bounding box of each file or data point. To simplify this process and reduce code duplication, we provide two subclasses of `GeoDataset`:\n",
"\n",
"* `RasterDataset`: recursively search for raster files in a directory\n",
"* `VectorDataset`: recursively search for vector files in a directory\n",
Expand Down
6 changes: 2 additions & 4 deletions docs/tutorials/earth_surface_water.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,8 @@
" \"\"\"\n",
"\n",
" # To avoid loading the entire dataset in memory, we will loop through each img\n",
" # The filenames will be retrieved from the dataset's rtree index\n",
" files = [\n",
" item.object for item in dset.index.intersection(dset.index.bounds, objects=True)\n",
" ]\n",
" # The filenames will be retrieved from the dataset's GeoDataFrame index\n",
" files = dset.index.filepath\n",
"\n",
" # Resetting statistics\n",
" accum_mean = 0\n",
Expand Down
12 changes: 4 additions & 8 deletions docs/tutorials/torchgeo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@
"source": [
"import os\n",
"import tempfile\n",
"from datetime import datetime\n",
"\n",
"import pandas as pd\n",
"from matplotlib import pyplot as plt\n",
"from torch.utils.data import DataLoader\n",
"\n",
Expand Down Expand Up @@ -273,11 +273,7 @@
"source": [
"### Spatiotemporal indexing\n",
"\n",
"How did we do this? TorchGeo uses a data structure called an *R-tree* to store the spatiotemporal bounding box of every file in the dataset. \n",
"\n",
"![R-tree](https://raw.githubusercontent.com/davidmoten/davidmoten.github.io/master/resources/rtree-3d/plot2.png)\n",
"\n",
"TorchGeo extracts the spatial bounding box from the metadata of each file, and the timestamp from the filename. This geospatial and geotemporal metadata allows us to efficiently compute the intersection or union of two datasets. It also lets us quickly retrieve an image and corresponding mask for a particular location in space and time."
"How did we do this? TorchGeo uses a [GeoDataFrame](https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.html) to store the spatiotemporal bounding box of every file in the dataset. TorchGeo extracts the spatial bounding box from the metadata of each file, and the timestamp from the filename. This geospatial and geotemporal metadata allows us to efficiently compute the intersection or union of two datasets. It also lets us quickly retrieve an image and corresponding mask for a particular location in space and time."
]
},
{
Expand All @@ -293,8 +289,8 @@
"xmax = xmin + size * 30\n",
"ymin = 4470000\n",
"ymax = ymin + size * 30\n",
"tmin = datetime(2023, 1, 1).timestamp()\n",
"tmax = datetime(2023, 12, 31).timestamp()\n",
"tmin = pd.Timestamp(2023, 1, 1)\n",
"tmax = pd.Timestamp(2023, 12, 31)\n",
"\n",
"bbox = BoundingBox(xmin, xmax, ymin, ymax, tmin, tmax)\n",
"sample = dataset[bbox]\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/user/metrics/features.csv
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
Library,ML Backend,I/O Backend,Spatial Backend,Transform Backend,Datasets,Weights,CLI,GUI,Reprojection,STAC,Time Series
`TorchGeo`_,PyTorch,"GDAL, h5py, laspy, NetCDF4, OpenCV, pandas, pillow, scipy, xarray",R-tree,Kornia,125,93,✅,❌,✅,❌,🚧
`TorchGeo`_,PyTorch,"GDAL, h5py, laspy, NetCDF4, OpenCV, pandas, pillow, scipy, xarray",geopandas,Kornia,125,93,✅,❌,✅,❌,🚧
`eo-learn`_,scikit-learn,"GDAL, OpenCV, pandas, scipy, Zarr",geopandas,numpy,0,0,❌,❌,✅,❌,🚧
`Raster Vision`_,"PyTorch, TensorFlow*","GDAL, OpenCV, pandas, pillow, scipy, xarray",STAC,Albumentations,0,6,✅,❌,✅,✅,🚧
`DeepForest`_,"PyTorch, TensorFlow*","GDAL, OpenCV, pandas, pillow, scipy",R-tree,Albumentations,0,4,❌,❌,❌,❌,❌
`samgeo`_,PyTorch,"GDAL, OpenCV, pandas, xarray",geopandas,numpy,0,0,❌,✅,✅,❌,❌
`TerraTorch`_,PyTorch,"GDAL, h5py, pandas, xarray",R-tree,Albumentations,27,1,✅,❌,✅,❌,🚧
`TerraTorch`_,PyTorch,"GDAL, h5py, pandas, xarray",-,Albumentations,27,1,✅,❌,✅,❌,🚧
`SITS`_,R Torch,GDAL,-,tidyverse,22,0,❌,❌,✅,✅,✅
`srai`_,PyTorch,"pandas, polars, duckdb","geopandas, duckdb-spatial",-,0,0,❌,❌,❌,❌,❌
`scikit-eo`_,"scikit-learn, TensorFlow","GDAL, pandas, scipy","geopandas",numpy,0,0,❌,❌,❌,❌,🚧
Expand Down
3 changes: 2 additions & 1 deletion experiments/ssl4eo/sample_conus.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import fiona
from rtree import index
from sample_ssl4eo import create_bbox, km2deg
from shapely.geometry import MultiPolygon, Point, shape
from shapely import MultiPolygon, Point
from shapely.geometry import shape
from shapely.ops import unary_union
from torchvision.datasets.utils import download_and_extract_archive
from tqdm import tqdm
Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ dependencies = [
"einops>=0.3",
# fiona 1.8.22+ required for Python 3.11 wheels
"fiona>=1.8.22",
# geopandas 0.12.1+ required for Shapely 2+ support
"geopandas>=0.12.1",
# kornia 0.7.4+ required for AugmentationSequential support for unknown keys
"kornia>=0.7.4",
# lightly 1.4.5+ required for LARS optimizer
Expand All @@ -63,12 +65,10 @@ dependencies = [
# rasterio 1.4.0-1.4.2 lack support for merging WarpedVRT objects
# https://github.com/rasterio/rasterio/issues/3196
"rasterio>=1.3.3,!=1.4.0,!=1.4.1,!=1.4.2",
# rtree 1.0.1+ required for Python 3.11 wheels
"rtree>=1.0.1",
# segmentation-models-pytorch 0.5+ required for new UnetDecoder API
"segmentation-models-pytorch>=0.5",
# shapely 1.8.5+ required for Python 3.11 wheels
"shapely>=1.8.5",
# shapely 2+ required for shapely.Geometry
"shapely>=2",
# timm 0.8+ required for timm.models.adapt_input_conv, 0.9.2 required by SMP
"timm>=0.9.2",
# torch 2+ required for Python 3.11 wheels
Expand Down Expand Up @@ -259,6 +259,8 @@ filterwarnings = [
"ignore:Failing to pass a value to the 'type_params' parameter of 'typing.*' is deprecated:DeprecationWarning:pydantic",
# https://github.com/Lightning-AI/pytorch-lightning/pull/20802
"ignore::jsonargparse._deprecated.JsonargparseDeprecationWarning:lightning.pytorch.cli",
# https://github.com/geopandas/geopandas/pull/3453
"ignore:The 'shapely.geos' module is deprecated:DeprecationWarning:geopandas._compat",

# Expected warnings
# Lightning warns us about using num_workers=0, but it's faster on macOS
Expand Down
4 changes: 2 additions & 2 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ setuptools==77.0.1
# install
einops==0.3.0
fiona==1.8.22
geopandas==0.12.1
kornia==0.7.4
lightly==1.4.5
lightning[pytorch-extra]==2.0.0
Expand All @@ -13,9 +14,8 @@ pandas==1.5.0
pillow==9.2.0
pyproj==3.4.0
rasterio==1.3.11
rtree==1.0.1
segmentation-models-pytorch==0.5.0
shapely==1.8.5
shapely==2.0.0
timm==0.9.2
torch==2.0.0
torchmetrics==1.2.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ setuptools==80.4.0
# install
einops==0.8.1
fiona==1.10.1
geopandas==1.0.1
kornia==0.8.1
lightly==1.5.20
lightning[pytorch-extra]==2.5.0.post0
Expand All @@ -13,7 +14,6 @@ pandas==2.2.3
pillow==11.2.1
pyproj==3.7.1
rasterio==1.4.3
rtree==1.4.0
segmentation-models-pytorch==0.5.0
shapely==2.0.7
timm==1.0.15
Expand Down
3 changes: 2 additions & 1 deletion tests/data/eurocrops/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

import fiona
from rasterio.crs import CRS
from shapely.geometry import Polygon, mapping
from shapely import Polygon
from shapely.geometry import mapping

# Size of example crop field polygon in projection units.
# This is set to align with Sentinel-2 test data, which is a 128x128 image at 10
Expand Down
2 changes: 1 addition & 1 deletion tests/data/openbuildings/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import shutil

from shapely.geometry import Polygon
from shapely import Polygon

SIZE = 0.05

Expand Down
17 changes: 12 additions & 5 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
from typing import Any

import matplotlib.pyplot as plt
import pandas as pd
import pytest
import shapely
import torch
from _pytest.fixtures import SubRequest
from geopandas import GeoDataFrame
from lightning.pytorch import Trainer
from matplotlib.figure import Figure
from rasterio.crs import CRS
from pyproj import CRS
from torch import Tensor

from torchgeo.datamodules import (
Expand All @@ -20,19 +23,23 @@
from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset
from torchgeo.samplers import RandomBatchGeoSampler, RandomGeoSampler

MINT = pd.Timestamp(2025, 4, 24)
MAXT = pd.Timestamp(2025, 4, 25)


class CustomGeoDataset(GeoDataset):
def __init__(
self, split: str = 'train', length: int = 1, download: bool = False
) -> None:
super().__init__()
for i in range(length):
self.index.insert(i, (0, 1, 2, 3, 4, 5))
geometry = [shapely.box(0, 0, 1, 1)] * length
index = pd.IntervalIndex([pd.Interval(MINT, MAXT)] * length, name='datetime')
crs = CRS.from_epsg(4326)
self.index = GeoDataFrame(index=index, geometry=geometry, crs=crs)
self.res = (1, 1)

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)
return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query}
return {'image': image, 'crs': self.index.crs, 'bounds': query}

def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_agb_live_woody_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import pytest
import torch
import torch.nn as nn
from pyproj import CRS
from pytest import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
AbovegroundLiveWoodyBiomassDensity,
Expand Down
5 changes: 3 additions & 2 deletions tests/datasets/test_agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import pytest
import torch
import torch.nn as nn
from pyproj import CRS
from pytest import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
AgriFieldNet,
Expand Down Expand Up @@ -69,7 +70,7 @@ def test_plot_prediction(self, dataset: AgriFieldNet) -> None:
plt.close()

def test_invalid_query(self, dataset: AgriFieldNet) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
query = BoundingBox(0, 0, 0, 0, pd.Timestamp.min, pd.Timestamp.min)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
Expand Down
5 changes: 3 additions & 2 deletions tests/datasets/test_airphen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import pytest
import torch
import torch.nn as nn
from rasterio.crs import CRS
from pyproj import CRS

from torchgeo.datasets import (
Airphen,
Expand Down Expand Up @@ -55,7 +56,7 @@ def test_no_data(self, tmp_path: Path) -> None:
Airphen(tmp_path)

def test_invalid_query(self, dataset: Airphen) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
query = BoundingBox(0, 0, 0, 0, pd.Timestamp.min, pd.Timestamp.min)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
Expand Down
5 changes: 3 additions & 2 deletions tests/datasets/test_astergdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import pytest
import torch
import torch.nn as nn
from rasterio.crs import CRS
from pyproj import CRS

from torchgeo.datasets import (
AsterGDEM,
Expand Down Expand Up @@ -66,7 +67,7 @@ def test_plot_prediction(self, dataset: AsterGDEM) -> None:
plt.close()

def test_invalid_query(self, dataset: AsterGDEM) -> None:
query = BoundingBox(100, 100, 100, 100, 0, 0)
query = BoundingBox(100, 100, 100, 100, pd.Timestamp.min, pd.Timestamp.min)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
Expand Down
5 changes: 3 additions & 2 deletions tests/datasets/test_cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import pytest
import torch
import torch.nn as nn
from pyproj import CRS
from pytest import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
Expand Down Expand Up @@ -76,7 +77,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
CanadianBuildingFootprints(tmp_path)

def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
query = BoundingBox(2, 2, 2, 2, 2, 2)
query = BoundingBox(2, 2, 2, 2, pd.Timestamp.min, pd.Timestamp.min)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
Expand Down
10 changes: 5 additions & 5 deletions tests/datasets/test_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import glob
import os
import shutil
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import pytest
import torch
import torch.nn as nn
from pyproj import CRS
from pytest import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
CDL,
Expand Down Expand Up @@ -71,9 +71,9 @@ def test_or(self, dataset: CDL) -> None:

def test_full_year(self, dataset: CDL) -> None:
bbox = dataset.bounds
time = datetime(2023, 6, 1).timestamp()
time = pd.Timestamp(2023, 6, 1)
query = BoundingBox(bbox.minx, bbox.maxx, bbox.miny, bbox.maxy, time, time)
next(dataset.index.intersection(tuple(query)))
dataset[query]

def test_already_extracted(self, dataset: CDL) -> None:
CDL(dataset.paths, years=[2023, 2022])
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
CDL(tmp_path)

def test_invalid_query(self, dataset: CDL) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
query = BoundingBox(0, 0, 0, 0, pd.Timestamp.min, pd.Timestamp.min)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
Expand Down
7 changes: 4 additions & 3 deletions tests/datasets/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pyproj import CRS
from pytest import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_plot(self, dataset: ChesapeakeDC) -> None:
plt.close()

def test_invalid_query(self, dataset: ChesapeakeDC) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
query = BoundingBox(0, 0, 0, 0, pd.Timestamp.min, pd.Timestamp.min)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
Expand Down Expand Up @@ -191,7 +192,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
ChesapeakeCVPR(tmp_path, checksum=True)

def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
query = BoundingBox(0, 0, 0, 0, pd.Timestamp.min, pd.Timestamp.min)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
Expand Down
Loading