Skip to content

Feature/rasterized vectordataset support #2505 #2785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
11 changes: 11 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,17 @@ VectorDataset

.. autoclass:: VectorDataset

RasterizedVectorDataset
^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: RasterizedVectorDataset

Vector datasets Rasterizers
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: RasterizationStrategy
.. autoclass:: DefaultRasterizationStrategy

NonGeoDataset
^^^^^^^^^^^^^

Expand Down
3 changes: 2 additions & 1 deletion docs/tutorials/custom_raster_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
"In TorchGeo, each `GeoDataset` uses an [R-tree](https://en.wikipedia.org/wiki/R-tree) to store the spatiotemporal bounding box of each file or data point. To simplify this process and reduce code duplication, we provide two subclasses of `GeoDataset`:\n",
"\n",
"* `RasterDataset`: recursively search for raster files in a directory\n",
"* `VectorDataset`: recursively search for vector files in a directory\n",
"* `VectorDataset`: recursively search for vector files in a directory (Items are features including geometry and attributes)\n",
"* `RasterizedVectorDataset`: Inherits from VectorDataset (Items are rasterized geometry masks)\n",
"\n",
"In this example, we'll be working with raster images, so we'll choose `RasterDataset` as the base class."
]
Expand Down
2 changes: 1 addition & 1 deletion docs/user/alternatives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Features

**Transform Backend**: The transform library used to perform data augmentation. For example, Kornia performs all augmentations on PyTorch Tensors, allowing you to run your transforms on the GPU for an entire mini-batch at a time.

**Datasets**: The number of geospatial datasets built into the library. Note that most projects have something similar to TorchGeo's ``RasterDataset`` and ``VectorDataset``, allowing you to work with generic raster and vector files. Collections of datasets are only counted a single time, so data loaders for Landsats 1--9 are a single dataset, and data loaders for SpaceNets 1--8 are also a single dataset.
**Datasets**: The number of geospatial datasets built into the library. Note that most projects have something similar to TorchGeo's ``RasterDataset``, ``VectorDataset`` and ``RasterizedVectorDataset``, allowing you to work with generic raster and vector files. Collections of datasets are only counted a single time, so data loaders for Landsats 1--9 are a single dataset, and data loaders for SpaceNets 1--8 are also a single dataset.

**Weights**: The number of model weights pre-trained on geospatial data that are offered by the library. Note that most projects support hundreds of model architectures via a library like PyTorch Image Models, and can use models pre-trained on ImageNet. There are far fewer libraries that provide foundation model weights pre-trained on multispectral satellite imagery.

Expand Down
112 changes: 102 additions & 10 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from pathlib import Path
from typing import Any

