Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions olmoearth_pretrain/evals/eval_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
pooling_type: PoolingType,
concat_features: bool = False,
use_pooled_tokens: bool = False,
feature_exit_depth: int | None = None,
):
"""Initialize the eval wrapper.

Expand All @@ -58,6 +59,7 @@ def __init__(
pooling_type: The pooling type to use for the model.
concat_features: Whether to concatenate features across modalities.
use_pooled_tokens: Whether to use pooled tokens.
feature_exit_depth: If set, use a uniform token exit depth for all modalities.
is_train: whether this is being used on the training data.
"""
super().__init__()
Expand All @@ -68,6 +70,7 @@ def __init__(
self.concat_features = concat_features
self.spatial_pool = task_type == TaskType.SEGMENTATION
self.use_pooled_tokens = use_pooled_tokens
self.feature_exit_depth = feature_exit_depth
if self.use_pooled_tokens:
assert isinstance(self.model, EncodeEarlyAttnPool), (
"Pooled tokens are only supported for EncodeEarlyAttnPool"
Expand Down Expand Up @@ -104,6 +107,30 @@ def __call__(
class OlmoEarthEvalWrapper(EvalWrapper):
"""Wrapper for OlmoEarth Pretrain models."""

def _get_token_exit_cfg(self) -> dict[str, int] | None:
"""Build a uniform token exit config for all supported modalities."""
if self.feature_exit_depth is None:
return None
if self.use_pooled_tokens:
raise ValueError(
"feature_exit_depth is not supported when use_pooled_tokens=True"
)

encoder_depth = len(self.model.blocks)
if self.feature_exit_depth < 0 or self.feature_exit_depth > encoder_depth:
raise ValueError(
f"feature_exit_depth must be in [0, {encoder_depth}], "
f"got {self.feature_exit_depth}"
)

supported_modalities = getattr(self.model, "supported_modality_names", None)
if supported_modalities is None:
raise ValueError(
"feature_exit_depth requires model.supported_modality_names"
)

return {modality: self.feature_exit_depth for modality in supported_modalities}

def __call__(
self,
masked_olmoearth_sample: MaskedOlmoEarthSample,
Expand All @@ -112,9 +139,13 @@ def __call__(
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
if not self.use_pooled_tokens:
token_exit_cfg = self._get_token_exit_cfg()
batch_embeddings: TokensAndMasks = self.model(
masked_olmoearth_sample, patch_size=self.patch_size, fast_pass=True
)["tokens_and_masks"] # (bsz, dim)
masked_olmoearth_sample,
patch_size=self.patch_size,
token_exit_cfg=token_exit_cfg,
fast_pass=(token_exit_cfg is None),
)["tokens_and_masks"]
# Concat features across modalities in space averaged across time
batch_embeddings = pool_unmasked_tokens(
batch_embeddings,
Expand Down
Loading
Loading