Skip to content

Commit 7046e89

Browse files
committed
Merge branch 'geosampler_prechipping' into vers_working_branch
2 parents 514745d + 25ce0e1 commit 7046e89

File tree

15 files changed

+499
-56
lines changed

15 files changed

+499
-56
lines changed

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
('py:class', 'torchvision.models._api.WeightsEnum'),
6868
('py:class', 'torchvision.models.resnet.ResNet'),
6969
('py:class', 'torchvision.models.swin_transformer.SwinTransformer'),
70+
('py:class', 'geopandas.GeoDataFrame'),
7071
]
7172

7273

@@ -122,6 +123,7 @@
122123
'torch': ('https://pytorch.org/docs/stable', None),
123124
'torchmetrics': ('https://lightning.ai/docs/torchmetrics/stable/', None),
124125
'torchvision': ('https://pytorch.org/vision/stable', None),
126+
'geopandas': ('https://geopandas.org/en/stable/', None),
125127
}
126128

127129
# nbsphinx

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ torchgeo
2929
:caption: Tutorials
3030

3131
tutorials/getting_started
32+
tutorials/visualizing_samples
3233
tutorials/custom_raster_dataset
3334
tutorials/transforms
3435
tutorials/indices
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Visualizing Samples\n",
8+
"\n",
9+
"This tutorial shows how to visualize and save the extent of your samples before and during training. In this particular example, we compare a vanilla RandomGeoSampler with one bounded by multiple ROI's and show how easy it is to gain insight on the distribution of your samples."
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import os\n",
19+
"import tempfile\n",
20+
"\n",
21+
"import matplotlib.pyplot as plt\n",
22+
"from torch.utils.data import DataLoader\n",
23+
"\n",
24+
"from torchgeo.datasets import NAIP, stack_samples\n",
25+
"from torchgeo.datasets.utils import download_url\n",
26+
"from torchgeo.samplers import RandomGeoSampler\n",
27+
"\n",
28+
"\n",
29+
"def run_epochs(dataset, sampler):\n",
30+
" dataloader = DataLoader(\n",
31+
" dataset, sampler=sampler, batch_size=1, collate_fn=stack_samples, num_workers=0\n",
32+
" )\n",
33+
" fig, ax = plt.subplots()\n",
34+
" num_epochs = 5\n",
35+
" for epoch in range(num_epochs):\n",
36+
" color = plt.cm.viridis(epoch / num_epochs)\n",
37+
" # sampler.chips.to_file(f'naip_chips_epoch_{epoch}') # Optional: save chips to file for display in GIS software\n",
38+
" ax = sampler.chips.plot(ax=ax, color=color)\n",
39+
" for sample in dataloader:\n",
40+
" pass\n",
41+
" plt.show()"
42+
]
43+
},
44+
{
45+
"cell_type": "markdown",
46+
"metadata": {},
47+
"source": [
48+
"Generate dataset"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": null,
54+
"metadata": {},
55+
"outputs": [],
56+
"source": [
57+
"naip_root = os.path.join(tempfile.gettempdir(), 'naip')\n",
58+
"naip_url = (\n",
59+
" 'https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/'\n",
60+
")\n",
61+
"tiles = ['m_3807511_ne_18_060_20181104.tif', 'm_3807512_sw_18_060_20180815.tif']\n",
62+
"for tile in tiles:\n",
63+
" download_url(naip_url + tile, naip_root)\n",
64+
"\n",
65+
"naip = NAIP(naip_root)"
66+
]
67+
},
68+
{
69+
"cell_type": "markdown",
70+
"metadata": {},
71+
"source": [
72+
"First we create the default sampler for our dataset (3 samples) and run it for 5 epochs and plot its results. Each color displays a different epoch, so we can see how the RandomGeoSampler has distributed it's samples for every epoch."
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": null,
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"sampler = RandomGeoSampler(naip, size=1000, length=3)"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"metadata": {},
88+
"outputs": [],
89+
"source": [
90+
"run_epochs(naip, sampler)"
91+
]
92+
},
93+
{
94+
"cell_type": "markdown",
95+
"metadata": {},
96+
"source": [
97+
"Now we split our dataset by two bounding boxes and re-inspect the samples."
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {},
104+
"outputs": [],
105+
"source": [
106+
"import numpy as np\n",
107+
"\n",
108+
"from torchgeo.datasets import roi_split\n",
109+
"from torchgeo.datasets.utils import BoundingBox\n",
110+
"\n",
111+
"rois = [\n",
112+
" BoundingBox(440854, 442938, 4299766, 4301731, 0, np.inf),\n",
113+
" BoundingBox(449070, 451194, 4289463, 4291746, 0, np.inf),\n",
114+
"]\n",
115+
"datasets = roi_split(naip, rois)"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"metadata": {},
122+
"outputs": [],
123+
"source": [
124+
"combined = datasets[0] | datasets[1]"
125+
]
126+
},
127+
{
128+
"cell_type": "code",
129+
"execution_count": null,
130+
"metadata": {},
131+
"outputs": [],
132+
"source": [
133+
"sampler = RandomGeoSampler(combined, size=1000, length=3)\n",
134+
"run_epochs(combined, sampler)"
135+
]
136+
}
137+
],
138+
"metadata": {
139+
"kernelspec": {
140+
"display_name": "cca",
141+
"language": "python",
142+
"name": "python3"
143+
},
144+
"language_info": {
145+
"codemirror_mode": {
146+
"name": "ipython",
147+
"version": 3
148+
},
149+
"file_extension": ".py",
150+
"mimetype": "text/x-python",
151+
"name": "python",
152+
"nbconvert_exporter": "python",
153+
"pygments_lexer": "ipython3",
154+
"version": "3.10.14"
155+
}
156+
},
157+
"nbformat": 4,
158+
"nbformat_minor": 2
159+
}

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ dependencies = [
4040
"einops>=0.3",
4141
# fiona 1.8.21+ required for Python 3.10 wheels
4242
"fiona>=1.8.21",
43+
# geopandas 0.13.2 is the last version to support pandas 1.3, but has feather support
44+
"geopandas>=0.13.2",
4345
# kornia 0.7.3+ required for instance segmentation support in AugmentationSequential
4446
"kornia>=0.7.3",
4547
# lightly 1.4.5+ required for LARS optimizer
@@ -58,6 +60,8 @@ dependencies = [
5860
"pandas>=1.3.3",
5961
# pillow 8.4+ required for Python 3.10 wheels
6062
"pillow>=8.4",
63+
# pyarrow 12.0+ required for feather support
64+
"pyarrow>=17.0.0",
6165
# pyproj 3.3+ required for Python 3.10 wheels
6266
"pyproj>=3.3",
6367
# rasterio 1.3+ required for Python 3.10 wheels

requirements/min-reqs.old

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ setuptools==61.0.0
44
# install
55
einops==0.3.0
66
fiona==1.8.21
7+
geopandas==0.13.2
78
kornia==0.7.3
89
lightly==1.4.5
910
lightning[pytorch-extra]==2.0.0
1011
matplotlib==3.5.0
1112
numpy==1.21.2
1213
pandas==1.3.3
1314
pillow==8.4.0
15+
pyarrow==17.0.0
1416
pyproj==3.3.0
1517
rasterio==1.3.0.post1
1618
rtree==1.0.0

requirements/required.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ setuptools==75.1.0
44
# install
55
einops==0.8.0
66
fiona==1.10.1
7+
geopandas==0.14.4
78
kornia==0.7.3
89
lightly==1.5.12
910
lightning[pytorch-extra]==2.4.0
1011
matplotlib==3.9.2
1112
numpy==2.1.1
1213
pandas==2.2.3
1314
pillow==10.4.0
15+
pyarrow==17.0.0
1416
pyproj==3.6.1
1517
rasterio==1.3.11
1618
rtree==1.3.0
5.36 KB
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ISO-8859-1
78 Bytes
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
PROJCS["NAD_1983_BC_Environment_Albers",GEOGCS["GCS_North_American_1983",DATUM["D_North_American_1983",SPHEROID["GRS_1980",6378137.0,298.257222101]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Albers"],PARAMETER["False_Easting",1000000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-126.0],PARAMETER["Standard_Parallel_1",50.0],PARAMETER["Standard_Parallel_2",58.5],PARAMETER["Latitude_Of_Origin",45.0],UNIT["Meter",1.0]]

0 commit comments

Comments
 (0)