import fiona
import pytest
import shapely
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
Expand All @@ -26,6 +28,7 @@
NonGeoClassificationDataset,
NonGeoDataset,
RasterDataset,
RasterizedVectorDataset,
Sentinel2,
UnionDataset,
VectorDataset,
Expand Down Expand Up @@ -71,6 +74,14 @@ class CustomVectorDataset(VectorDataset):
"""


class CustomRasterizedVectorDataset(RasterizedVectorDataset):
filename_glob = '*.geojson'
date_format = '%Y'
filename_regex = r"""
^vector_(?P<date>\d{4})\.geojson
"""


class CustomSentinelDataset(Sentinel2):
all_bands: tuple[str, ...] = ()
separate_files = False
Expand Down Expand Up @@ -393,17 +404,96 @@ class TestVectorDataset:
def dataset(self) -> CustomVectorDataset:
root = os.path.join('tests', 'data', 'vector')
transforms = nn.Identity()
return CustomVectorDataset(root, res=(0.1, 0.1), transforms=transforms)
return CustomVectorDataset(root, transforms=transforms)

@pytest.fixture(scope='class')
def multilabel(self) -> CustomVectorDataset:
def clipped_geoemtries_dataset(self) -> CustomVectorDataset:
root = os.path.join('tests', 'data', 'vector')
transforms = nn.Identity()
return CustomVectorDataset(
return CustomVectorDataset(root, transforms=transforms, clip_geometries=True)

def test_getitem(self, dataset: CustomRasterizedVectorDataset) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert len(x) == 3

def test_empty_query_window(self, dataset: CustomRasterizedVectorDataset) -> None:
query = BoundingBox(1.1, 1.9, 1.1, 1.9, 0, sys.maxsize)
x = dataset[query]
assert len(x) == 0

def test_invalid_query(self, dataset: CustomRasterizedVectorDataset) -> None:
query = BoundingBox(3, 3, 3, 3, 0, 0)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
dataset[query]

def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
RasterizedVectorDataset(tmp_path)

def test_clip_geometries(
self,
dataset: CustomRasterizedVectorDataset,
clipped_geoemtries_dataset: CustomRasterizedVectorDataset,
) -> None:
# split the query window in two
query_window_1, query_window_2 = dataset.bounds.split(0.3, horizontal=False)

# retrieve the datasets elements in window 1
dataset_elements = dataset[query_window_1]
clipped_dataset_elements = clipped_geoemtries_dataset[query_window_1]

# Compare clippied and non clipped geoemtries area in window 1
def feature_area(f: 'fiona.Feature') -> float:
area = float(shapely.geometry.shape(f['geometry']).area)
return area

dataset_elements_area = sum(
[feature_area(feature) for feature in dataset_elements.values()]
)
clipped_dataset_elements_area = sum(
[feature_area(feature) for feature in clipped_dataset_elements.values()]
)
assert clipped_dataset_elements_area < dataset_elements_area

# Compare clipped geometries area in both windows with whole dataset window
w1_clipped_elements = clipped_geoemtries_dataset[query_window_1]
w2_clipped_elements = clipped_geoemtries_dataset[query_window_2]
unclipped_elements = dataset[dataset.bounds]

w1_clipped_area = sum(
[feature_area(feature) for feature in w1_clipped_elements.values()]
)
w2_clipped_area = sum(
[feature_area(feature) for feature in w2_clipped_elements.values()]
)
unclipped_area = sum(
[feature_area(feature) for feature in unclipped_elements.values()]
)

assert (w1_clipped_area + w2_clipped_area) == unclipped_area


class TestRasterizedVectorDataset:
@pytest.fixture(scope='class')
def dataset(self) -> CustomRasterizedVectorDataset:
root = os.path.join('tests', 'data', 'vector')
transforms = nn.Identity()
return CustomRasterizedVectorDataset(
root, res=(0.1, 0.1), transforms=transforms
)

@pytest.fixture(scope='class')
def multilabel(self) -> CustomRasterizedVectorDataset:
root = os.path.join('tests', 'data', 'vector')
transforms = nn.Identity()
return CustomRasterizedVectorDataset(
root, res=(0.1, 0.1), transforms=transforms, label_name='label_id'
)

def test_getitem(self, dataset: CustomVectorDataset) -> None:
def test_getitem(self, dataset: CustomRasterizedVectorDataset) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x['crs'], CRS)
Expand All @@ -413,11 +503,13 @@ def test_getitem(self, dataset: CustomVectorDataset) -> None:
torch.tensor([0, 1], dtype=torch.uint8),
)

def test_time_index(self, dataset: CustomVectorDataset) -> None:
def test_time_index(self, dataset: CustomRasterizedVectorDataset) -> None:
assert dataset.index.bounds[4] > 0
assert dataset.index.bounds[5] < sys.maxsize

def test_getitem_multilabel(self, multilabel: CustomVectorDataset) -> None:
def test_getitem_multilabel(
self, multilabel: CustomRasterizedVectorDataset
) -> None:
x = multilabel[multilabel.bounds]
assert isinstance(x, dict)
assert isinstance(x['crs'], CRS)
Expand All @@ -427,12 +519,12 @@ def test_getitem_multilabel(self, multilabel: CustomVectorDataset) -> None:
torch.tensor([0, 1, 2, 3], dtype=torch.uint8),
)

def test_empty_shapes(self, dataset: CustomVectorDataset) -> None:
def test_empty_shapes(self, dataset: CustomRasterizedVectorDataset) -> None:
query = BoundingBox(1.1, 1.9, 1.1, 1.9, 0, sys.maxsize)
x = dataset[query]
assert torch.equal(x['mask'], torch.zeros(8, 8, dtype=torch.uint8))

def test_invalid_query(self, dataset: CustomVectorDataset) -> None:
def test_invalid_query(self, dataset: CustomRasterizedVectorDataset) -> None:
query = BoundingBox(3, 3, 3, 3, 0, 0)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
Expand All @@ -441,11 +533,11 @@ def test_invalid_query(self, dataset: CustomVectorDataset) -> None:

def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
VectorDataset(tmp_path)
RasterizedVectorDataset(tmp_path)

def test_single_res(self) -> None:
root = os.path.join('tests', 'data', 'vector')
ds = CustomVectorDataset(root, res=0.1)
ds = CustomRasterizedVectorDataset(root, res=0.1)
assert ds.res == (0.1, 0.1)


Expand Down
6 changes: 6 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,14 @@
from .gbif import GBIF
from .gbm import GlobalBuildingMap
from .geo import (
DefaultRasterizationStrategy,
GeoDataset,
IntersectionDataset,
NonGeoClassificationDataset,
NonGeoDataset,
RasterDataset,
RasterizationStrategy,
RasterizedVectorDataset,
UnionDataset,
VectorDataset,
)
Expand Down Expand Up @@ -257,6 +260,7 @@
'DL4GAMAlps',
'DatasetNotFoundError',
'DeepGlobeLandCover',
'DefaultRasterizationStrategy',
'DependencyNotFoundError',
'DigitalTyphoon',
'EDDMapS',
Expand Down Expand Up @@ -314,6 +318,8 @@
'QuakeSet',
'RGBBandsMissingError',
'RasterDataset',
'RasterizationStrategy',
'RasterizedVectorDataset',
'ReforesTree',
'RwandaFieldBoundary',
'SSL4EOLBenchmark',
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datasets/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from rasterio.crs import CRS

from .errors import DatasetNotFoundError
from .geo import VectorDataset
from .geo import RasterizedVectorDataset
from .utils import Path, check_integrity, download_and_extract_archive


class CanadianBuildingFootprints(VectorDataset):
class CanadianBuildingFootprints(RasterizedVectorDataset):
"""Canadian Building Footprints dataset.

The `Canadian Building Footprints
Expand Down Expand Up @@ -134,7 +134,7 @@ def plot(
"""Plot a sample from the dataset.

Args:
sample: a sample returned by :meth:`VectorDataset.__getitem__`
sample: a sample returned by :meth:`RasterizedVectorDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle

Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datasets/eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from rasterio.crs import CRS

from .errors import DatasetNotFoundError
from .geo import VectorDataset
from .geo import RasterizedVectorDataset
from .utils import Path, check_integrity, download_and_extract_archive, download_url


class EuroCrops(VectorDataset):
class EuroCrops(RasterizedVectorDataset):
"""EuroCrops Dataset (Version 9).

The `EuroCrops <https://www.eurocrops.tum.de/index.html>`__ dataset combines "all
Expand Down Expand Up @@ -231,7 +231,7 @@ def plot(
"""Plot a sample from the dataset.

Args:
sample: a sample returned by :meth:`VectorDataset.__getitem__`
sample: a sample returned by :meth:`RasterizedVectorDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle

Expand Down
Loading
Loading