@@ -35,22 +35,23 @@ def input_386():
35
35
return torch .ones ((1 , NUM_CHANNELS , 386 , 386 ))
36
36
37
37
38
- @pytest .mark .parametrize ("model_name" , ["prithvi_eo_v1_100" , "prithvi_eo_v2_300" , "prithvi_swin_B" ])
38
+ @pytest .mark .parametrize ("model_name" , ["prithvi_swin_B" , "prithvi_swin_L" , " prithvi_eo_v1_100" , "prithvi_eo_v2_300" , "prithvi_swin_B" ])
39
39
@pytest .mark .parametrize ("test_input" , ["input_224" , "input_512" ])
40
40
def test_can_create_backbones_from_timm (model_name , test_input , request ):
41
41
backbone = timm .create_model (model_name , pretrained = False )
42
42
input_tensor = request .getfixturevalue (test_input )
43
43
backbone (input_tensor )
44
44
gc .collect ()
45
45
46
- @pytest .mark .parametrize ("model_name" , ["prithvi_eo_v1_100" , "prithvi_eo_v2_300" , "prithvi_swin_B" ])
46
+ @pytest .mark .parametrize ("model_name" , ["prithvi_swin_B" , "prithvi_swin_L" , " prithvi_eo_v1_100" , "prithvi_eo_v2_300" , "prithvi_swin_B" ])
47
47
@pytest .mark .parametrize ("test_input" , ["input_224" , "input_512" ])
48
48
def test_can_create_backbones_from_timm_features_only (model_name , test_input , request ):
49
49
backbone = timm .create_model (model_name , pretrained = False , features_only = True )
50
50
input_tensor = request .getfixturevalue (test_input )
51
51
backbone (input_tensor )
52
52
gc .collect ()
53
- @pytest .mark .parametrize ("model_name" , ["prithvi_eo_v1_100" , "prithvi_eo_v2_300" , "prithvi_swin_B" ])
53
+
54
+ @pytest .mark .parametrize ("model_name" , ["prithvi_swin_L" , "prithvi_swin_L" , "prithvi_eo_v1_100" , "prithvi_eo_v2_300" , "prithvi_swin_B" ])
54
55
@pytest .mark .parametrize ("prefix" , ["" , "timm_" ])
55
56
def test_can_create_timm_backbones_from_registry (model_name , input_224 , prefix ):
56
57
backbone = BACKBONE_REGISTRY .build (prefix + model_name , pretrained = False )
@@ -62,12 +63,14 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal):
62
63
backbone = timm .create_model (model_name , pretrained = False , num_frames = NUM_FRAMES )
63
64
backbone (input_224_multitemporal )
64
65
gc .collect ()
66
+
65
67
@pytest .mark .parametrize ("model_name" , ["prithvi_eo_v1_100" , "prithvi_eo_v2_300" ])
66
68
def test_vit_models_non_divisible_input (model_name , input_non_divisible ):
67
69
#padding 'none','constant', 'reflect', 'replicate' or 'circular' default is 'none'
68
70
backbone = timm .create_model (model_name , pretrained = False , features_only = True , num_frames = NUM_FRAMES , padding = 'constant' )
69
71
backbone (input_non_divisible )
70
72
gc .collect ()
73
+
71
74
@pytest .mark .parametrize ("model_name" , ["prithvi_eo_v1_100" , "prithvi_eo_v2_300" ])
72
75
@pytest .mark .parametrize ("patch_size" , [8 , 16 ])
73
76
@pytest .mark .parametrize ("patch_size_time" , [1 , 2 , 4 ])
0 commit comments