From 6a7cc01d4020cdeb74fbc34aafc3f08934207cd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 30 Jan 2025 13:40:09 -0300 Subject: [PATCH 1/3] Argument to allow segmentation class to output all the categories MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/segmentation_tasks.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index aff17ab0..a5f06d92 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -65,6 +65,7 @@ def __init__( tiled_inference_parameters: TiledInferenceParameters = None, test_dataloaders_names: list[str] | None = None, lr_overrides: dict[str, float] | None = None, + output_most_probable: bool = True, ) -> None: """Constructor @@ -110,6 +111,8 @@ def __init__( lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None. + output_most_probable (bool): A boolean to define if the output during the inference will be just + for the most probable class or if it will include all of them. """ self.tiled_inference_parameters = tiled_inference_parameters self.aux_loss = aux_loss @@ -136,7 +139,13 @@ def __init__( self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" self.plot_on_val = int(plot_on_val) + self.output_most_probable = output_most_probable + if output_most_probable: + self.select_classes = lambda y: y.argmax(dim=1) + else: + self.select_classes = lambda y: y + print(self.output_most_probable) def configure_losses(self) -> None: """Initialize the loss criterion. @@ -349,5 +358,7 @@ def model_forward(x): ) else: y_hat: Tensor = self(x, **rest).output - y_hat = y_hat.argmax(dim=1) + + y_hat = self.select_classes(y_hat) + return y_hat, file_names From d35ceb3dad0a8e8a401edae0f8ee47400ab408f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 30 Jan 2025 17:04:51 -0300 Subject: [PATCH 2/3] Adapting the number of bands for the tiff output files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/cli_tools.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 50b6d697..6a9c7377 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -81,6 +81,11 @@ def is_one_band(img): def write_tiff(img_wrt, filename, metadata): + # Adapting the number of bands to be compatible with the + # output dimensions. + count = img_wrt.shape[0] + metadata['count'] = count + with rasterio.open(filename, "w", **metadata) as dest: if is_one_band(img_wrt): img_wrt = img_wrt[None] From 2b6404b7b39a32413145bf470016c6ddc268886c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 30 Jan 2025 17:05:54 -0300 Subject: [PATCH 3/3] files for testing multiple outputs in segmentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/segmentation_tasks.py | 2 +- ...ed-finetune_prithvi_swin_B_segmentation.yaml | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index a5f06d92..1286ff57 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -145,7 +145,7 @@ def __init__( self.select_classes = lambda y: y.argmax(dim=1) else: self.select_classes = lambda y: y - print(self.output_most_probable) + def configure_losses(self) -> None: """Initialize the loss criterion. diff --git a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml index 52b9ee58..586aa6b1 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml @@ -51,12 +51,12 @@ data: - 2 - 1 - 0 - train_data_root: tests/ - train_label_data_root: tests/ - val_data_root: tests/ - val_label_data_root: tests/ - test_data_root: tests/ - test_label_data_root: tests/ + train_data_root: tests/resources/inputs + train_label_data_root: tests/resources/inputs + val_data_root: tests/resources/inputs + val_label_data_root: tests/resources/inputs + test_data_root: tests/resources/inputs + test_label_data_root: tests/resources/inputs img_grep: "segmentation*input*.tif" label_grep: "segmentation*label*.tif" means: @@ -83,8 +83,8 @@ model: decoder: UperNetDecoder pretrained: true backbone: prithvi_swin_B - backbone_pretrained_cfg_overlay: - file: tests/prithvi_swin_B.pt + #backbone_pretrained_cfg_overlay: + #file: tests/prithvi_swin_B.pt backbone_drop_path_rate: 0.3 # backbone_window_size: 8 decoder_channels: 256 @@ -99,6 +99,7 @@ model: num_frames: 1 num_classes: 2 head_dropout: 0.5708022831486758 + output_most_probable: false loss: ce #aux_heads: # - name: aux_head