Skip to content

Commit

Permalink
Merge pull request #329 from fmartiescofet/reduce_lr
Browse files Browse the repository at this point in the history
Feat: Implement option to have multiple learning rates
  • Loading branch information
Joao-L-S-Almeida authored Jan 17, 2025
2 parents 7f4533b + 04117cf commit 3161ba2
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 19 deletions.
133 changes: 133 additions & 0 deletions examples/confs/sen1floods11_vit_dual_lr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger: True # will use tensorboardlogger
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch

max_epochs: 200
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: <your_path_here>
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 16
num_workers: 8
constant_scale: 0.0001
dataset_bands:
- COASTAL_AEROSOL
- BLUE
- GREEN
- RED
- RED_EDGE_1
- RED_EDGE_2
- RED_EDGE_3
- NIR_BROAD
- NIR_NARROW
- WATER_VAPOR
- CIRRUS
- SWIR_1
- SWIR_2
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
train_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
train_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
val_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
val_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
test_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
test_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
# these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files
train_split: <sen1floods11_root>/v1.1/splits/flood_handlabeled/flood_train_data.txt
test_split: <sen1floods11_root>/v1.1/splits/flood_handlabeled/flood_test_data.txt
val_split: <sen1floods11_root>/v1.1/splits/flood_handlabeled/flood_valid_data.txt
img_grep: "*_S2Hand.tif"
label_grep: "*_LabelHand.tif"
no_label_replace: -1
no_data_replace: 0
means:
- 0.1412956
- 0.13795798
- 0.12353792
- 0.30902815
- 0.2044958
- 0.11912015
stds:
- 0.07406382
- 0.07370365
- 0.08692279
- 0.11798815
- 0.09772074
- 0.07659938
num_classes: 2

model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: FCNDecoder
backbone_pretrained: true
backbone: prithvi_vit_100
decoder_channels: 256
backbone_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_classes: 2
head_dropout: 0.1
decoder_num_convs: 4
head_channel_list:
- 256
necks:
- name: SelectIndices
indices:
- -1
- name: ReshapeTokensToImage
loss: ce
aux_heads:
- name: aux_head
decoder: FCNDecoder
decoder_args:
decoder_channels: 256
decoder_in_index: -1
decoder_num_convs: 2
head_dropout: 0.1
# head_channel_list:
# - 64
aux_loss:
aux_head: 1.0
ignore_index: -1
class_weights:
- 0.3
- 0.7
freeze_backbone: false
freeze_decoder: false
model_factory: EncoderDecoderFactory
optimizer: AdamW
lr: 1e-4
lr_overrides:
encoder: 1e-5
scheduler: ReduceLROnPlateau
scheduler_hparams:
monitor: val/loss
19 changes: 18 additions & 1 deletion terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections.abc import Iterable

import lightning
from lightning.pytorch.callbacks import Callback
Expand Down Expand Up @@ -52,10 +53,26 @@ def configure_optimizers(
optimizer = self.hparams["optimizer"]
if optimizer is None:
optimizer = "Adam"

parameters: Iterable
if self.hparams.get("lr_overrides", None) is not None and len(self.hparams["lr_overrides"]) > 0:
parameters = []
for param_name, custom_lr in self.hparams["lr_overrides"].items():
p = [p for n, p in self.named_parameters() if param_name in n]
parameters.append({"params": p, "lr": custom_lr})
rest_p = [
p
for n, p in self.named_parameters()
if all(param_name not in n for param_name in self.hparams["lr_overrides"])
]
parameters.append({"params": rest_p})
else:
parameters = self.parameters()

return optimizer_factory(
optimizer,
self.hparams["lr"],
self.parameters(),
parameters,
self.hparams["optimizer_hparams"],
self.hparams["scheduler"],
self.monitor,
Expand Down
20 changes: 12 additions & 8 deletions terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from terratorch.tasks.optimizer_factory import optimizer_factory
from terratorch.tasks.base_task import TerraTorchTask

logger = logging.getLogger('terratorch')
logger = logging.getLogger("terratorch")


def to_class_prediction(y: ModelOutput) -> Tensor:
y_hat = y.output
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT002, FBT001
class_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor
Expand Down Expand Up @@ -97,6 +99,9 @@ def __init__(
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
Defaults to numeric ordering.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
"""
self.aux_loss = aux_loss
self.aux_heads = aux_heads
Expand All @@ -120,7 +125,6 @@ def __init__(
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"


def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand All @@ -131,8 +135,8 @@ def configure_losses(self) -> None:
ignore_index = self.hparams["ignore_index"]

class_weights = (
torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
)
torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
)
if loss == "ce":
ignore_value = -100 if ignore_index is None else ignore_index
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
Expand Down Expand Up @@ -200,7 +204,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
x = batch["image"]
y = batch["label"]
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}

model_output: ModelOutput = self(x, **rest)
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
Expand All @@ -221,7 +225,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
x = batch["image"]
y = batch["label"]
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand All @@ -239,7 +243,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
x = batch["image"]
y = batch["label"]
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand All @@ -260,7 +264,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
x = batch["image"]
file_names = batch["filename"] if "filename" in batch else None
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)

y_hat = self(x).output
Expand Down
15 changes: 10 additions & 5 deletions terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

logger = logging.getLogger('terratorch')
logger = logging.getLogger("terratorch")


class RootLossWrapper(nn.Module):
def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None:
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__(
freeze_decoder: bool = False, # noqa: FBT001, FBT002
plot_on_val: bool | int = 10,
tiled_inference_parameters: TiledInferenceParameters | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor
Expand Down Expand Up @@ -186,6 +188,9 @@ def __init__(
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
used to determine if inference is done on the whole image or through tiling.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand Down Expand Up @@ -266,7 +271,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
x = batch["image"]
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}

model_output: ModelOutput = self(x, **rest)
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
Expand All @@ -287,7 +292,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
x = batch["image"]
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
Expand Down Expand Up @@ -329,7 +334,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
x = batch["image"]
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand All @@ -350,7 +355,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
x = batch["image"]
file_names = batch["filename"] if "filename" in batch else None
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}

def model_forward(x):
return self(x).output
Expand Down
11 changes: 9 additions & 2 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
class_names: list[str] | None = None,
tiled_inference_parameters: TiledInferenceParameters = None,
test_dataloaders_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor
Expand Down Expand Up @@ -106,6 +107,9 @@ def __init__(
test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
which assumes only one test dataloader is used.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand Down Expand Up @@ -294,7 +298,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
batch["prediction"] = y_hat_hard

if isinstance(batch["image"], dict):
if hasattr(datamodule, 'rgb_modality'):
if hasattr(datamodule, "rgb_modality"):
# Generic multimodal dataset
batch["image"] = batch["image"][datamodule.rgb_modality]
else:
Expand Down Expand Up @@ -343,7 +347,10 @@ def model_forward(x):
if self.tiled_inference_parameters:
y_hat: Tensor = tiled_inference(
# TODO: tiled inference does not work with additional input data (**rest)
model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters
model_forward,
x,
self.hparams["model_args"]["num_classes"],
self.tiled_inference_parameters,
)
else:
y_hat: Tensor = self(x, **rest).output
Expand Down
Loading

0 comments on commit 3161ba2

Please sign in to comment.