Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WiP] Replace/timm #433

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def is_one_band(img):

def write_tiff(img_wrt, filename, metadata):

# Adapting the number of bands to be compatible with the
# output dimensions.
count = img_wrt.shape[0]
metadata['count'] = count

with rasterio.open(filename, "w", **metadata) as dest:
if is_one_band(img_wrt):
img_wrt = img_wrt[None]
Expand Down Expand Up @@ -132,6 +137,28 @@ def import_custom_modules(custom_modules_path: str | Path | None = None) -> None
else:
logger.debug("No custom module is being used.")

# TODO remove it for future releases
def remove_unexpected_prefix(state_dict):
state_dict_ = {}
for k, v in state_dict.items():
keys = k.split(".")
if "_timm_module" in keys:
index = keys.index("_timm_module")
keys.pop(index)
k_ = ".".join(keys)
else:
k_ = k
state_dict_[k_] = v
return state_dict_

# TODO remove it for future releases
def remove_unexpected_prefix(state_dict):
state_dict_ = {}
for k, v in state_dict.items():
k_ = k.replace("stages_", "stages.")
state_dict_[k_] = v
return state_dict_

class CustomWriter(BasePredictionWriter):
"""Callback class to write geospatial data to file."""

Expand Down Expand Up @@ -486,6 +513,9 @@ def __init__(
weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
if "state_dict" in weights:
weights = weights["state_dict"]
# It removes a residual prefix (related to timm) from older
# checkpoints.
weights = remove_unexpected_prefix(weights)
weights = {k.replace("model.", ""): v for k, v in weights.items() if k.startswith("model.")}
self.model.model.load_state_dict(weights)

Expand Down
39 changes: 38 additions & 1 deletion terratorch/models/backbones/prithvi_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ def weights_are_swin_implementation(state_dict: dict[str, torch.Tensor]):
return True
return False

# Identifying when a prefix is being used in the checkpoints
# it will identify it.
def identify_prefix(state_dict, model):

state_dict_ = model.state_dict()

prefix = list(state_dict.keys())[0].replace(list(state_dict_.keys())[0], "")

return prefix

# Replacing "_" with "." when necessary.
def adapt_prefix(key):
if key.startswith("stages_"):
key = key.replace("stages_", "stages.")
return key

def checkpoint_filter_fn(state_dict: dict[str, torch.Tensor], model: torch.nn.Module, pretrained_bands, model_bands):
"""convert patch embedding weight from manual patchify + linear proj to conv"""
Expand Down Expand Up @@ -134,9 +149,27 @@ def checkpoint_filter_fn(state_dict: dict[str, torch.Tensor], model: torch.nn.Mo
state_dict[k] = v

relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
# Sometimes the checkpoints can contain an unexpected prefix that must be
# removed.
prefix = identify_prefix(state_dict, model)

for table_key in relative_position_bias_table_keys:

# The checkpoints can sometimes contain unexpected prefixes.
# TODO Guarantee that it will not happen in the future.
if prefix:
table_key_ = table_key.replace(prefix, "")
else:
table_key_ = table_key

# In an unexpected behavior, the prefix can sometimes contain
# "_" or ".". We are enforcing ".".
# TODO Standardize it.
table_key_ = adapt_prefix(table_key_)

table_pretrained = state_dict[table_key]
table_current = model.state_dict()[table_key]

table_current = model.state_dict()[table_key_]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
Expand Down Expand Up @@ -190,8 +223,10 @@ def _create_swin_mmseg_transformer(
def checkpoint_filter_wrapper_fn(state_dict, model):
return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands)

# TODO Totally remove the usage of timm for Swin in the future.
# When the pretrained configuration is not available in HF, we shift to
# pretrained=False
"""
try:
model: MMSegSwinTransformer = build_model_with_cfg(
MMSegSwinTransformer,
Expand All @@ -213,6 +248,8 @@ def checkpoint_filter_wrapper_fn(state_dict, model):
feature_cfg={"flatten_sequential": True, "out_indices": out_indices},
**kwargs,
)
"""
model = MMSegSwinTransformer(**kwargs)

model.pretrained_bands = pretrained_bands
model.model_bands = model_bands
Expand Down
5 changes: 3 additions & 2 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def checkpoint_filter_fn_vit(

state_dict = clean_dict

state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)
state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands, encoder_only=True)

return state_dict

Expand Down Expand Up @@ -153,7 +153,7 @@ def checkpoint_filter_fn_mae(

state_dict = clean_dict

state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)
state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands, encoder_only=False)

