Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Update multimodal functionalities #450

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
1ab36b6
Fix zarr loading error
blumenstiel Dec 12, 2024
0d293be
Add try catch testing
blumenstiel Jan 8, 2025
7c5872a
Merge branch 'IBM:main' into multimodal
blumenstiel Jan 22, 2025
718ec85
Updated multimodal dataset
blumenstiel Jan 22, 2025
e76f03f
Merge branch 'IBM:main' into multimodal
blumenstiel Jan 27, 2025
f5d7c41
Update multi mae
blumenstiel Jan 27, 2025
f0cf9ae
Update multi mae
blumenstiel Jan 28, 2025
9b5405d
Fix multi mae config
blumenstiel Jan 28, 2025
83fbb1d
Update multi mae code
blumenstiel Jan 29, 2025
11cae7b
updated multimae and plotting
blumenstiel Jan 29, 2025
0ceaa54
Merge branch 'IBM:main' into multimodal
blumenstiel Feb 10, 2025
3b38602
Fix multimae
blumenstiel Feb 10, 2025
b5b1542
Fix multimodal dataset
blumenstiel Feb 10, 2025
43f90a8
Merge branch 'IBM:main' into multimodal
blumenstiel Feb 17, 2025
88bce70
Fix multilabel classification
blumenstiel Feb 19, 2025
0bb1647
Fix classification target dtype
blumenstiel Feb 19, 2025
51248a4
Support multimodal datasets in classification
blumenstiel Feb 19, 2025
d799a14
Fix default setting
blumenstiel Feb 19, 2025
bc207a6
Fix mm dataset
blumenstiel Feb 19, 2025
b686ac1
Fix mm dataset
blumenstiel Feb 19, 2025
453341a
Fix classification tasks
blumenstiel Feb 19, 2025
7161bd0
Fix mm dataset
blumenstiel Feb 19, 2025
a36573b
Add registry
blumenstiel Feb 21, 2025
30e0b72
Add registry
blumenstiel Feb 21, 2025
81f14b6
Add prithvi mae models
blumenstiel Feb 21, 2025
026c6fa
rename prithvi model_bands to bands
blumenstiel Feb 21, 2025
e282b09
Add full model library
blumenstiel Feb 21, 2025
f6e6b70
Prithvi add loss dict
blumenstiel Feb 21, 2025
9f1b393
Merge branch 'IBM:main' into multimodal
blumenstiel Feb 21, 2025
1b7f40e
Add reconstruction_tasks
blumenstiel Feb 21, 2025
ad6b992
Fixed tasks
blumenstiel Feb 21, 2025
e557a38
Update multimae
blumenstiel Feb 21, 2025
7240c12
Update reconstruction task
blumenstiel Feb 21, 2025
49ff9d3
Update prithvi
blumenstiel Feb 21, 2025
dfe8221
Merge branch 'multimodal' into feature/full-model-registry
blumenstiel Feb 21, 2025
c0f2eea
Merge pull request #2 from IBM/feature/full-model-registry
blumenstiel Feb 21, 2025
f64e4bd
Fix import error
blumenstiel Feb 21, 2025
5885d15
Add NDVI modality
blumenstiel Feb 21, 2025
c96e853
Fix reconstruction training
blumenstiel Feb 21, 2025
f965793
Fix reconstruction training
blumenstiel Feb 21, 2025
f5b5dc2
Fix reconstruction training
blumenstiel Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 22 additions & 33 deletions examples/confs/multimae_sen1floods11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ data:
num_workers: 0
modalities:
- S2L2A
- S1
- LULC
- S1GRD
rgb_modality: S2L2A # If not provided, uses first modality
rgb_indices:
- 3
Expand All @@ -45,18 +44,15 @@ data:

train_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
S1GRD: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
val_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
S1GRD: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
test_data_root:
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
S1GRD: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand

train_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_train.txt
Expand All @@ -66,8 +62,7 @@ data:
allow_substring_file_names: True
image_grep:
S2L2A: "*_S2L2AHand.tif"
S1: "*_S1Hand.tif"
LULC: "*_LULCHand.npy"
S1GRD: "*_S1Hand.tif"
label_grep: "*_LabelHand.tif"
no_label_replace: -1
no_data_replace: 0
Expand All @@ -86,7 +81,7 @@ data:
- 3711.071
- 3416.714
- 2849.625
S1:
S1GRD:
- -12.577
- -20.265

