Skip to content

Commit 9a00845

Browse files
Merge pull request #322 from IBM/improve/tests
Testing finetuning for more Prithvi-2 backbones
2 parents 4a0afc2 + 710b7de commit 9a00845

5 files changed

+308
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 42
3+
trainer:
4+
accelerator: cpu
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
# precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: tests/
13+
name: all_ecos_random
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
- class_path: EarlyStopping
20+
init_args:
21+
monitor: val/loss
22+
patience: 100
23+
max_epochs: 2
24+
check_val_every_n_epoch: 1
25+
log_every_n_steps: 20
26+
enable_checkpointing: true
27+
default_root_dir: tests/
28+
data:
29+
class_path: GenericNonGeoPixelwiseRegressionDataModule
30+
init_args:
31+
batch_size: 2
32+
num_workers: 4
33+
train_transform:
34+
#- class_path: albumentations.HorizontalFlip
35+
# init_args:
36+
# p: 0.5
37+
#- class_path: albumentations.Rotate
38+
# init_args:
39+
# limit: 30
40+
# border_mode: 0 # cv2.BORDER_CONSTANT
41+
# value: 0
42+
# # mask_value: 1
43+
# p: 0.5
44+
- class_path: ToTensorV2
45+
dataset_bands:
46+
- 0
47+
- BLUE
48+
- GREEN
49+
- RED
50+
- NIR_NARROW
51+
- SWIR_1
52+
- SWIR_2
53+
- 1
54+
- 2
55+
- 3
56+
- 4
57+
output_bands:
58+
- BLUE
59+
- GREEN
60+
- RED
61+
- NIR_NARROW
62+
- SWIR_1
63+
- SWIR_2
64+
rgb_indices:
65+
- 2
66+
- 1
67+
- 0
68+
train_data_root: tests/resources/inputs
69+
train_label_data_root: tests/resources/inputs
70+
val_data_root: tests/resources/inputs
71+
val_label_data_root: tests/resources/inputs
72+
test_data_root: tests/resources/inputs
73+
test_label_data_root: tests/resources/inputs
74+
img_grep: "regression*input*.tif"
75+
label_grep: "regression*label*.tif"
76+
means:
77+
- 547.36707
78+
- 898.5121
79+
- 1020.9082
80+
- 2665.5352
81+
- 2340.584
82+
- 1610.1407
83+
stds:
84+
- 411.4701
85+
- 558.54065
86+
- 815.94025
87+
- 812.4403
88+
- 1113.7145
89+
- 1067.641
90+
no_label_replace: -1
91+
no_data_replace: 0
92+
93+
model:
94+
class_path: terratorch.tasks.PixelwiseRegressionTask
95+
init_args:
96+
model_args:
97+
decoder: UperNetDecoder
98+
pretrained: false
99+
backbone: prithvi_eo_v2_300
100+
# backbone_pretrained_cfg_overlay:
101+
# file: tests/prithvi_vit_300.pt
102+
backbone_drop_path_rate: 0.3
103+
# backbone_window_size: 8
104+
decoder_channels: 64
105+
num_frames: 1
106+
in_channels: 6
107+
bands:
108+
- BLUE
109+
- GREEN
110+
- RED
111+
- NIR_NARROW
112+
- SWIR_1
113+
- SWIR_2
114+
head_dropout: 0.5708022831486758
115+
head_final_act: torch.nn.ReLU
116+
head_learned_upscale_layers: 2
117+
loss: rmse
118+
#aux_heads:
119+
# - name: aux_head
120+
# decoder: IdentityDecoder
121+
# decoder_args:
122+
# decoder_out_index: 2
123+
# head_dropout: 0,5
124+
# head_channel_list:
125+
# - 64
126+
# head_final_act: torch.nn.ReLU
127+
#aux_loss:
128+
# aux_head: 0.4
129+
ignore_index: -1
130+
freeze_backbone: true
131+
freeze_decoder: false
132+
model_factory: PrithviModelFactory
133+
134+
# uncomment this block for tiled inference
135+
# tiled_inference_parameters:
136+
# h_crop: 224
137+
# h_stride: 192
138+
# w_crop: 224
139+
# w_stride: 192
140+
# average_patches: true
141+
optimizer:
142+
class_path: torch.optim.AdamW
143+
init_args:
144+
lr: 0.00013524680528283027
145+
weight_decay: 0.047782217873995426
146+
lr_scheduler:
147+
class_path: ReduceLROnPlateau
148+
init_args:
149+
monitor: val/loss
150+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 42
3+
trainer:
4+
accelerator: cpu
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
# precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: tests/
13+
name: all_ecos_random
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
- class_path: EarlyStopping
20+
init_args:
21+
monitor: val/loss
22+
patience: 100
23+
max_epochs: 2
24+
check_val_every_n_epoch: 1
25+
log_every_n_steps: 20
26+
enable_checkpointing: true
27+
default_root_dir: tests/
28+
data:
29+
class_path: GenericNonGeoPixelwiseRegressionDataModule
30+
init_args:
31+
batch_size: 2
32+
num_workers: 4
33+
train_transform:
34+
#- class_path: albumentations.HorizontalFlip
35+
# init_args:
36+
# p: 0.5
37+
#- class_path: albumentations.Rotate
38+
# init_args:
39+
# limit: 30
40+
# border_mode: 0 # cv2.BORDER_CONSTANT
41+
# value: 0
42+
# # mask_value: 1
43+
# p: 0.5
44+
- class_path: ToTensorV2
45+
dataset_bands:
46+
- 0
47+
- BLUE
48+
- GREEN
49+
- RED
50+
- NIR_NARROW
51+
- SWIR_1
52+
- SWIR_2
53+
- 1
54+
- 2
55+
- 3
56+
- 4
57+
output_bands:
58+
- BLUE
59+
- GREEN
60+
- RED
61+
- NIR_NARROW
62+
- SWIR_1
63+
- SWIR_2
64+
rgb_indices:
65+
- 2
66+
- 1
67+
- 0
68+
train_data_root: tests/resources/inputs
69+
train_label_data_root: tests/resources/inputs
70+
val_data_root: tests/resources/inputs
71+
val_label_data_root: tests/resources/inputs
72+
test_data_root: tests/resources/inputs
73+
test_label_data_root: tests/resources/inputs
74+
img_grep: "regression*input*.tif"
75+
label_grep: "regression*label*.tif"
76+
means:
77+
- 547.36707
78+
- 898.5121
79+
- 1020.9082
80+
- 2665.5352
81+
- 2340.584
82+
- 1610.1407
83+
stds:
84+
- 411.4701
85+
- 558.54065
86+
- 815.94025
87+
- 812.4403
88+
- 1113.7145
89+
- 1067.641
90+
no_label_replace: -1
91+
no_data_replace: 0
92+
93+
model:
94+
class_path: terratorch.tasks.PixelwiseRegressionTask
95+
init_args:
96+
model_args:
97+
decoder: UperNetDecoder
98+
pretrained: false
99+
backbone: prithvi_eo_v2_600
100+
# backbone_pretrained_cfg_overlay:
101+
# file: tests/prithvi_vit_300.pt
102+
backbone_drop_path_rate: 0.3
103+
# backbone_window_size: 8
104+
decoder_channels: 64
105+
num_frames: 1
106+
in_channels: 6
107+
bands:
108+
- BLUE
109+
- GREEN
110+
- RED
111+
- NIR_NARROW
112+
- SWIR_1
113+
- SWIR_2
114+
head_dropout: 0.5708022831486758
115+
head_final_act: torch.nn.ReLU
116+
head_learned_upscale_layers: 2
117+
loss: rmse
118+
#aux_heads:
119+
# - name: aux_head
120+
# decoder: IdentityDecoder
121+
# decoder_args:
122+
# decoder_out_index: 2
123+
# head_dropout: 0,5
124+
# head_channel_list:
125+
# - 64
126+
# head_final_act: torch.nn.ReLU
127+
#aux_loss:
128+
# aux_head: 0.4
129+
ignore_index: -1
130+
freeze_backbone: true
131+
freeze_decoder: false
132+
model_factory: PrithviModelFactory
133+
134+
# uncomment this block for tiled inference
135+
# tiled_inference_parameters:
136+
# h_crop: 224
137+
# h_stride: 192
138+
# w_crop: 224
139+
# w_stride: 192
140+
# average_patches: true
141+
optimizer:
142+
class_path: torch.optim.AdamW
143+
init_args:
144+
lr: 0.00013524680528283027
145+
weight_decay: 0.047782217873995426
146+
lr_scheduler:
147+
class_path: ReduceLROnPlateau
148+
init_args:
149+
monitor: val/loss
150+

tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ model:
9696
model_args:
9797
decoder: UperNetDecoder
9898
pretrained: false
99-
backbone: prithvi_vit_300
99+
backbone: prithvi_eo_v2_300
100100
# backbone_pretrained_cfg_overlay:
101101
# file: tests/prithvi_vit_300.pt
102102
backbone_drop_path_rate: 0.3

tests/test_backbones.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,23 @@ def input_386():
3535
return torch.ones((1, NUM_CHANNELS, 386, 386))
3636

3737

38-
@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
38+
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
3939
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
4040
def test_can_create_backbones_from_timm(model_name, test_input, request):
4141
backbone = timm.create_model(model_name, pretrained=False)
4242
input_tensor = request.getfixturevalue(test_input)
4343
backbone(input_tensor)
4444
gc.collect()
4545

46-
@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
46+
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
4747
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
4848
def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
4949
backbone = timm.create_model(model_name, pretrained=False, features_only=True)
5050
input_tensor = request.getfixturevalue(test_input)
5151
backbone(input_tensor)
5252
gc.collect()
53-
@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
53+
54+
@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
5455
@pytest.mark.parametrize("prefix", ["", "timm_"])
5556
def test_can_create_timm_backbones_from_registry(model_name, input_224, prefix):
5657
backbone = BACKBONE_REGISTRY.build(prefix+model_name, pretrained=False)
@@ -62,12 +63,14 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal):
6263
backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES)
6364
backbone(input_224_multitemporal)
6465
gc.collect()
66+
6567
@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
6668
def test_vit_models_non_divisible_input(model_name, input_non_divisible):
6769
#padding 'none','constant', 'reflect', 'replicate' or 'circular' default is 'none'
6870
backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, padding='constant')
6971
backbone(input_non_divisible)
7072
gc.collect()
73+
7174
@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
7275
@pytest.mark.parametrize("patch_size", [8, 16])
7376
@pytest.mark.parametrize("patch_size_time", [1, 2, 4])

tests/test_finetune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def setup_and_cleanup(model_name):
2222
if os.path.isdir(os.path.join("tests", "all_ecos_random")):
2323
shutil.rmtree(os.path.join("tests", "all_ecos_random"))
2424

25-
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_vit_100"])
25+
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_eo_v2_600"])
2626
@pytest.mark.parametrize("case", ["fit", "test", "validate"])
2727
def test_finetune_multiple_backbones(model_name, case):
2828
command_list = [case, "-c", f"tests/resources/configs/manufactured-finetune_{model_name}.yaml"]

0 commit comments

Comments
 (0)