@@ -924,6 +924,10 @@ def __init__(
924924 self .max_decoded_bandsets = max_decoded_bandsets
925925 self .only_decode_modalities = only_decode_modalities
926926
927+ # =================================================================
928+ # PUBLIC API
929+ # =================================================================
930+
927931 def get_sample_present_modalities_bandsets (
928932 self , batch : MaskedOlmoEarthSample
929933 ) -> list [list [tuple [str , int ]]]:
@@ -1114,13 +1118,11 @@ def _randomly_select_decoded_bandsets(
11141118 )
11151119 return set ([available [i ] for i in decoded_idxs ])
11161120
1117- def overide_strategy_mask (self , modality_spec : ModalitySpec ) -> bool :
1118- """Overide the mask for a modality depending on the strategy being modality cross masked.
1121+ # =================================================================
1122+ # PHASE 2: MASK APPLICATION
1123+ # Methods for applying encode/decode rules to mask tensors
1124+ # =================================================================
11191125
1120- e.g in time masking, static in time data is randomly masked but we want that data to be either used to predict temporally masked data or
1121- predicted from temporal data.
1122- """
1123- return False
11241126
11251127 def apply_bandset_mask_rules (
11261128 self ,
@@ -1268,7 +1270,7 @@ def _apply_single_bandset_mask(
12681270 is_decoded = (modality , bandset_idx ) in decoded_bandset_idxs
12691271
12701272 # Handle special modalities that need override (e.g., static in time/space)
1271- if self .overide_strategy_mask (modality_spec ):
1273+ if self ._overide_strategy_mask (modality_spec ):
12721274 self ._force_override_mask (
12731275 sample_idx ,
12741276 bandset_idx ,
@@ -1293,6 +1295,15 @@ def _apply_single_bandset_mask(
12931295 sample_idx , bandset_idx , modality_mask , out_modality_mask
12941296 )
12951297
1298+
1299+ def _overide_strategy_mask (self , modality_spec : ModalitySpec ) -> bool :
1300+ """Overide the mask for a modality depending on the strategy being modality cross masked.
1301+
1302+ e.g in time masking, static in time data is randomly masked but we want that data to be either used to predict temporally masked data or
1303+ predicted from temporal data.
1304+ """
1305+ return False
1306+
12961307 def _force_override_mask (
12971308 self ,
12981309 sample_idx : int ,
@@ -1371,6 +1382,11 @@ def _suppress_undecoded_bandset(
13711382 modality_mask [sample_idx , ..., bandset_idx ],
13721383 )
13731384
1385+ # =================================================================
1386+ # PHASE 3: COUNTING & VALIDATION
1387+ # Methods for counting tokens and accumulating counts across modalities
1388+ # =================================================================
1389+
13741390 def _count_modality_tokens (
13751391 self , modality_mask : torch .Tensor
13761392 ) -> tuple [torch .Tensor , torch .Tensor ]:
@@ -1388,6 +1404,11 @@ def _accumulate_token_counts(
13881404 return new_count
13891405 return current + new_count
13901406
1407+ # =================================================================
1408+ # PHASE 4: EDGE CASE HANDLING
1409+ # Methods for handling rare edge cases (e.g., samples with no tokens)
1410+ # =================================================================
1411+
13911412 def _handle_no_tokens_edge_cases (
13921413 self ,
13931414 masked_batch_dict : dict [str , Any ],
@@ -1432,10 +1453,22 @@ def _fix_sample_with_no_tokens(
14321453 modality_mask , modality_spec , patch_size
14331454 )
14341455
1456+ # =================================================================
1457+ # MAIN ENTRY POINT
1458+ # =================================================================
1459+
14351460 def apply_mask (
14361461 self , batch : OlmoEarthSample , patch_size : int | None = None , ** kwargs : Any
14371462 ) -> MaskedOlmoEarthSample :
1438- """Apply space masking to the input data."""
1463+ """Apply cross-modality masking to the input data.
1464+
1465+ This is the main entry point that orchestrates the entire masking process:
1466+ 1. Apply base masking strategy (e.g., space/time masking)
1467+ 2. Identify present modalities and bandsets for each sample
1468+ 3. Select which bandsets to encode vs decode
1469+ 4. Apply bandset-level masking rules
1470+ 5. Handle edge cases
1471+ """
14391472 if patch_size is None :
14401473 # this is because we use a random-masking proxy in case of
14411474 # no encoded or decoded tokens.
@@ -1487,7 +1520,7 @@ def __init__(
14871520 only_decode_modalities = only_decode_modalities ,
14881521 )
14891522
1490- def overide_strategy_mask (self , modality_spec : ModalitySpec ) -> bool :
1523+ def _overide_strategy_mask (self , modality_spec : ModalitySpec ) -> bool :
14911524 """Overide the random mask for the given modality by the encoding and decoding bandsets."""
14921525 # For space masking non spatial data is randomly masked but we want to use the encoding and decoding bandsets
14931526 # to determine the mask for the non spatial data
@@ -1526,7 +1559,7 @@ def __init__(
15261559 only_decode_modalities = only_decode_modalities ,
15271560 )
15281561
1529- def overide_strategy_mask (self , modality_spec : ModalitySpec ) -> bool :
1562+ def _overide_strategy_mask (self , modality_spec : ModalitySpec ) -> bool :
15301563 """Overide the random mask for the given modality by the encoding and decoding bandsets."""
15311564 # For time masking static data is randomly masked but we want to use the encoding and decoding bandsets
15321565 # to determine the mask for the static data
0 commit comments