Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WiP] More logging during execution (stage 1). #390

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
7 changes: 5 additions & 2 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@
SemanticSegmentationTask, # noqa: F401
)

logger = logging.getLogger("terratorch")
from terratorch.utils import get_logger

logger = get_logger()

from terratorch.utils import remove_unexpected_prefix

Expand Down Expand Up @@ -420,6 +422,7 @@ def instantiate_classes(self) -> None:
else:
custom_modules_path = os.getenv("TERRATORCH_CUSTOM_MODULE_PATH", None)

logger.debug(f"Import custom modules from {custom_modules_path}")
import_custom_modules(custom_modules_path)

@staticmethod
Expand Down Expand Up @@ -455,7 +458,6 @@ def build_lightning_cli(
UserWarning,
stacklevel=1,
)

return MyLightningCLI(
model_class=BaseTask,
subclass_mode_model=True,
Expand Down Expand Up @@ -495,6 +497,7 @@ def __init__(
self.model = model
self.datamodule = datamodule
if checkpoint_path:
logger.info(f"Loading weights from local checkpoint: {checkpoint_path}")
weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
if "state_dict" in weights:
weights = weights["state_dict"]
Expand Down
5 changes: 4 additions & 1 deletion terratorch/io/file.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import logging
import torch
import importlib
from torch import nn
import numpy as np

logger = logging.getLogger(__name__)

def open_generic_torch_model(model: type | str = None,
model_kwargs: dict = None,
model_weights_path: str = None):
Expand All @@ -26,7 +29,7 @@ def open_generic_torch_model(model: type | str = None,

def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = None, device: str = None) -> None:

print(f"Trying to load for {device}")
logger.info(f"Trying to load for {device}")

try: # If 'model' was instantiated outside this function, the dictionary of weights will be loaded.
if device != None:
Expand Down
3 changes: 2 additions & 1 deletion terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from terratorch.datasets.utils import generate_bands_intervals
from terratorch.models.backbones.prithvi_mae import PrithviViT, PrithviMAE
from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY
from terratorch.utils import get_logger
from huggingface_hub import hf_hub_download

logger = logging.getLogger(__name__)
logger = get_logger()

PRETRAINED_BANDS = [
HLSBands.BLUE,
Expand Down
3 changes: 2 additions & 1 deletion terratorch/models/backbones/select_patch_embed_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def select_patch_embed_weights(
Returns:
dict: New state dict
"""

if (type(pretrained_bands) == type(model_bands)) | (type(pretrained_bands) == int) | (type(model_bands) == int):

state_dict = get_state_dict(state_dict)
Expand All @@ -110,7 +111,7 @@ def select_patch_embed_weights(
# Search for patch embedding weight in state dict
proj_key, prefix = get_proj_key(state_dict, return_prefix=True, encoder_only=encoder_only)
if proj_key is None or proj_key not in state_dict:
raise Exception("Could not find key for patch embed weight in state_dict.")
raise Exception(f"Could not find key {proj_key} for patch embed weight in state_dict.")

patch_embed_weight = state_dict[proj_key]

Expand Down
12 changes: 12 additions & 0 deletions terratorch/models/encoder_decoder_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from terratorch.models.scalar_output_model import ScalarOutputModel
from terratorch.models.utils import extract_prefix_keys
from terratorch.registry import BACKBONE_REGISTRY, DECODER_REGISTRY, MODEL_FACTORY_REGISTRY
from terratorch.utils import get_logger

PIXEL_WISE_TASKS = ["segmentation", "regression"]
SCALAR_TASKS = ["classification"]
SUPPORTED_TASKS = PIXEL_WISE_TASKS + SCALAR_TASKS

logger = get_logger()

def _get_backbone(backbone: str | nn.Module, **backbone_kwargs) -> nn.Module:
if isinstance(backbone, nn.Module):
Expand Down Expand Up @@ -231,6 +233,16 @@ def _build_appropriate_model(
neck_module: nn.Module = nn.Sequential(*necks)
else:
neck_module = None

# Printing data for debugging
logger.debug(f"Task: {task}")
logger.debug(f"Backbone: {backbone.__class__}")
logger.debug(f"Decoder: {decoder.__class__}")
logger.debug(f"head_kwargs: {head_kwargs}")
logger.debug(f"patch_size: {patch_size}")
logger.debug(f"rescale: {rescale}")
logger.debug(f"necks: {necks}")

if task in PIXEL_WISE_TASKS:
return PixelWiseModel(
task,
Expand Down
5 changes: 4 additions & 1 deletion terratorch/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def build(self, name: str, *constructor_args, **constructor_kwargs):

# if no prefix, try to build in order
for source in self._sources.values():
with suppress(KeyError):
try:
return source.build(name, *constructor_args, **constructor_kwargs)
except KeyError:
raise Exception(f"It wasn't possible to load model from source {source}.")

msg = f"Could not instantiate model {name} not from any source."
raise KeyError(msg)
Expand Down Expand Up @@ -138,6 +140,7 @@ def build(self, name: str, *constructor_args, **constructor_kwargs):
"""Build and return the component.
Use prefixes ending with _ to forward to a specific source
"""

return self._registry[name](*constructor_args, **constructor_kwargs)

def __iter__(self):
Expand Down
4 changes: 2 additions & 2 deletions terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from terratorch.models.model import Model
from terratorch.tasks.optimizer_factory import optimizer_factory
from terratorch.utils import get_logger

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10
logger = logging.getLogger("terratorch")

logger = get_logger()

class TerraTorchTask(BaseTask):
"""
Expand Down
8 changes: 6 additions & 2 deletions terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from terratorch.tasks.optimizer_factory import optimizer_factory
from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference
from terratorch.tasks.base_task import TerraTorchTask
from terratorch.utils import get_logger

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

logger = logging.getLogger("terratorch")

logger = get_logger()

class RootLossWrapper(nn.Module):
def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None:
Expand Down Expand Up @@ -226,13 +226,17 @@ def __init__(
self.monitor = f"{self.val_metrics.prefix}loss"
self.plot_on_val = int(plot_on_val)

logger.info(f"Instantiating a class {self.__class__}")
logger.debug(f"Using hparams: {self.hparams}")

def configure_losses(self) -> None:
"""Initialize the loss criterion.

Raises:
ValueError: If *loss* is invalid.
"""
loss: str = self.hparams["loss"].lower()

if loss == "mse":
self.criterion: nn.Module = IgnoreIndexLossWrapper(
nn.MSELoss(reduction="none"), self.hparams["ignore_index"]
Expand Down
7 changes: 5 additions & 2 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from terratorch.tasks.optimizer_factory import optimizer_factory
from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference
from terratorch.tasks.base_task import TerraTorchTask
from terratorch.utils import get_logger

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

logger = logging.getLogger("terratorch")

logger = get_logger()

def to_segmentation_prediction(y: ModelOutput) -> Tensor:
y_hat = y.output
Expand Down Expand Up @@ -148,6 +148,9 @@ def __init__(
else:
self.select_classes = lambda y: y

logger.info(f"Instantiating a class {self.__class__}")
logger.debug(f"Using hparams: {self.hparams}")

def configure_losses(self) -> None:
"""Initialize the loss criterion.

Expand Down
13 changes: 12 additions & 1 deletion terratorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import math
from collections import Counter

import logging
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
Expand Down Expand Up @@ -69,6 +70,16 @@ def compute_float_mask_statistics(dataloader: DataLoader) -> dict[str, float]:
std = math.sqrt(variance)
return {"mean": mean, "std": std}

def get_logger():

loglevel = os.getenv("LOGLEVEL")

if loglevel:
logging.basicConfig(level=loglevel.upper(), format='%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d:%H:%M:%S')

logger = logging.getLogger("terratorch")
return logger

# TODO remove it for future releases
def remove_unexpected_prefix(state_dict):
state_dict_ = {}
Expand Down
Loading