return state_dict

Expand Down Expand Up @@ -214,6 +214,7 @@ def _create_prithvi(
# Load model from checkpoint
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands)

loaded_keys = model.load_state_dict(state_dict, strict=False)
if loaded_keys.missing_keys:
logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}")
Expand Down
89 changes: 82 additions & 7 deletions terratorch/models/backbones/select_patch_embed_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,74 @@ def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoi
checkpoint_shape = [checkpoint_patch_embed.shape[i] for i in range(len(checkpoint_patch_embed.shape)) if i != 1]
return model_shape == checkpoint_shape

def get_state_dict(state_dict):

def search_state_dict(keys):
key = 0
for k in keys:
if k.endswith("state_dict"):
key = k
break
return key

state_dict_key = search_state_dict(state_dict.keys())

if state_dict_key:
return state_dict[state_dict_key]
else:
return state_dict

def get_common_prefix(keys):

keys_big_list = []

keys = list(keys)
keys.pop(-1)

for k in keys:
keys_big_list.append(set(k.split(".")))
prefix_list = set.intersection(*keys_big_list)

if len(prefix_list) > 1:
prefix = ".".join(prefix_list)
else:
prefix = prefix_list.pop()

return prefix + "."

def get_proj_key(state_dict, encoder_only=True, return_prefix=False):

proj_key = None

for key in state_dict.keys():
if key.endswith('patch_embed.proj.weight') or key.endswith('patch_embed.projection.weight'):
proj_key = key
break


if return_prefix and proj_key:
if encoder_only:
for sufix in ['patch_embed.proj.weight', 'patch_embed.projection.weight']:
if proj_key.endswith(sufix):
prefix = proj_key.replace(sufix, "")
break
else:
prefix = get_common_prefix(state_dict.keys())
else:
prefix = None

return proj_key, prefix

def remove_prefixes(state_dict, prefix):
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k.replace(prefix, "")] = v
return new_state_dict

def select_patch_embed_weights(
state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int | OpticalBands| SARBands], model_bands: list[HLSBands | int | OpticalBands| SARBands], proj_key: str | None = None
) -> dict:
state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int | OpticalBands| SARBands], model_bands: list[HLSBands | int | OpticalBands| SARBands],
proj_key: str | None = None, encoder_only:bool=True) -> dict:

"""Filter out the patch embedding weights according to the bands being used.
If a band exists in the pretrained_bands, but not in model_bands, drop it.
If a band exists in model_bands, but not pretrained_bands, randomly initialize those weights.
Expand All @@ -38,18 +103,25 @@ def select_patch_embed_weights(
"""
if (type(pretrained_bands) == type(model_bands)) | (type(pretrained_bands) == int) | (type(model_bands) == int):

state_dict = get_state_dict(state_dict)
prefix = None # we expect no prefix will be necessary in principle

if proj_key is None:
# Search for patch embedding weight in state dict
for key in state_dict.keys():
if key.endswith('patch_embed.proj.weight') or key.endswith('patch_embed.projection.weight'):
proj_key = key
break
proj_key, prefix = get_proj_key(state_dict, return_prefix=True, encoder_only=encoder_only)
if proj_key is None or proj_key not in state_dict:
raise Exception("Could not find key for patch embed weight in state_dict.")

patch_embed_weight = state_dict[proj_key]

