Skip to content

Commit 3161ba2

Browse files
Merge pull request #329 from fmartiescofet/reduce_lr
Feat: Implement option to have multiple learning rates
2 parents 7f4533b + 04117cf commit 3161ba2

File tree

6 files changed

+191
-19
lines changed

6 files changed

+191
-19
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 0
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
precision: 16-mixed
9+
logger: True # will use tensorboardlogger
10+
callbacks:
11+
- class_path: RichProgressBar
12+
- class_path: LearningRateMonitor
13+
init_args:
14+
logging_interval: epoch
15+
16+
max_epochs: 200
17+
check_val_every_n_epoch: 1
18+
log_every_n_steps: 50
19+
enable_checkpointing: true
20+
default_root_dir: <your_path_here>
21+
data:
22+
class_path: GenericNonGeoSegmentationDataModule
23+
init_args:
24+
batch_size: 16
25+
num_workers: 8
26+
constant_scale: 0.0001
27+
dataset_bands:
28+
- COASTAL_AEROSOL
29+
- BLUE
30+
- GREEN
31+
- RED
32+
- RED_EDGE_1
33+
- RED_EDGE_2
34+
- RED_EDGE_3
35+
- NIR_BROAD
36+
- NIR_NARROW
37+
- WATER_VAPOR
38+
- CIRRUS
39+
- SWIR_1
40+
- SWIR_2
41+
output_bands:
42+
- BLUE
43+
- GREEN
44+
- RED
45+
- NIR_NARROW
46+
- SWIR_1
47+
- SWIR_2
48+
rgb_indices:
49+
- 2
50+
- 1
51+
- 0
52+
train_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
53+
train_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
54+
val_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
55+
val_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
56+
test_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
57+
test_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
58+
# these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files
59+
train_split: <sen1floods11_root>/v1.1/splits/flood_handlabeled/flood_train_data.txt
60+
test_split: <sen1floods11_root>/v1.1/splits/flood_handlabeled/flood_test_data.txt
61+
val_split: <sen1floods11_root>/v1.1/splits/flood_handlabeled/flood_valid_data.txt
62+
img_grep: "*_S2Hand.tif"
63+
label_grep: "*_LabelHand.tif"
64+
no_label_replace: -1
65+
no_data_replace: 0
66+
means:
67+
- 0.1412956
68+
- 0.13795798
69+
- 0.12353792
70+
- 0.30902815
71+
- 0.2044958
72+
- 0.11912015
73+
stds:
74+
- 0.07406382
75+
- 0.07370365
76+
- 0.08692279
77+
- 0.11798815
78+
- 0.09772074
79+
- 0.07659938
80+
num_classes: 2
81+
82+
model:
83+
class_path: terratorch.tasks.SemanticSegmentationTask
84+
init_args:
85+
model_args:
86+
decoder: FCNDecoder
87+
backbone_pretrained: true
88+
backbone: prithvi_vit_100
89+
decoder_channels: 256
90+
backbone_bands:
91+
- BLUE
92+
- GREEN
93+
- RED
94+
- NIR_NARROW
95+
- SWIR_1
96+
- SWIR_2
97+
num_classes: 2
98+
head_dropout: 0.1
99+
decoder_num_convs: 4
100+
head_channel_list:
101+
- 256
102+
necks:
103+
- name: SelectIndices
104+
indices:
105+
- -1
106+
- name: ReshapeTokensToImage
107+
loss: ce
108+
aux_heads:
109+
- name: aux_head
110+
decoder: FCNDecoder
111+
decoder_args:
112+
decoder_channels: 256
113+
decoder_in_index: -1
114+
decoder_num_convs: 2
115+
head_dropout: 0.1
116+
# head_channel_list:
117+
# - 64
118+
aux_loss:
119+
aux_head: 1.0
120+
ignore_index: -1
121+
class_weights:
122+
- 0.3
123+
- 0.7
124+
freeze_backbone: false
125+
freeze_decoder: false
126+
model_factory: EncoderDecoderFactory
127+
optimizer: AdamW
128+
lr: 1e-4
129+
lr_overrides:
130+
encoder: 1e-5
131+
scheduler: ReduceLROnPlateau
132+
scheduler_hparams:
133+
monitor: val/loss

terratorch/tasks/base_task.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from collections.abc import Iterable
23

