Skip to content

Commit 6789a25

Browse files
keves1adamjstewart
andauthored
Add ChangeDetectionTask (#2422)
* starting from PR #1760 * changed from image1, image2 to stacked images. * fixed mypy and ruff issues * adding tests. some still need work. * making Kornia transforms work with added temporal dimension. * Support only binary change with two timesteps. Moved loss functions to torchgeo/losses. * fixed issues with tests. * Update versionadded Co-authored-by: Adam J. Stewart <[email protected]> * removed custom loss functions. * added docstring. * Fix syntax error in Python 3.10 * revert target dtype to long in dataset and change to float in trainer instead. * ruff format * updated OSCD dataset tests * prettier format * Using K.CenterCrop until Kornia has a better option. * Removing file per #978 * misc updates from review comments * adding test coverage * match statements and denormalizing for plotting * using monkeypatch to test predict_step * updated docstring * ruff format * test predict_step and other misc changes * misc fixes * updated docstring --------- Co-authored-by: Adam J. Stewart <[email protected]>
1 parent b19fac1 commit 6789a25

File tree

9 files changed

+546
-101
lines changed

9 files changed

+546
-101
lines changed

tests/conf/oscd.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
model:
2+
class_path: ChangeDetectionTask
3+
init_args:
4+
loss: 'bce'
5+
model: 'unet'
6+
backbone: 'resnet18'
7+
in_channels: 13
8+
data:
9+
class_path: OSCDDataModule
10+
init_args:
11+
batch_size: 2
12+
patch_size: 16
13+
val_split_pct: 0.5
14+
dict_kwargs:
15+
root: 'tests/data/oscd'

tests/datamodules/test_oscd.py

Lines changed: 0 additions & 82 deletions
This file was deleted.

tests/datasets/test_oscd.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,15 @@ def dataset(
6666
def test_getitem(self, dataset: OSCD) -> None:
6767
x = dataset[0]
6868
assert isinstance(x, dict)
69-
assert isinstance(x['image1'], torch.Tensor)
70-
assert x['image1'].ndim == 3
71-
assert isinstance(x['image2'], torch.Tensor)
72-
assert x['image2'].ndim == 3
69+
assert isinstance(x['image'], torch.Tensor)
70+
assert x['image'].ndim == 4
7371
assert isinstance(x['mask'], torch.Tensor)
7472
assert x['mask'].ndim == 2
7573

7674
if dataset.bands == OSCD.rgb_bands:
77-
assert x['image1'].shape[0] == 3
78-
assert x['image2'].shape[0] == 3
75+
assert x['image'].shape[1] == 3
7976
else:
80-
assert x['image1'].shape[0] == 13
81-
assert x['image2'].shape[0] == 13
77+
assert x['image'].shape[1] == 13
8278

8379
def test_len(self, dataset: OSCD) -> None:
8480
if dataset.split == 'train':

tests/trainers/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pathlib import Path
77

88
import pytest
9+
import timm
910
import torch
10-
import torchvision
1111
from _pytest.fixtures import SubRequest
1212
from torch import Tensor
1313
from torch.nn.modules import Module
@@ -22,8 +22,9 @@ def fast_dev_run(request: SubRequest) -> bool:
2222

2323

2424
@pytest.fixture(scope='package')
25-
def model() -> Module:
26-
model: Module = torchvision.models.resnet18(weights=None)
25+
def model(request: SubRequest) -> Module:
26+
in_channels = getattr(request, 'param', 3)
27+
model: Module = timm.create_model('resnet18', in_chans=in_channels)
2728
return model
2829

2930

tests/trainers/test_change.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import os
5+
from pathlib import Path
6+
from typing import Any, Literal
7+
8+
import pytest
9+
import segmentation_models_pytorch as smp
10+
import timm
11+
import torch
12+
import torch.nn as nn
13+
from lightning.pytorch import Trainer
14+
from pytest import MonkeyPatch
15+
from torch.nn.modules import Module
16+
from torchvision.models._api import WeightsEnum
17+
18+
from torchgeo.datamodules import MisconfigurationException, OSCDDataModule
19+
from torchgeo.datasets import OSCD, RGBBandsMissingError
20+
from torchgeo.main import main
21+
from torchgeo.models import ResNet18_Weights
22+
from torchgeo.trainers import ChangeDetectionTask
23+
24+
25+
class ChangeDetectionTestModel(Module):
26+
def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None:
27+
super().__init__()
28+
self.conv1 = nn.Conv2d(
29+
in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0
30+
)
31+
32+
def forward(self, x: torch.Tensor) -> torch.Tensor:
33+
x = self.conv1(x)
34+
return x
35+
36+
37+
def create_model(**kwargs: Any) -> Module:
38+
return ChangeDetectionTestModel(**kwargs)
39+
40+
41+
def plot(*args: Any, **kwargs: Any) -> None:
42+
return None
43+
44+
45+
def plot_missing_bands(*args: Any, **kwargs: Any) -> None:
46+
raise RGBBandsMissingError()
47+
48+
49+
class PredictChangeDetectionDataModule(OSCDDataModule):
50+
def setup(self, stage: str) -> None:
51+
self.predict_dataset = OSCD(**self.kwargs)
52+
53+
54+
class TestChangeDetectionTask:
55+
@pytest.mark.parametrize('name', ['oscd'])
56+
def test_trainer(
57+
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
58+
) -> None:
59+
config = os.path.join('tests', 'conf', name + '.yaml')
60+
61+
monkeypatch.setattr(smp, 'Unet', create_model)
62+
63+
args = [
64+
'--config',
65+
config,
66+
'--trainer.accelerator',
67+
'cpu',
68+
'--trainer.fast_dev_run',
69+
str(fast_dev_run),
70+
'--trainer.max_epochs',
71+
'1',
72+
'--trainer.log_every_n_steps',
73+
'1',
74+
]
75+
76+
main(['fit', *args])
77+
try:
78+
main(['test', *args])
79+
except MisconfigurationException:
80+
pass
81+
try:
82+
main(['predict', *args])
83+
except MisconfigurationException:
84+
pass
85+
86+
def test_predict(self, fast_dev_run: bool) -> None:
87+
datamodule = PredictChangeDetectionDataModule(
88+
root=os.path.join('tests', 'data', 'oscd'),
89+
batch_size=2,
90+
patch_size=32,
91+
val_split_pct=0.5,
92+
num_workers=0,
93+
)
94+
model = ChangeDetectionTask(backbone='resnet18', in_channels=13, model='unet')
95+
trainer = Trainer(
96+
accelerator='cpu',
97+
fast_dev_run=fast_dev_run,
98+
log_every_n_steps=1,
99+
max_epochs=1,
100+
)
101+
trainer.predict(model=model, datamodule=datamodule)
102+
103+
@pytest.fixture
104+
def weights(self) -> WeightsEnum:
105+
return ResNet18_Weights.SENTINEL2_ALL_MOCO
106+
107+
@pytest.fixture
108+
def mocked_weights(
109+
self,
110+
tmp_path: Path,
111+
monkeypatch: MonkeyPatch,
112+
weights: WeightsEnum,
113+
load_state_dict_from_url: None,
114+
) -> WeightsEnum:
115+
path = tmp_path / f'{weights}.pth'
116+
# multiply in_chans by 2 since images are concatenated
117+
model = timm.create_model(
118+
weights.meta['model'], in_chans=weights.meta['in_chans'] * 2
119+
)
120+
torch.save(model.state_dict(), path)
121+
try:
122+
monkeypatch.setattr(weights.value, 'url', str(path))
123+
except AttributeError:
124+
monkeypatch.setattr(weights, 'url', str(path))
125+
return weights
126+
127+
@pytest.mark.parametrize('model', [6], indirect=True)
128+
def test_weight_file(self, checkpoint: str) -> None:
129+
ChangeDetectionTask(backbone='resnet18', weights=checkpoint)
130+
131+
def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
132+
ChangeDetectionTask(
133+
backbone=mocked_weights.meta['model'],
134+
weights=mocked_weights,
135+
in_channels=mocked_weights.meta['in_chans'],
136+
)
137+
138+
def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
139+
ChangeDetectionTask(
140+
backbone=mocked_weights.meta['model'],
141+
weights=str(mocked_weights),
142+
in_channels=mocked_weights.meta['in_chans'],
143+
)
144+
145+
@pytest.mark.slow
146+
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
147+
ChangeDetectionTask(
148+
backbone=weights.meta['model'],
149+
weights=weights,
150+
in_channels=weights.meta['in_chans'],
151+
)
152+
153+
@pytest.mark.slow
154+
def test_weight_str_download(self, weights: WeightsEnum) -> None:
155+
ChangeDetectionTask(
156+
backbone=weights.meta['model'],
157+
weights=str(weights),
158+
in_channels=weights.meta['in_chans'],
159+
)
160+
161+
@pytest.mark.parametrize('model_name', ['unet', 'fcsiamdiff', 'fcsiamconc'])
162+
@pytest.mark.parametrize(
163+
'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0']
164+
)
165+
def test_freeze_backbone(
166+
self, model_name: Literal['unet', 'fcsiamdiff', 'fcsiamconc'], backbone: str
167+
) -> None:
168+
model = ChangeDetectionTask(
169+
model=model_name, backbone=backbone, freeze_backbone=True
170+
)
171+
assert all(
172+
[param.requires_grad is False for param in model.model.encoder.parameters()]
173+
)
174+
assert all([param.requires_grad for param in model.model.decoder.parameters()])
175+
assert all(
176+
[
177+
param.requires_grad
178+
for param in model.model.segmentation_head.parameters()
179+
]
180+
)
181+
182+
@pytest.mark.parametrize('model_name', ['unet', 'fcsiamdiff', 'fcsiamconc'])
183+
def test_freeze_decoder(
184+
self, model_name: Literal['unet', 'fcsiamdiff', 'fcsiamconc']
185+
) -> None:
186+
model = ChangeDetectionTask(model=model_name, freeze_decoder=True)
187+
assert all(
188+
[param.requires_grad is False for param in model.model.decoder.parameters()]
189+
)
190+
assert all([param.requires_grad for param in model.model.encoder.parameters()])
191+
assert all(
192+
[
193+
param.requires_grad
194+
for param in model.model.segmentation_head.parameters()
195+
]
196+
)
197+
198+
@pytest.mark.parametrize('loss_fn', ['bce', 'jaccard', 'focal'])
199+
def test_losses(self, loss_fn: Literal['bce', 'jaccard', 'focal']) -> None:
200+
ChangeDetectionTask(loss=loss_fn)
201+
202+
def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
203+
monkeypatch.setattr(OSCDDataModule, 'plot', plot)
204+
datamodule = OSCDDataModule(
205+
root=os.path.join('tests', 'data', 'oscd'),
206+
batch_size=2,
207+
patch_size=32,
208+
val_split_pct=0.5,
209+
num_workers=0,
210+
)
211+
model = ChangeDetectionTask(backbone='resnet18', in_channels=13, model='unet')
212+
trainer = Trainer(
213+
accelerator='cpu',
214+
fast_dev_run=fast_dev_run,
215+
log_every_n_steps=1,
216+
max_epochs=1,
217+
)
218+
trainer.validate(model=model, datamodule=datamodule)
219+
220+
def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
221+
monkeypatch.setattr(OSCDDataModule, 'plot', plot_missing_bands)
222+
datamodule = OSCDDataModule(
223+
root=os.path.join('tests', 'data', 'oscd'),
224+
batch_size=2,
225+
patch_size=32,
226+
val_split_pct=0.5,
227+
num_workers=0,
228+
)
229+
model = ChangeDetectionTask(backbone='resnet18', in_channels=13, model='unet')
230+
trainer = Trainer(
231+
accelerator='cpu',
232+
fast_dev_run=fast_dev_run,
233+
log_every_n_steps=1,
234+
max_epochs=1,
235+
)
236+
trainer.validate(model=model, datamodule=datamodule)

0 commit comments

Comments
 (0)