Expand All @@ -104,17 +99,13 @@ data:
- 1652.703
- 1471.002
- 1365.30
S1:
S1GRD:
- 5.179
- 5.872

num_classes: 2

train_transform:
- class_path: albumentations.RandomCrop
init_args:
height: 224
width: 224
- class_path: albumentations.D4
- class_path: ToTensorV2

Expand All @@ -127,35 +118,33 @@ model:
backbone_pretrained: false
backbone: multimae_base
backbone_input_adapters:
- S1
- S1GRD
- S2L2A
- LULC
decoder: FCNDecoder # UperNetDecoder
decoder_num_convs: 4 # only for FCNDecoder
# decoder_scale_modules: True # only for UperNetDecoder
decoder_channels: 256
num_classes: 2
backbone_merge_method: mean
necks:
- name: ReshapeTokensToImage
remove_cls_token: False # Need to be False because of missing CLS token in MultiMAE
- name: SelectIndices
indices: [2, 5, 8, 11]
- name: LearnedInterpolateToPyramidal
decoder: UNetDecoder
decoder_channels: [512, 256, 128, 64]
head_dropout: 0.1
head_channel_list:
- 256
loss: ce
num_classes: 2
loss: dice
ignore_index: -1
class_weights:
- 0.3
- 0.7
class_names:
- Others
- Flood
freeze_backbone: false
freeze_decoder: false

optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 6.e-5
lr: 1.e-4
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
factor: 0.5
patience: 5

12 changes: 10 additions & 2 deletions terratorch/datamodules/generic_multimodal_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
logger = logging.getLogger("terratorch")

def collate_chunk_dicts(batch_list):
if isinstance(batch_list, dict):
# batch size = 1
return batch_list

batch = {}
for key, value in batch_list[0].items(): # TODO: Handle missing modalities when allow_missing_modalities is set.
if isinstance(value, torch.Tensor):
Expand Down Expand Up @@ -185,7 +189,7 @@ def __init__(
image_modalities: list[str] | None = None,
rgb_modality: str | None = None,
rgb_indices: list[int] | None = None,
allow_substring_file_names: bool = False,
allow_substring_file_names: bool = True,
class_names: list[str] | None = None,
constant_scale: dict[float] = None,
train_transform: dict | A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -439,7 +443,8 @@ def setup(self, stage: str) -> None:
expand_temporal_dimension=self.expand_temporal_dimension,
reduce_zero_label=self.reduce_zero_label,
channel_position=self.channel_position,
concat_bands=self.concat_bands ,
data_with_sample_dim = self.data_with_sample_dim,
concat_bands=self.concat_bands,
)
logger.info(f"Train dataset: {len(self.train_dataset)}")
if stage in ["fit", "validate"]:
Expand All @@ -463,6 +468,7 @@ def setup(self, stage: str) -> None:
expand_temporal_dimension=self.expand_temporal_dimension,
reduce_zero_label=self.reduce_zero_label,
channel_position=self.channel_position,
data_with_sample_dim = self.data_with_sample_dim,
concat_bands=self.concat_bands,
)
logger.info(f"Val dataset: {len(self.val_dataset)}")
Expand All @@ -487,6 +493,7 @@ def setup(self, stage: str) -> None:
expand_temporal_dimension=self.expand_temporal_dimension,
reduce_zero_label=self.reduce_zero_label,
channel_position=self.channel_position,
data_with_sample_dim = self.data_with_sample_dim,
concat_bands=self.concat_bands,
)
logger.info(f"Test dataset: {len(self.test_dataset)}")
Expand All @@ -507,6 +514,7 @@ def setup(self, stage: str) -> None:
expand_temporal_dimension=self.expand_temporal_dimension,
reduce_zero_label=self.reduce_zero_label,
channel_position=self.channel_position,
data_with_sample_dim=self.data_with_sample_dim,
concat_bands=self.concat_bands,
)
logger.info(f"Predict dataset: {len(self.predict_dataset)}")
Expand Down
Loading