34
import lightning
45
from lightning.pytorch.callbacks import Callback
@@ -52,10 +53,26 @@ def configure_optimizers(
5253
optimizer = self.hparams["optimizer"]
5354
if optimizer is None:
5455
optimizer = "Adam"
56+
57+
parameters: Iterable
58+
if self.hparams.get("lr_overrides", None) is not None and len(self.hparams["lr_overrides"]) > 0:
59+
parameters = []
60+
for param_name, custom_lr in self.hparams["lr_overrides"].items():
61+
p = [p for n, p in self.named_parameters() if param_name in n]
62+
parameters.append({"params": p, "lr": custom_lr})
63+
rest_p = [
64+
p
65+
for n, p in self.named_parameters()
66+
if all(param_name not in n for param_name in self.hparams["lr_overrides"])
67+
]
68+
parameters.append({"params": rest_p})
69+
else:
70+
parameters = self.parameters()
71+
5572
return optimizer_factory(
5673
optimizer,
5774
self.hparams["lr"],
58-
self.parameters(),
75+
parameters,
5976
self.hparams["optimizer_hparams"],
6077
self.hparams["scheduler"],
6178
self.monitor,

terratorch/tasks/classification_tasks.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from terratorch.tasks.optimizer_factory import optimizer_factory
1717
from terratorch.tasks.base_task import TerraTorchTask
1818

19-
logger = logging.getLogger('terratorch')
19+
logger = logging.getLogger("terratorch")
20+
2021

2122
def to_class_prediction(y: ModelOutput) -> Tensor:
2223
y_hat = y.output
@@ -62,6 +63,7 @@ def __init__(
6263
freeze_backbone: bool = False, # noqa: FBT001, FBT002
6364
freeze_decoder: bool = False, # noqa: FBT002, FBT001
6465
class_names: list[str] | None = None,
66+
lr_overrides: dict[str, float] | None = None,
6567
) -> None:
6668
"""Constructor
6769
@@ -97,6 +99,9 @@ def __init__(
9799
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
98100
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
99101
Defaults to numeric ordering.
102+
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
103+
parameters. The key should be a substring of the parameter names (it will check the substring is
104+
contained in the parameter name)and the value should be the new lr. Defaults to None.
100105
"""
101106
self.aux_loss = aux_loss
102107
self.aux_heads = aux_heads
@@ -120,7 +125,6 @@ def __init__(
120125
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
121126
self.monitor = f"{self.val_metrics.prefix}loss"
122127

123-
124128
def configure_losses(self) -> None:
125129
"""Initialize the loss criterion.
126130
@@ -131,8 +135,8 @@ def configure_losses(self) -> None:
131135
ignore_index = self.hparams["ignore_index"]
132136

133137
class_weights = (
134-
torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
135-
)
138+
torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
139+
)
136140
if loss == "ce":
137141
ignore_value = -100 if ignore_index is None else ignore_index
138142
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
@@ -200,7 +204,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
200204
x = batch["image"]
201205
y = batch["label"]
202206
other_keys = batch.keys() - {"image", "label", "filename"}
203-
rest = {k:batch[k] for k in other_keys}
207+
rest = {k: batch[k] for k in other_keys}
204208

205209
model_output: ModelOutput = self(x, **rest)
206210
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
@@ -221,7 +225,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
221225
x = batch["image"]
222226
y = batch["label"]
223227
other_keys = batch.keys() - {"image", "label", "filename"}
224-
rest = {k:batch[k] for k in other_keys}
228+
rest = {k: batch[k] for k in other_keys}
225229
model_output: ModelOutput = self(x, **rest)
226230
loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
227231
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
@@ -239,7 +243,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
239243
x = batch["image"]
240244
y = batch["label"]
241245
other_keys = batch.keys() - {"image", "label", "filename"}
242-
rest = {k:batch[k] for k in other_keys}
246+
rest = {k: batch[k] for k in other_keys}
243247
model_output: ModelOutput = self(x, **rest)
244248
loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
245249
self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
@@ -260,7 +264,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
260264
x = batch["image"]
261265
file_names = batch["filename"] if "filename" in batch else None
262266
other_keys = batch.keys() - {"image", "label", "filename"}
263-
rest = {k:batch[k] for k in other_keys}
267+
rest = {k: batch[k] for k in other_keys}
264268
model_output: ModelOutput = self(x, **rest)
265269

266270
y_hat = self(x).output

terratorch/tasks/regression_tasks.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
BATCH_IDX_FOR_VALIDATION_PLOTTING = 10
2626

27-
logger = logging.getLogger('terratorch')
27+
logger = logging.getLogger("terratorch")
28+
2829

2930
class RootLossWrapper(nn.Module):
3031
def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None:
@@ -152,6 +153,7 @@ def __init__(
152153
freeze_decoder: bool = False, # noqa: FBT001, FBT002
153154
plot_on_val: bool | int = 10,
154155
tiled_inference_parameters: TiledInferenceParameters | None = None,
156+
lr_overrides: dict[str, float] | None = None,
155157
) -> None:
156158
"""Constructor
157159
@@ -186,6 +188,9 @@ def __init__(
186188
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
187189
tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
188190
used to determine if inference is done on the whole image or through tiling.
191+
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
192+
parameters. The key should be a substring of the parameter names (it will check the substring is
193+
contained in the parameter name)and the value should be the new lr. Defaults to None.
189194
"""
190195
self.tiled_inference_parameters = tiled_inference_parameters
191196
self.aux_loss = aux_loss
@@ -266,7 +271,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
266271
x = batch["image"]
267272
y = batch["mask"]
268273
other_keys = batch.keys() - {"image", "mask", "filename"}
269-
rest = {k:batch[k] for k in other_keys}
274+
rest = {k: batch[k] for k in other_keys}
270275

271276
model_output: ModelOutput = self(x, **rest)
272277
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
@@ -287,7 +292,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
287292
x = batch["image"]
288293
y = batch["mask"]
289294
other_keys = batch.keys() - {"image", "mask", "filename"}
290-
rest = {k:batch[k] for k in other_keys}
295+
rest = {k: batch[k] for k in other_keys}
291296
model_output: ModelOutput = self(x, **rest)
292297
loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
293298
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
@@ -329,7 +334,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
329334
x = batch["image"]
330335
y = batch["mask"]
331336
other_keys = batch.keys() - {"image", "mask", "filename"}
332-
rest = {k:batch[k] for k in other_keys}
337+
rest = {k: batch[k] for k in other_keys}
333338
model_output: ModelOutput = self(x, **rest)
334339
loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
335340
self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
@@ -350,7 +355,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
350355
x = batch["image"]
351356
file_names = batch["filename"] if "filename" in batch else None
352357
other_keys = batch.keys() - {"image", "mask", "filename"}
353-
rest = {k:batch[k] for k in other_keys}
358+
rest = {k: batch[k] for k in other_keys}
354359

355360
def model_forward(x):
356361
return self(x).output

terratorch/tasks/segmentation_tasks.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
class_names: list[str] | None = None,
6565
tiled_inference_parameters: TiledInferenceParameters = None,
6666
test_dataloaders_names: list[str] | None = None,
67+
lr_overrides: dict[str, float] | None = None,
6768
) -> None:
6869
"""Constructor
6970
@@ -106,6 +107,9 @@ def __init__(
106107
test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
107108
multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
108109
which assumes only one test dataloader is used.
110+
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
111+
parameters. The key should be a substring of the parameter names (it will check the substring is
112+
contained in the parameter name)and the value should be the new lr. Defaults to None.
109113
"""
110114
self.tiled_inference_parameters = tiled_inference_parameters
111115
self.aux_loss = aux_loss
@@ -294,7 +298,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
294298
batch["prediction"] = y_hat_hard
295299

296300
if isinstance(batch["image"], dict):
297-
if hasattr(datamodule, 'rgb_modality'):
301+
if hasattr(datamodule, "rgb_modality"):
298302
# Generic multimodal dataset
299303
batch["image"] = batch["image"][datamodule.rgb_modality]
300304
else:
@@ -343,7 +347,10 @@ def model_forward(x):
343347
if self.tiled_inference_parameters:
344348
y_hat: Tensor = tiled_inference(
345349
# TODO: tiled inference does not work with additional input data (**rest)
346-
model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters
350+
model_forward,
351+
x,
352+
self.hparams["model_args"]["num_classes"],
353+
self.tiled_inference_parameters,
347354
)
348355
else:
349356
y_hat: Tensor = self(x, **rest).output

0 commit comments

Comments
 (0)