Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions olmoearth_pretrain/evals/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
logger = getLogger(__name__)


class _BackboneWithHead(nn.Module):
class BackboneWithHead(nn.Module):
"""Backbone model with a classification or segmentation head."""

def __init__(
Expand All @@ -33,6 +33,7 @@ def __init__(
num_classes: int,
use_pooled_tokens: bool = False,
) -> None:
"""Initialize the backbone with head."""
super().__init__()
self.backbone = model
self.wrapper = get_eval_wrapper(
Expand All @@ -42,6 +43,8 @@ def __init__(
pooling_type=pooling_type,
concat_features=False,
use_pooled_tokens=use_pooled_tokens,
# Set this to False to avoid downsampling embeddings and labels for AnySat
is_train=False,
)
self.task_type = task_type
self.patch_size = patch_size
Expand All @@ -66,15 +69,14 @@ def forward(
) -> torch.Tensor:
"""Forward pass through the model and head."""
dev = next(self.wrapper.parameters()).device
# classification: (B, D), segmentation: (B, H, W, D)
emb, _ = self.wrapper(batch, None)
emb, labels = self.wrapper(batch, labels)
emb = cast(torch.Tensor, emb)
emb_dim = emb.shape[-1]
if not self._inited:
self._init_head(emb_dim, dev)
if emb.device != dev:
emb = emb.to(dev, non_blocking=True)
return self._head(emb)
return self._head(emb), labels


def _to_device(
Expand All @@ -92,7 +94,7 @@ def _to_device(

@torch.no_grad()
def _eval_cls(
module: _BackboneWithHead,
module: BackboneWithHead,
loader: DataLoader,
device: torch.device,
is_multilabel: bool,
Expand All @@ -104,7 +106,7 @@ def _eval_cls(
label = label.to(device=device)
masked = _to_device(masked, device)
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
logits = module(masked, label) # (B, C)
logits, _ = module(masked, label) # (B, C)
logits_all.append(logits.float().cpu())
labels_all.append(label.cpu())
logits = torch.cat(logits_all, 0)
Expand All @@ -124,7 +126,7 @@ def _eval_cls(

@torch.no_grad()
def _eval_seg(
module: _BackboneWithHead,
module: BackboneWithHead,
loader: DataLoader,
device: torch.device,
num_classes: int,
Expand All @@ -137,7 +139,7 @@ def _eval_seg(
label = label.to(device=device)
masked = _to_device(masked, device)
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
logits = module(masked, label) # (B, H, W, C*p*p)
logits, _ = module(masked, label) # (B, H, W, C*p*p)
H, W = logits.shape[1], logits.shape[2]
logits = rearrange(
logits,
Expand All @@ -162,11 +164,17 @@ def _eval_seg(
return mean_iou(preds, labels, num_classes=num_classes, ignore_label=-1)


def count_params(module: nn.Module) -> tuple[int, int]:
"""Count total and trainable parameters in a module."""
total = sum(p.numel() for p in module.parameters())
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
return total, trainable
def count_params(backbone: nn.Module, head: nn.Module) -> tuple[int, int, int, int]:
"""Count total and trainable parameters separately for the backbone and the linear head."""
total_backbone = sum(p.numel() for p in backbone.parameters())
trainable_backbone = sum(
p.numel() for p in backbone.parameters() if p.requires_grad
)

total_head = sum(p.numel() for p in head.parameters())
trainable_head = sum(p.numel() for p in head.parameters() if p.requires_grad)

return total_backbone, trainable_backbone, total_head, trainable_head


def _snapshot_state_dict(module: nn.Module) -> dict[str, torch.Tensor]:
Expand All @@ -188,7 +196,7 @@ def run_finetune_eval(
test_loader: DataLoader | None,
) -> tuple[float, float]:
"""Finetune the model on a downstream task and evaluate."""
ft = _BackboneWithHead(
ft = BackboneWithHead(
model=model,
task_type=task_config.task_type,
patch_size=patch_size,
Expand All @@ -200,11 +208,15 @@ def run_finetune_eval(
# Trigger _init_head once with a tiny dry pass
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16):
sample_batch, label = next(iter(train_loader))
_ = ft(_to_device(sample_batch, device), label.to(device))
_, _ = ft(_to_device(sample_batch, device), label.to(device))

total, trainable = count_params(ft)
logger.info(f"Total parameters: {total:,}")
logger.info(f"Trainable parameters: {trainable:,}")
total_backbone, trainable_backbone, total_head, trainable_head = count_params(
ft.backbone, ft._head
)
logger.info(f"Total backbone parameters: {total_backbone:,}")
logger.info(f"Trainable backbone parameters: {trainable_backbone:,}")
logger.info(f"Total head parameters: {total_head:,}")
logger.info(f"Trainable head parameters: {trainable_head:,}")

opt = torch.optim.AdamW(ft.parameters(), lr=lr)
if task_config.task_type == TaskType.CLASSIFICATION:
Expand All @@ -216,7 +228,8 @@ def run_finetune_eval(
else:
loss_fn = nn.CrossEntropyLoss(ignore_index=-1)

patience = max(1, int(0.1 * epochs)) if epochs > 0 else 1
# Set patience to higher so that we don't missed the best model
patience = max(1, int(0.2 * epochs)) if epochs > 0 else 1
logger.info(f"Using early stopping patience of {patience} epochs")

best_state = _snapshot_state_dict(ft)
Expand All @@ -230,7 +243,7 @@ def run_finetune_eval(
label = label.to(device=device)
masked = _to_device(masked, device)
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
logits = ft(masked, label)
logits, label = ft(masked, label)
if task_config.task_type == TaskType.SEGMENTATION:
H, W = logits.shape[1], logits.shape[2]
logits = rearrange(
Expand Down Expand Up @@ -260,9 +273,8 @@ def run_finetune_eval(
total_epochs=epochs,
warmup_epochs=max(1, int(0.1 * epochs)),
max_lr=lr,
min_lr=1.0e-5,
min_lr=1.0e-6,
)
# torch.nn.utils.clip_grad_norm_(ft.parameters(), 1.0)
opt.step()
opt.zero_grad()

Expand Down
1 change: 1 addition & 0 deletions olmoearth_pretrain/evals/models/anysat/anysat.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _calculate_patch_size(self, h: int) -> int:
# based on https://arxiv.org/pdf/2412.14123, a patch size of
# 40 is the minimum used for images of 128x128. Since smaller patches
# = more tokens, this should lead to the best performance
# TODO: this is not taking into account the input image size, e.g. 256x256
h_in_m = h * 10
patch_size = min(40, h_in_m)
return patch_size
Expand Down
51 changes: 38 additions & 13 deletions olmoearth_pretrain/evals/models/galileo/single_file_galileo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from collections import OrderedDict
from collections import OrderedDict as OrderedDictType
from collections.abc import Sequence
from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -1681,6 +1682,10 @@ def forward(
torch.stack(output_st, dim=-2),
)

AUTOCAST_DTYPE_MAP = {
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}

class GalileoWrapper(nn.Module):
"""GalileoWrapper."""
Expand All @@ -1698,6 +1703,7 @@ def __init__(
month: int = 6,
add_layernorm_on_exit: bool = True,
use_pretrained_normalizer: bool = True,
autocast_dtype: str | None = "bfloat16"
):
"""Init GalileoWrapper."""
super().__init__()
Expand Down Expand Up @@ -1735,6 +1741,11 @@ def __init__(
else:
self.normalizer = None

if autocast_dtype is not None:
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
else:
self.autocast_dtype = None

def preproccess(
self,
s2: torch.Tensor | None = None,
Expand Down Expand Up @@ -1870,19 +1881,33 @@ def forward(
if s_t_x.shape[1] < self.patch_size:
logger.info(f"tile size {s_t_x.shape[1]} < self.patch size {self.patch_size}. Using tile size as patch size.")
patch_size = s_t_x.shape[1]
output = self.galileo_encoder(
s_t_x,
sp_x,
t_x,
st_x,
s_t_m,
sp_m,
t_m,
st_m,
month,
patch_size=patch_size,
add_layernorm_on_exit=self.add_layernorm_on_exit,
)

# Decide context based on self.autocast_dtype
device = s_t_x.device
if self.autocast_dtype is None:
context = nullcontext()
else:
assert device is not None
context = torch.amp.autocast(
device_type=device.type,
dtype=self.autocast_dtype
)

with context:
output = self.galileo_encoder(
s_t_x,
sp_x,
t_x,
st_x,
s_t_m,
sp_m,
t_m,
st_m,
month,
patch_size=patch_size,
add_layernorm_on_exit=self.add_layernorm_on_exit,
)

s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, _ = output
if not spatial_pool:
return self.galileo_encoder.average_tokens(
Expand Down
5 changes: 4 additions & 1 deletion olmoearth_pretrain/evals/models/panopticon/panopticon.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Panopticon(nn.Module):
"""Class containing the Panopticon model that can ingest MaskedOlmoEarthSample objects."""

patch_size: int = 14
image_resolution: int = 224
supported_modalities: list[str] = [
Modality.SENTINEL2_L2A.name,
Modality.LANDSAT.name,
Expand Down Expand Up @@ -84,7 +85,9 @@ def _process_modality_data(self, data: torch.Tensor) -> list[torch.Tensor]:
for i in range(t_dim):
data_i = rearrange(data[:, :, :, i, :], "b h w c -> b c h w")

new_height = self.patch_size if original_height == 1 else 224
new_height = (
self.patch_size if original_height == 1 else self.image_resolution
)

data_i = F.interpolate(
data_i,
Expand Down
Loading