temp_weight = model.state_dict()[proj_key].clone()
# It seems `proj_key` can have different names for
# the checkpoint and the model instance
proj_key_, _ = get_proj_key(model.state_dict(), encoder_only=encoder_only)

if proj_key_:
temp_weight = model.state_dict()[proj_key_].clone()
else:
temp_weight = model.state_dict()[proj_key].clone()

# only do this if the patch size and tubelet size match. If not, start with random weights
if patch_embed_weights_are_compatible(temp_weight, patch_embed_weight):
Expand All @@ -68,4 +140,7 @@ def select_patch_embed_weights(

state_dict[proj_key] = temp_weight

if prefix:
state_dict = remove_prefixes(state_dict, prefix)

return state_dict
2 changes: 2 additions & 0 deletions terratorch/models/backbones/swin_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ def __init__(
norm_layer=nn.LayerNorm,
with_cp=False, # noqa: FBT002
frozen_stages=-1,
**kwargs,
):
self.frozen_stages = frozen_stages
self.output_fmt = "NHWC"
Expand Down Expand Up @@ -984,6 +985,7 @@ def __init__(
in_chans = downsample.out_channels
self.stages = nn.Sequential(*stages)
self.num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
self.out_channels = self.num_features
# Add a norm layer for each output

self.head = ClassifierHead(
Expand Down
7 changes: 6 additions & 1 deletion terratorch/models/encoder_decoder_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from terratorch.models.scalar_output_model import ScalarOutputModel
from terratorch.models.utils import extract_prefix_keys
from terratorch.registry import BACKBONE_REGISTRY, DECODER_REGISTRY, MODEL_FACTORY_REGISTRY
from terratorch.registry.timm_registry import TimmBackboneWrapper

PIXEL_WISE_TASKS = ["segmentation", "regression"]
SCALAR_TASKS = ["classification"]
Expand Down Expand Up @@ -152,7 +153,11 @@ def build_model(
backbone = get_peft_backbone(peft_config, backbone)

try:
out_channels = backbone.out_channels
if isinstance(backbone, TimmBackboneWrapper):
backbone_ = backbone._timm_module
out_channels = backbone_.out_channels
else:
out_channels = backbone.out_channels
except AttributeError as e:
msg = "backbone must have out_channels attribute"
raise AttributeError(msg) from e
Expand Down
9 changes: 9 additions & 0 deletions terratorch/registry/timm_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def build(self, name: str, features_only=True, *constructor_args, **constructor_
Use prefixes ending with _ to forward to a specific source
"""
try:
"""
return TimmBackboneWrapper(
timm.create_model(
name,
Expand All @@ -42,6 +43,14 @@ def build(self, name: str, features_only=True, *constructor_args, **constructor_
**constructor_kwargs,
)
)
"""
return timm.create_model(
name,
*constructor_args,
features_only=features_only,
**constructor_kwargs,
)

except RuntimeError as e:
if "Unknown model" in str(e):
msg = f"Unknown model {name}"
Expand Down
13 changes: 12 additions & 1 deletion terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
tiled_inference_parameters: TiledInferenceParameters = None,
test_dataloaders_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
output_most_probable: bool = True,
) -> None:
"""Constructor

Expand Down Expand Up @@ -112,6 +113,8 @@ def __init__(
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
output_most_probable (bool): A boolean to define if the output during the inference will be just
for the most probable class or if it will include all of them.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand All @@ -138,6 +141,12 @@ def __init__(
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"
self.plot_on_val = int(plot_on_val)
self.output_most_probable = output_most_probable

if output_most_probable:
self.select_classes = lambda y: y.argmax(dim=1)
else:
self.select_classes = lambda y: y

def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand Down Expand Up @@ -351,5 +360,7 @@ def model_forward(x):
)
else:
y_hat: Tensor = self(x, **rest).output
y_hat = y_hat.argmax(dim=1)

y_hat = self.select_classes(y_hat)

return y_hat, file_names
Loading
Loading