diff --git a/docs/conf.py b/docs/conf.py index 4078970c2e4..dfec675d3f4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -67,6 +67,7 @@ ('py:class', 'torchvision.models._api.WeightsEnum'), ('py:class', 'torchvision.models.resnet.ResNet'), ('py:class', 'torchvision.models.swin_transformer.SwinTransformer'), + ('py:class', 'geopandas.GeoDataFrame'), ] @@ -122,6 +123,7 @@ 'torch': ('https://pytorch.org/docs/stable', None), 'torchmetrics': ('https://lightning.ai/docs/torchmetrics/stable/', None), 'torchvision': ('https://pytorch.org/vision/stable', None), + 'geopandas': ('https://geopandas.org/en/stable/', None), } # nbsphinx diff --git a/docs/index.rst b/docs/index.rst index ced959493a8..60deae2c855 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,6 +29,7 @@ torchgeo :caption: Tutorials tutorials/getting_started + tutorials/visualizing_samples tutorials/custom_raster_dataset tutorials/transforms tutorials/indices diff --git a/docs/tutorials/visualizing_samples.ipynb b/docs/tutorials/visualizing_samples.ipynb new file mode 100644 index 00000000000..d1d5c9924d3 --- /dev/null +++ b/docs/tutorials/visualizing_samples.ipynb @@ -0,0 +1,159 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualizing Samples\n", + "\n", + "This tutorial shows how to visualize and save the extent of your samples before and during training. In this particular example, we compare a vanilla RandomGeoSampler with one bounded by multiple ROI's and show how easy it is to gain insight on the distribution of your samples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from torchgeo.datasets import NAIP, stack_samples\n", + "from torchgeo.datasets.utils import download_url\n", + "from torchgeo.samplers import RandomGeoSampler\n", + "\n", + "\n", + "def run_epochs(dataset, sampler):\n", + " dataloader = DataLoader(\n", + " dataset, sampler=sampler, batch_size=1, collate_fn=stack_samples, num_workers=0\n", + " )\n", + " fig, ax = plt.subplots()\n", + " num_epochs = 5\n", + " for epoch in range(num_epochs):\n", + " color = plt.cm.viridis(epoch / num_epochs)\n", + " # sampler.chips.to_file(f'naip_chips_epoch_{epoch}') # Optional: save chips to file for display in GIS software\n", + " ax = sampler.chips.plot(ax=ax, color=color)\n", + " for sample in dataloader:\n", + " pass\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "naip_root = os.path.join(tempfile.gettempdir(), 'naip')\n", + "naip_url = (\n", + " 'https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/'\n", + ")\n", + "tiles = ['m_3807511_ne_18_060_20181104.tif', 'm_3807512_sw_18_060_20180815.tif']\n", + "for tile in tiles:\n", + " download_url(naip_url + tile, naip_root)\n", + "\n", + "naip = NAIP(naip_root)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First we create the default sampler for our dataset (3 samples) and run it for 5 epochs and plot its results. Each color displays a different epoch, so we can see how the RandomGeoSampler has distributed it's samples for every epoch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sampler = RandomGeoSampler(naip, size=1000, length=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_epochs(naip, sampler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we split our dataset by two bounding boxes and re-inspect the samples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "from torchgeo.datasets import roi_split\n", + "from torchgeo.datasets.utils import BoundingBox\n", + "\n", + "rois = [\n", + " BoundingBox(440854, 442938, 4299766, 4301731, 0, np.inf),\n", + " BoundingBox(449070, 451194, 4289463, 4291746, 0, np.inf),\n", + "]\n", + "datasets = roi_split(naip, rois)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "combined = datasets[0] | datasets[1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sampler = RandomGeoSampler(combined, size=1000, length=3)\n", + "run_epochs(combined, sampler)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cca", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 4a985b3a40a..83062ce28fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,8 @@ dependencies = [ "einops>=0.3", # fiona 1.8.21+ required for Python 3.10 wheels "fiona>=1.8.21", + # geopandas 0.13.2 is the last version to support pandas 1.3, but has feather support + "geopandas>=0.13.2", # kornia 0.7.3+ required for instance segmentation support in AugmentationSequential "kornia>=0.7.3", # lightly 1.4.5+ required for LARS optimizer @@ -58,6 +60,8 @@ dependencies = [ "pandas>=1.3.3", # pillow 8.4+ required for Python 3.10 wheels "pillow>=8.4", + # pyarrow 12.0+ required for feather support + "pyarrow>=17.0.0", # pyproj 3.3+ required for Python 3.10 wheels "pyproj>=3.3", # rasterio 1.3+ required for Python 3.10 wheels diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index a6e91f70fe9..24e15ba962a 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -4,6 +4,7 @@ setuptools==61.0.0 # install einops==0.3.0 fiona==1.8.21 +geopandas==0.13.2 kornia==0.7.3 lightly==1.4.5 lightning[pytorch-extra]==2.0.0 @@ -11,6 +12,7 @@ matplotlib==3.5.0 numpy==1.21.2 pandas==1.3.3 pillow==8.4.0 +pyarrow==17.0.0 pyproj==3.3.0 rasterio==1.3.0.post1 rtree==1.0.0 diff --git a/requirements/required.txt b/requirements/required.txt index 695d77e45f2..b875a18843c 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -4,6 +4,7 @@ setuptools==75.1.0 # install einops==0.8.0 fiona==1.10.1 +geopandas==0.14.4 kornia==0.7.3 lightly==1.5.12 lightning[pytorch-extra]==2.4.0 @@ -11,6 +12,7 @@ matplotlib==3.9.2 numpy==2.1.1 pandas==2.2.3 pillow==10.4.0 +pyarrow==17.0.0 pyproj==3.6.1 rasterio==1.3.11 rtree==1.3.0 diff --git a/requirements/style.txt b/requirements/style.txt index a88e62af3cc..648a2033db5 100644 --- a/requirements/style.txt +++ b/requirements/style.txt @@ -1,3 +1,3 @@ # style mypy==1.11.2 -ruff==0.6.6 +ruff==0.6.7 diff --git a/tests/data/samplers/filtering_4x4.feather b/tests/data/samplers/filtering_4x4.feather new file mode 100644 index 00000000000..305d37e4fa6 Binary files /dev/null and b/tests/data/samplers/filtering_4x4.feather differ diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.cpg b/tests/data/samplers/filtering_4x4/filtering_4x4.cpg new file mode 100644 index 00000000000..57decb48120 --- /dev/null +++ b/tests/data/samplers/filtering_4x4/filtering_4x4.cpg @@ -0,0 +1 @@ +ISO-8859-1 diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.dbf b/tests/data/samplers/filtering_4x4/filtering_4x4.dbf new file mode 100644 index 00000000000..499d67bcec4 Binary files /dev/null and b/tests/data/samplers/filtering_4x4/filtering_4x4.dbf differ diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.prj b/tests/data/samplers/filtering_4x4/filtering_4x4.prj new file mode 100644 index 00000000000..42fd4b91b78 --- /dev/null +++ b/tests/data/samplers/filtering_4x4/filtering_4x4.prj @@ -0,0 +1 @@ +PROJCS["NAD_1983_BC_Environment_Albers",GEOGCS["GCS_North_American_1983",DATUM["D_North_American_1983",SPHEROID["GRS_1980",6378137.0,298.257222101]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Albers"],PARAMETER["False_Easting",1000000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-126.0],PARAMETER["Standard_Parallel_1",50.0],PARAMETER["Standard_Parallel_2",58.5],PARAMETER["Latitude_Of_Origin",45.0],UNIT["Meter",1.0]] diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.shp b/tests/data/samplers/filtering_4x4/filtering_4x4.shp new file mode 100644 index 00000000000..65606c26dd6 Binary files /dev/null and b/tests/data/samplers/filtering_4x4/filtering_4x4.shp differ diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.shx b/tests/data/samplers/filtering_4x4/filtering_4x4.shx new file mode 100644 index 00000000000..b2028e759e5 Binary files /dev/null and b/tests/data/samplers/filtering_4x4/filtering_4x4.shx differ diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index 8e5fd13d292..80e71c52a43 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -7,6 +7,7 @@ import pytest import torch from _pytest.fixtures import SubRequest +from geopandas import GeoDataFrame from lightning.pytorch import Trainer from matplotlib.figure import Figure from rasterio.crs import CRS @@ -182,7 +183,7 @@ def test_zero_length_sampler(self) -> None: dm = CustomGeoDataModule() dm.dataset = CustomGeoDataset() dm.sampler = RandomGeoSampler(dm.dataset, 1, 1) - dm.sampler.length = 0 + dm.sampler.chips = GeoDataFrame() msg = r'CustomGeoDataModule\.sampler has length 0.' with pytest.raises(MisconfigurationException, match=msg): dm.train_dataloader() diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 743c8be70da..96a5861ac9c 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -2,13 +2,16 @@ # Licensed under the MIT License. import math -from collections.abc import Iterator +import os from itertools import product +import geopandas as gpd import pytest import torch from _pytest.fixtures import SubRequest +from geopandas import GeoDataFrame from rasterio.crs import CRS +from shapely.geometry import box from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples @@ -24,14 +27,23 @@ class CustomGeoSampler(GeoSampler): def __init__(self) -> None: - pass - - def __iter__(self) -> Iterator[BoundingBox]: - for i in range(len(self)): - yield BoundingBox(i, i, i, i, i, i) - - def __len__(self) -> int: - return 2 + self.chips = self.get_chips() + + def get_chips(self) -> GeoDataFrame: + chips = [] + for i in range(2): + chips.append( + { + 'geometry': box(i, i, i, i), + 'minx': i, + 'miny': i, + 'maxx': i, + 'maxy': i, + 'mint': i, + 'maxt': i, + } + ) + return GeoDataFrame(chips, crs=CRS.from_epsg(3005)) class CustomGeoDataset(GeoDataset): @@ -65,6 +77,64 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): GeoSampler(dataset) # type: ignore[abstract] + @pytest.mark.parametrize( + 'filtering_file', ['filtering_4x4', 'filtering_4x4.feather'] + ) + def test_filtering_from_path(self, filtering_file: str) -> None: + datadir = os.path.join('tests', 'data', 'samplers') + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler( + ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) + ) + iterator = iter(sampler) + + assert len(sampler) == 4 + filtering_path = os.path.join(datadir, filtering_file) + sampler.filter_chips(filtering_path, 'intersects', 'drop') + assert len(sampler) == 3 + assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) + + def test_filtering_from_gdf(self) -> None: + datadir = os.path.join('tests', 'data', 'samplers') + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler( + ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) + ) + iterator = iter(sampler) + + # Dropping first chip + assert len(sampler) == 4 + filtering_gdf = gpd.read_file(os.path.join(datadir, 'filtering_4x4')) + sampler.filter_chips(filtering_gdf, 'intersects', 'drop') + assert len(sampler) == 3 + assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) + + # Keeping only first chip + sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) + iterator = iter(sampler) + sampler.filter_chips(filtering_gdf, 'intersects', 'keep') + assert len(sampler) == 1 + assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) + + def test_set_worker_split(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler( + ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) + ) + assert len(sampler) == 4 + sampler.set_worker_split(total_workers=4, worker_num=1) + assert len(sampler) == 1 + + def test_save_chips(self, tmpdir_factory: pytest.TempdirFactory) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) + sampler.save(str(tmpdir_factory.mktemp('out').join('chips'))) + sampler.save(str(tmpdir_factory.mktemp('out').join('chips.feather'))) + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -116,6 +186,19 @@ def test_roi(self, dataset: CustomGeoDataset) -> None: for query in sampler: assert query in roi + def test_empty(self, dataset: CustomGeoDataset) -> None: + sampler = RandomGeoSampler(dataset, 5, length=0) + assert len(sampler) == 0 + + def test_refresh_samples(self, dataset: CustomGeoDataset) -> None: + dataset.index.insert(0, (0, 100, 200, 300, 400, 500)) + sampler = RandomGeoSampler(dataset, 5, length=1) + samples = list(sampler) + assert len(sampler) == 1 + sampler.refresh_samples() + assert list(sampler) != samples + assert len(sampler) == 1 + def test_small_area(self) -> None: ds = CustomGeoDataset(res=1) ds.index.insert(0, (0, 10, 0, 10, 0, 10)) @@ -277,11 +360,11 @@ def dataset(self) -> CustomGeoDataset: def sampler(self, dataset: CustomGeoDataset) -> PreChippedGeoSampler: return PreChippedGeoSampler(dataset, shuffle=True) - def test_iter(self, sampler: GridGeoSampler) -> None: + def test_iter(self, sampler: PreChippedGeoSampler) -> None: for _ in sampler: continue - def test_len(self, sampler: GridGeoSampler) -> None: + def test_len(self, sampler: PreChippedGeoSampler) -> None: assert len(sampler) == 2 def test_roi(self, dataset: CustomGeoDataset) -> None: diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 6fa4331c4b7..45159713465 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -7,16 +7,44 @@ from collections.abc import Callable, Iterable, Iterator from functools import partial +import geopandas as gpd +import numpy as np import torch +from geopandas import GeoDataFrame from rtree.index import Index, Property +from shapely.geometry import box from torch import Generator from torch.utils.data import Sampler +from tqdm import tqdm from ..datasets import BoundingBox, GeoDataset from .constants import Units from .utils import _to_tuple, get_random_bounding_box, tile_to_chips +def load_file(path: str | GeoDataFrame) -> GeoDataFrame: + """Load a file from the given path. + + Parameters: + path (str or GeoDataFrame): The path to the file or a GeoDataFrame object. + + Returns: + GeoDataFrame: The loaded file as a GeoDataFrame. + + Raises: + None + + """ + if isinstance(path, GeoDataFrame): + return path + if path.endswith('.feather'): + print(f'Reading feather file: {path}') + return gpd.read_feather(path) + else: + print(f'Reading shapefile: {path}') + return gpd.read_file(path) + + class GeoSampler(Sampler[BoundingBox], abc.ABC): """Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`. @@ -46,14 +74,124 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: self.res = dataset.res self.roi = roi + self.dataset = dataset + self.chips: GeoDataFrame = GeoDataFrame() + + @staticmethod + def __save_as_gpd_or_feather( + path: str, gdf: GeoDataFrame, driver: str = 'ESRI Shapefile' + ) -> None: + """Save a GeoDataFrame as a file supported by any geopandas driver or as a feather file. + + Parameters: + path (str): The path to save the file. + gdf (GeoDataFrame): The GeoDataFrame to be saved. + driver (str, optional): The driver to be used for saving the file. Defaults to 'ESRI Shapefile'. + + Returns: + None + """ + if path.endswith('.feather'): + gdf.to_feather(path) + else: + gdf.to_file(path, driver=driver) @abc.abstractmethod + def get_chips(self) -> GeoDataFrame: + """Determines the way to get the extent of the chips (samples) of the dataset. + + Should return a GeoDataFrame with the extend of the chips with the columns + geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip. It is + expected that every sampler calls this method to get the chips as one of the + last steps in the ``__init__`` method. + """ + + def filter_chips( + self, + filter_by: str | GeoDataFrame, + predicate: str = 'intersects', + action: str = 'keep', + ) -> None: + """Filter the default set of chips in the sampler down to a specific subset. + + Args: + filter_by: The file or geodataframe for which the geometries will be used during filtering + predicate: Predicate as used in Geopandas sindex.query_bulk + action: What to do with the chips that satisfy the condition by the predicacte. + Can either be ``'drop'`` or ``'keep'``. + """ + prefilter_leng = len(self.chips) + filtering_gdf = load_file(filter_by).to_crs(self.dataset.crs) + + if action == 'keep': + self.chips = self.chips.iloc[ + list( + set( + self.chips.sindex.query_bulk( + filtering_gdf.geometry, predicate=predicate + )[1] + ) + ) + ].reset_index(drop=True) + elif action == 'drop': + self.chips = self.chips.drop( + index=list( + set( + self.chips.sindex.query_bulk( + filtering_gdf.geometry, predicate=predicate + )[1] + ) + ) + ).reset_index(drop=True) + + self.chips.fid = self.chips.index + print(f'Filter step reduced chips from {prefilter_leng} to {len(self.chips)}') + assert not self.chips.empty, 'No chips left after filtering!' + + def set_worker_split(self, total_workers: int, worker_num: int) -> None: + """Split the chips for multi-worker inference. + + Splits the chips in n equal parts for the number of workers and keeps the set of + chips for the specific worker id, convenient if you want to split the chips across + multiple dataloaders for multi-worker inference. + + Args: + total_workers (int): The total number of parts to split the chips + worker_num (int): The id of the worker (which part to keep), starts from 0 + + """ + self.chips = np.array_split(self.chips, total_workers)[worker_num] + + def save(self, path: str, driver: str = 'ESRI Shapefile') -> None: + """Save the chips as a file format supported by GeoPandas or feather file. + + Parameters: + - path (str): The path to save the file. + - driver (str): The driver to use for saving the file. Defaults to 'ESRI Shapefile'. + + Returns: + - None + """ + self.__save_as_gpd_or_feather(path, self.chips, driver) + def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ + for _, chip in self.chips.iterrows(): + yield BoundingBox( + chip.minx, chip.maxx, chip.miny, chip.maxy, chip.mint, chip.maxt + ) + + def __len__(self) -> int: + """Return the number of samples over the ROI. + + Returns: + number of patches that will be sampled + """ + return len(self.chips) class RandomGeoSampler(GeoSampler): @@ -137,32 +275,63 @@ def __init__( if torch.sum(self.areas) == 0: self.areas += 1 + self.chips = self.get_chips() + def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ - for _ in range(len(self)): + self.refresh_samples() + for _, chip in self.chips.iterrows(): + yield BoundingBox( + chip.minx, chip.maxx, chip.miny, chip.maxy, chip.mint, chip.maxt + ) + + def refresh_samples(self) -> None: + """Refresh the samples in the sampler. + + This method is useful when you want to refresh the samples in the sampler + without creating a new sampler instance. + """ + self.chips = self.get_chips() + + def get_chips(self) -> GeoDataFrame: + """Generate chips from the dataset. + + Returns: + A GeoDataFrame containing the generated chips. + """ + chips = [] + print('generating samples... ') + for _ in tqdm(range(self.length)): # Choose a random tile, weighted by area idx = torch.multinomial(self.areas, 1) hit = self.hits[idx] bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box( - bounds, self.size, self.res, self.generator - ) + bbox = get_random_bounding_box(bounds, self.size, self.res, self.generator) + minx, maxx, miny, maxy, mint, maxt = tuple(bbox) + chip = { + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, + } + chips.append(chip) + + if chips: + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf['fid'] = chips_gdf.index - yield bounding_box - - def __len__(self) -> int: - """Return the number of samples in a single epoch. - - Returns: - length of the epoch - """ - return self.length + else: + chips_gdf = GeoDataFrame() + return chips_gdf class GridGeoSampler(GeoSampler): @@ -225,24 +394,20 @@ def __init__( ): self.hits.append(hit) - self.length = 0 - for hit in self.hits: - bounds = BoundingBox(*hit.bounds) - rows, cols = tile_to_chips(bounds, self.size, self.stride) - self.length += rows * cols + self.chips = self.get_chips() - def __iter__(self) -> Iterator[BoundingBox]: - """Return the index of a dataset. + def get_chips(self) -> GeoDataFrame: + """Generates chips from the given hits. Returns: - (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + A GeoDataFrame containing the generated chips. """ - # For each tile... + print('generating samples... ') + self.length = 0 + chips = [] for hit in self.hits: bounds = BoundingBox(*hit.bounds) rows, cols = tile_to_chips(bounds, self.size, self.stride) - mint = bounds.mint - maxt = bounds.maxt # For each row... for i in range(rows): @@ -254,15 +419,25 @@ def __iter__(self) -> Iterator[BoundingBox]: minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] - yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) + chip = { + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': bounds.mint, + 'maxt': bounds.maxt, + } + self.length += 1 + chips.append(chip) - def __len__(self) -> int: - """Return the number of samples over the ROI. + if chips: + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf['fid'] = chips_gdf.index - Returns: - number of patches that will be sampled - """ - return self.length + else: + chips_gdf = GeoDataFrame() + return chips_gdf class PreChippedGeoSampler(GeoSampler): @@ -309,23 +484,35 @@ def __init__( for hit in self.index.intersection(tuple(self.roi), objects=True): self.hits.append(hit) - def __iter__(self) -> Iterator[BoundingBox]: - """Return the index of a dataset. + self.length = len(self.hits) + self.chips = self.get_chips() + + def get_chips(self) -> GeoDataFrame: + """Generate chips from the hits and return them as a GeoDataFrame. Returns: - (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + A GeoDataFrame containing the generated chips. """ generator: Callable[[int], Iterable[int]] = range if self.shuffle: generator = partial(torch.randperm, generator=self.generator) - for idx in generator(len(self)): - yield BoundingBox(*self.hits[idx].bounds) - - def __len__(self) -> int: - """Return the number of samples over the ROI. - - Returns: - number of patches that will be sampled - """ - return len(self.hits) + print('generating samples... ') + chips = [] + for idx in generator(self.length): + minx, maxx, miny, maxy, mint, maxt = self.hits[idx].bounds + chip = { + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, + } + self.length += 1 + chips.append(chip) + + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf['fid'] = chips_gdf.index + return chips_gdf