Skip to content

Commit 1b59268

Browse files
committed
move the overide inside private api
1 parent 4b4647d commit 1b59268

1 file changed

Lines changed: 43 additions & 10 deletions

File tree

olmoearth_pretrain/train/masking.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,10 @@ def __init__(
924924
self.max_decoded_bandsets = max_decoded_bandsets
925925
self.only_decode_modalities = only_decode_modalities
926926

927+
# =================================================================
928+
# PUBLIC API
929+
# =================================================================
930+
927931
def get_sample_present_modalities_bandsets(
928932
self, batch: MaskedOlmoEarthSample
929933
) -> list[list[tuple[str, int]]]:
@@ -1114,13 +1118,11 @@ def _randomly_select_decoded_bandsets(
11141118
)
11151119
return set([available[i] for i in decoded_idxs])
11161120

1117-
def overide_strategy_mask(self, modality_spec: ModalitySpec) -> bool:
1118-
"""Overide the mask for a modality depending on the strategy being modality cross masked.
1121+
# =================================================================
1122+
# PHASE 2: MASK APPLICATION
1123+
# Methods for applying encode/decode rules to mask tensors
1124+
# =================================================================
11191125

1120-
e.g in time masking, static in time data is randomly masked but we want that data to be either used to predict temporally masked data or
1121-
predicted from temporal data.
1122-
"""
1123-
return False
11241126

11251127
def apply_bandset_mask_rules(
11261128
self,
@@ -1268,7 +1270,7 @@ def _apply_single_bandset_mask(
12681270
is_decoded = (modality, bandset_idx) in decoded_bandset_idxs
12691271

12701272
# Handle special modalities that need override (e.g., static in time/space)
1271-
if self.overide_strategy_mask(modality_spec):
1273+
if self._overide_strategy_mask(modality_spec):
12721274
self._force_override_mask(
12731275
sample_idx,
12741276
bandset_idx,
@@ -1293,6 +1295,15 @@ def _apply_single_bandset_mask(
12931295
sample_idx, bandset_idx, modality_mask, out_modality_mask
12941296
)
12951297

1298+
1299+
def _overide_strategy_mask(self, modality_spec: ModalitySpec) -> bool:
1300+
"""Overide the mask for a modality depending on the strategy being modality cross masked.
1301+
1302+
e.g in time masking, static in time data is randomly masked but we want that data to be either used to predict temporally masked data or
1303+
predicted from temporal data.
1304+
"""
1305+
return False
1306+
12961307
def _force_override_mask(
12971308
self,
12981309
sample_idx: int,
@@ -1371,6 +1382,11 @@ def _suppress_undecoded_bandset(
13711382
modality_mask[sample_idx, ..., bandset_idx],
13721383
)
13731384

1385+
# =================================================================
1386+
# PHASE 3: COUNTING & VALIDATION
1387+
# Methods for counting tokens and accumulating counts across modalities
1388+
# =================================================================
1389+
13741390
def _count_modality_tokens(
13751391
self, modality_mask: torch.Tensor
13761392
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -1388,6 +1404,11 @@ def _accumulate_token_counts(
13881404
return new_count
13891405
return current + new_count
13901406

1407+
# =================================================================
1408+
# PHASE 4: EDGE CASE HANDLING
1409+
# Methods for handling rare edge cases (e.g., samples with no tokens)
1410+
# =================================================================
1411+
13911412
def _handle_no_tokens_edge_cases(
13921413
self,
13931414
masked_batch_dict: dict[str, Any],
@@ -1432,10 +1453,22 @@ def _fix_sample_with_no_tokens(
14321453
modality_mask, modality_spec, patch_size
14331454
)
14341455

1456+
# =================================================================
1457+
# MAIN ENTRY POINT
1458+
# =================================================================
1459+
14351460
def apply_mask(
14361461
self, batch: OlmoEarthSample, patch_size: int | None = None, **kwargs: Any
14371462
) -> MaskedOlmoEarthSample:
1438-
"""Apply space masking to the input data."""
1463+
"""Apply cross-modality masking to the input data.
1464+
1465+
This is the main entry point that orchestrates the entire masking process:
1466+
1. Apply base masking strategy (e.g., space/time masking)
1467+
2. Identify present modalities and bandsets for each sample
1468+
3. Select which bandsets to encode vs decode
1469+
4. Apply bandset-level masking rules
1470+
5. Handle edge cases
1471+
"""
14391472
if patch_size is None:
14401473
# this is because we use a random-masking proxy in case of
14411474
# no encoded or decoded tokens.
@@ -1487,7 +1520,7 @@ def __init__(
14871520
only_decode_modalities=only_decode_modalities,
14881521
)
14891522

1490-
def overide_strategy_mask(self, modality_spec: ModalitySpec) -> bool:
1523+
def _overide_strategy_mask(self, modality_spec: ModalitySpec) -> bool:
14911524
"""Overide the random mask for the given modality by the encoding and decoding bandsets."""
14921525
# For space masking non spatial data is randomly masked but we want to use the encoding and decoding bandsets
14931526
# to determine the mask for the non spatial data
@@ -1526,7 +1559,7 @@ def __init__(
15261559
only_decode_modalities=only_decode_modalities,
15271560
)
15281561

1529-
def overide_strategy_mask(self, modality_spec: ModalitySpec) -> bool:
1562+
def _overide_strategy_mask(self, modality_spec: ModalitySpec) -> bool:
15301563
"""Overide the random mask for the given modality by the encoding and decoding bandsets."""
15311564
# For time masking static data is randomly masked but we want to use the encoding and decoding bandsets
15321565
# to determine the mask for the static data

0 commit comments

Comments
 (0)