From 57a5298175b98c4d01d9f5c40fd924f6ded7d921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 30 Jan 2025 12:30:16 -0300 Subject: [PATCH] weights_only=True for all the occurences of torch.load MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/io/file.py | 3 ++- terratorch/models/backbones/clay_v1/embedder.py | 2 +- terratorch/models/backbones/dofa_vit.py | 2 +- terratorch/models/backbones/prithvi_vit.py | 4 ++-- terratorch/models/backbones/scalemae.py | 2 +- terratorch/models/backbones/torchgeo_resnet.py | 4 ++-- terratorch/models/backbones/torchgeo_swin_satlas.py | 4 ++-- terratorch/models/backbones/torchgeo_vit.py | 4 ++-- terratorch/models/satmae_model_factory.py | 4 ++-- 9 files changed, 15 insertions(+), 14 deletions(-) diff --git a/terratorch/io/file.py b/terratorch/io/file.py index 942a09e0..7805f716 100644 --- a/terratorch/io/file.py +++ b/terratorch/io/file.py @@ -34,10 +34,11 @@ def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = N torch.load( os.path.join(save_dir, name), map_location=torch.device(device), + weights_only=True, ) ) else: - model.load_state_dict(torch.load(os.path.join(save_dir, name), map_location='cpu')) + model.load_state_dict(torch.load(os.path.join(save_dir, name), map_location='cpu', weights_only=True)) except Exception: print( diff --git a/terratorch/models/backbones/clay_v1/embedder.py b/terratorch/models/backbones/clay_v1/embedder.py index 18bf03e9..cd2c7da1 100644 --- a/terratorch/models/backbones/clay_v1/embedder.py +++ b/terratorch/models/backbones/clay_v1/embedder.py @@ -67,7 +67,7 @@ def __init__( def load_clay_weights(self, ckpt_path): "Load the weights from the Clay model encoder." - ckpt = torch.load(ckpt_path) + ckpt = torch.load(ckpt_path, weights_only=True) state_dict = ckpt.get("state_dict") state_dict = { re.sub(r"^model\.encoder\.", "", name): param diff --git a/terratorch/models/backbones/dofa_vit.py b/terratorch/models/backbones/dofa_vit.py index a4d7ef5e..035e1e63 100644 --- a/terratorch/models/backbones/dofa_vit.py +++ b/terratorch/models/backbones/dofa_vit.py @@ -141,7 +141,7 @@ def load_dofa_weights(model: nn.Module, ckpt_data: str | None = None, weights: repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '') filename = ckpt_data.split("/")[-1] ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename) - checkpoint_model = torch.load(ckpt_data, map_location="cpu") + checkpoint_model = torch.load(ckpt_data, map_location="cpu", weights_only=True) for k in ["head.weight", "head.bias"]: if ( diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 9f85b82d..c2723f94 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -215,7 +215,7 @@ def _create_prithvi( if ckpt_path is not None: # Load model from checkpoint - state_dict = torch.load(ckpt_path, map_location="cpu") + state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands) model.load_state_dict(state_dict, strict=False) elif pretrained: @@ -225,7 +225,7 @@ def _create_prithvi( # Load model from Hugging Face pretrained_path = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"], filename=pretrained_weights[variant]["hf_hub_filename"]) - state_dict = torch.load(pretrained_path, map_location="cpu") + state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True) state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands) model.load_state_dict(state_dict, strict=True) except RuntimeError as e: diff --git a/terratorch/models/backbones/scalemae.py b/terratorch/models/backbones/scalemae.py index 28b4d12f..0473fd25 100644 --- a/terratorch/models/backbones/scalemae.py +++ b/terratorch/models/backbones/scalemae.py @@ -258,7 +258,7 @@ def vit_huge_patch14(**kwargs): return model def load_scalemae_weights(model: nn.Module, ckpt_data: str, model_bands: list[HLSBands], input_size: int = 224) -> nn.Module: - checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"] + checkpoint_model = torch.load(ckpt_data, map_location="cpu", weights_only=True)["model"] state_dict = model.state_dict() for k in ["head.weight", "head.bias"]: diff --git a/terratorch/models/backbones/torchgeo_resnet.py b/terratorch/models/backbones/torchgeo_resnet.py index b2ab79ed..2f2397d6 100644 --- a/terratorch/models/backbones/torchgeo_resnet.py +++ b/terratorch/models/backbones/torchgeo_resnet.py @@ -448,8 +448,8 @@ def load_resnet_weights(model: nn.Module, model_bands, ckpt_data: str, weights: repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '') filename = ckpt_data.split("/")[-1] ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename) - # checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"] - checkpoint_model = torch.load(ckpt_data, map_location="cpu") + + checkpoint_model = torch.load(ckpt_data, map_location="cpu", weights_only=True) state_dict = model.state_dict() for k in ["fc.weight", "fc.bias"]: diff --git a/terratorch/models/backbones/torchgeo_swin_satlas.py b/terratorch/models/backbones/torchgeo_swin_satlas.py index 1de43453..d46e163c 100644 --- a/terratorch/models/backbones/torchgeo_swin_satlas.py +++ b/terratorch/models/backbones/torchgeo_swin_satlas.py @@ -236,8 +236,8 @@ def load_swin_weights(model: nn.Module, model_bands, ckpt_data: str, weights: We repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '') filename = ckpt_data.split("/")[-1] ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename) - # checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"] - checkpoint_model = torch.load(ckpt_data, map_location="cpu") + + checkpoint_model = torch.load(ckpt_data, map_location="cpu", weights_only=True) state_dict = model.state_dict() for k in ["head.weight", "head.bias"]: diff --git a/terratorch/models/backbones/torchgeo_vit.py b/terratorch/models/backbones/torchgeo_vit.py index dbe9d97b..17abd13a 100644 --- a/terratorch/models/backbones/torchgeo_vit.py +++ b/terratorch/models/backbones/torchgeo_vit.py @@ -179,8 +179,8 @@ def load_vit_weights(model: nn.Module, model_bands, ckpt_data: str, weights: Wei repo_id = ckpt_data.split("/resolve/")[0].replace("https://hf.co/", '') filename = ckpt_data.split("/")[-1] ckpt_data = huggingface_hub.hf_hub_download(repo_id=repo_id, filename=filename) - # checkpoint_model = torch.load(ckpt_data, map_location="cpu")["model"] - checkpoint_model = torch.load(ckpt_data, map_location="cpu") + + checkpoint_model = torch.load(ckpt_data, map_location="cpu", weights_only=True) state_dict = model.state_dict() for k in ["head.weight", "head.bias"]: diff --git a/terratorch/models/satmae_model_factory.py b/terratorch/models/satmae_model_factory.py index 17f4f629..98cd04d1 100644 --- a/terratorch/models/satmae_model_factory.py +++ b/terratorch/models/satmae_model_factory.py @@ -227,9 +227,9 @@ def build_model( backbone: nn.Module = ModelWrapper(model=backbone_template(**backbone_kwargs), kind=backbone_kind) if self.CPU_ONLY: - model_dict = torch.load(checkpoint_path, map_location="cpu") + model_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) else: - model_dict = torch.load(checkpoint_path) + model_dict = torch.load(checkpoint_path, weights_only=True) # Filtering parameters from the model state_dict (when necessary)