diff --git a/examples/confs/sen1floods11_vit_dual_lr.yaml b/examples/confs/sen1floods11_vit_dual_lr.yaml new file mode 100644 index 00000000..47a6eaa7 --- /dev/null +++ b/examples/confs/sen1floods11_vit_dual_lr.yaml @@ -0,0 +1,133 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: True # will use tensorboardlogger + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: +data: + class_path: GenericNonGeoSegmentationDataModule + init_args: + batch_size: 16 + num_workers: 8 + constant_scale: 0.0001 + dataset_bands: + - COASTAL_AEROSOL + - BLUE + - GREEN + - RED + - RED_EDGE_1 + - RED_EDGE_2 + - RED_EDGE_3 + - NIR_BROAD + - NIR_NARROW + - WATER_VAPOR + - CIRRUS + - SWIR_1 + - SWIR_2 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ + train_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand + val_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ + val_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand + test_data_root: /v1.1/data/flood_events/HandLabeled/S2Hand/ + test_label_data_root: /v1.1/data/flood_events/HandLabeled/LabelHand + # these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files + train_split: /v1.1/splits/flood_handlabeled/flood_train_data.txt + test_split: /v1.1/splits/flood_handlabeled/flood_test_data.txt + val_split: /v1.1/splits/flood_handlabeled/flood_valid_data.txt + img_grep: "*_S2Hand.tif" + label_grep: "*_LabelHand.tif" + no_label_replace: -1 + no_data_replace: 0 + means: + - 0.1412956 + - 0.13795798 + - 0.12353792 + - 0.30902815 + - 0.2044958 + - 0.11912015 + stds: + - 0.07406382 + - 0.07370365 + - 0.08692279 + - 0.11798815 + - 0.09772074 + - 0.07659938 + num_classes: 2 + +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_args: + decoder: FCNDecoder + backbone_pretrained: true + backbone: prithvi_vit_100 + decoder_channels: 256 + backbone_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_classes: 2 + head_dropout: 0.1 + decoder_num_convs: 4 + head_channel_list: + - 256 + necks: + - name: SelectIndices + indices: + - -1 + - name: ReshapeTokensToImage + loss: ce + aux_heads: + - name: aux_head + decoder: FCNDecoder + decoder_args: + decoder_channels: 256 + decoder_in_index: -1 + decoder_num_convs: 2 + head_dropout: 0.1 + # head_channel_list: + # - 64 + aux_loss: + aux_head: 1.0 + ignore_index: -1 + class_weights: + - 0.3 + - 0.7 + freeze_backbone: false + freeze_decoder: false + model_factory: EncoderDecoderFactory + optimizer: AdamW + lr: 1e-4 + lr_overrides: + encoder: 1e-5 + scheduler: ReduceLROnPlateau + scheduler_hparams: + monitor: val/loss diff --git a/terratorch/tasks/base_task.py b/terratorch/tasks/base_task.py index e59aaf39..53a26cec 100644 --- a/terratorch/tasks/base_task.py +++ b/terratorch/tasks/base_task.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Iterable import lightning from lightning.pytorch.callbacks import Callback @@ -52,10 +53,26 @@ def configure_optimizers( optimizer = self.hparams["optimizer"] if optimizer is None: optimizer = "Adam" + + parameters: Iterable + if self.hparams.get("lr_overrides", None) is not None and len(self.hparams["lr_overrides"]) > 0: + parameters = [] + for param_name, custom_lr in self.hparams["lr_overrides"].items(): + p = [p for n, p in self.named_parameters() if param_name in n] + parameters.append({"params": p, "lr": custom_lr}) + rest_p = [ + p + for n, p in self.named_parameters() + if all(param_name not in n for param_name in self.hparams["lr_overrides"]) + ] + parameters.append({"params": rest_p}) + else: + parameters = self.parameters() + return optimizer_factory( optimizer, self.hparams["lr"], - self.parameters(), + parameters, self.hparams["optimizer_hparams"], self.hparams["scheduler"], self.monitor, diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index f91e1836..8630a0ff 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -16,7 +16,8 @@ from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.base_task import TerraTorchTask -logger = logging.getLogger('terratorch') +logger = logging.getLogger("terratorch") + def to_class_prediction(y: ModelOutput) -> Tensor: y_hat = y.output @@ -62,6 +63,7 @@ def __init__( freeze_backbone: bool = False, # noqa: FBT001, FBT002 freeze_decoder: bool = False, # noqa: FBT002, FBT001 class_names: list[str] | None = None, + lr_overrides: dict[str, float] | None = None, ) -> None: """Constructor @@ -97,6 +99,9 @@ def __init__( freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False. class_names (list[str] | None, optional): List of class names passed to metrics for better naming. Defaults to numeric ordering. + lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific + parameters. The key should be a substring of the parameter names (it will check the substring is + contained in the parameter name)and the value should be the new lr. Defaults to None. """ self.aux_loss = aux_loss self.aux_heads = aux_heads @@ -120,7 +125,6 @@ def __init__( self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" - def configure_losses(self) -> None: """Initialize the loss criterion. @@ -131,8 +135,8 @@ def configure_losses(self) -> None: ignore_index = self.hparams["ignore_index"] class_weights = ( - torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None - ) + torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None + ) if loss == "ce": ignore_value = -100 if ignore_index is None else ignore_index self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights) @@ -200,7 +204,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> x = batch["image"] y = batch["label"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) @@ -221,7 +225,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - x = batch["image"] y = batch["label"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) @@ -239,7 +243,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None x = batch["image"] y = batch["label"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) @@ -260,7 +264,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T x = batch["image"] file_names = batch["filename"] if "filename" in batch else None other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) y_hat = self(x).output diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index bbc1dd48..c57541fa 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -24,7 +24,8 @@ BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 -logger = logging.getLogger('terratorch') +logger = logging.getLogger("terratorch") + class RootLossWrapper(nn.Module): def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None: @@ -152,6 +153,7 @@ def __init__( freeze_decoder: bool = False, # noqa: FBT001, FBT002 plot_on_val: bool | int = 10, tiled_inference_parameters: TiledInferenceParameters | None = None, + lr_overrides: dict[str, float] | None = None, ) -> None: """Constructor @@ -186,6 +188,9 @@ def __init__( If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs. tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters used to determine if inference is done on the whole image or through tiling. + lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific + parameters. The key should be a substring of the parameter names (it will check the substring is + contained in the parameter name)and the value should be the new lr. Defaults to None. """ self.tiled_inference_parameters = tiled_inference_parameters self.aux_loss = aux_loss @@ -266,7 +271,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) @@ -287,7 +292,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) @@ -329,7 +334,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) @@ -350,7 +355,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T x = batch["image"] file_names = batch["filename"] if "filename" in batch else None other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} def model_forward(x): return self(x).output diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 2fdfdc0c..328dbac8 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -64,6 +64,7 @@ def __init__( class_names: list[str] | None = None, tiled_inference_parameters: TiledInferenceParameters = None, test_dataloaders_names: list[str] | None = None, + lr_overrides: dict[str, float] | None = None, ) -> None: """Constructor @@ -106,6 +107,9 @@ def __init__( test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used. + lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific + parameters. The key should be a substring of the parameter names (it will check the substring is + contained in the parameter name)and the value should be the new lr. Defaults to None. """ self.tiled_inference_parameters = tiled_inference_parameters self.aux_loss = aux_loss @@ -294,7 +298,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - batch["prediction"] = y_hat_hard if isinstance(batch["image"], dict): - if hasattr(datamodule, 'rgb_modality'): + if hasattr(datamodule, "rgb_modality"): # Generic multimodal dataset batch["image"] = batch["image"][datamodule.rgb_modality] else: @@ -343,7 +347,10 @@ def model_forward(x): if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( # TODO: tiled inference does not work with additional input data (**rest) - model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters + model_forward, + x, + self.hparams["model_args"]["num_classes"], + self.tiled_inference_parameters, ) else: y_hat: Tensor = self(x, **rest).output diff --git a/tests/test_prithvi_tasks.py b/tests/test_prithvi_tasks.py index 0d94b6cb..7252a9f7 100644 --- a/tests/test_prithvi_tasks.py +++ b/tests/test_prithvi_tasks.py @@ -31,7 +31,8 @@ def model_input() -> torch.Tensor: @pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "UNetDecoder"]) @pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"]) -def test_create_segmentation_task(backbone, decoder, loss, model_factory: str): +@pytest.mark.parametrize("lr_overrides", [{"encoder": 0.01}, None]) +def test_create_segmentation_task(backbone, decoder, loss, model_factory: str, lr_overrides): model_args = { "backbone": backbone, "decoder": decoder, @@ -48,6 +49,7 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: str): model_args, model_factory, loss=loss, + lr_overrides=lr_overrides, ) gc.collect() @@ -56,7 +58,8 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: str): @pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "UNetDecoder"]) @pytest.mark.parametrize("loss", ["mae", "rmse", "huber"]) -def test_create_regression_task(backbone, decoder, loss, model_factory: str): +@pytest.mark.parametrize("lr_overrides", [{"encoder": 0.01}, None]) +def test_create_regression_task(backbone, decoder, loss, model_factory: str, lr_overrides): model_args = { "backbone": backbone, "decoder": decoder, @@ -73,6 +76,7 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: str): model_args, model_factory, loss=loss, + lr_overrides=lr_overrides, ) gc.collect() @@ -81,7 +85,8 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: str): @pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "UNetDecoder"]) @pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"]) -def test_create_classification_task(backbone, decoder, loss, model_factory: str): +@pytest.mark.parametrize("lr_overrides", [{"encoder": 0.01}, None]) +def test_create_classification_task(backbone, decoder, loss, model_factory: str, lr_overrides): model_args = { "backbone": backbone, "decoder": decoder, @@ -99,6 +104,7 @@ def test_create_classification_task(backbone, decoder, loss, model_factory: str) model_args, model_factory, loss=loss, + lr_overrides=lr_overrides, ) gc.collect()